diff --git a/core/java/android/net/NetworkUtils.java b/core/java/android/net/NetworkUtils.java index 0b92b95128..1e004e4942 100644 --- a/core/java/android/net/NetworkUtils.java +++ b/core/java/android/net/NetworkUtils.java @@ -476,4 +476,35 @@ public class NetworkUtils { return true; } + + /** + * Safely multiple a value by a rational. + *
+ * Internally it uses integer-based math whenever possible, but switches + * over to double-based math if values would overflow. + * @hide + */ + public static long multiplySafeByRational(long value, long num, long den) { + if (den == 0) { + throw new ArithmeticException("Invalid Denominator"); + } + long x = value; + long y = num; + + // Logic shamelessly borrowed from Math.multiplyExact() + long r = x * y; + long ax = Math.abs(x); + long ay = Math.abs(y); + if (((ax | ay) >>> 31 != 0)) { + // Some bits greater than 2^31 that might cause overflow + // Check the result using the divide operator + // and check for the special case of Long.MIN_VALUE * -1 + if (((y != 0) && (r / y != x)) || + (x == Long.MIN_VALUE && y == -1)) { + // Use double math to avoid overflowing + return (long) (((double) num / den) * value); + } + } + return r / den; + } } diff --git a/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java b/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java index 551498f2c0..e83d2a90bf 100644 --- a/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java +++ b/tests/net/java/com/android/server/net/NetworkStatsCollectionTest.java @@ -23,11 +23,12 @@ import static android.net.NetworkStats.TAG_NONE; import static android.net.NetworkStats.UID_ALL; import static android.net.NetworkStatsHistory.FIELD_ALL; import static android.net.NetworkTemplate.buildTemplateMobileAll; +import static android.net.NetworkUtils.multiplySafeByRational; import static android.os.Process.myUid; import static android.text.format.DateUtils.HOUR_IN_MILLIS; import static android.text.format.DateUtils.MINUTE_IN_MILLIS; -import static com.android.server.net.NetworkStatsCollection.multiplySafe; +import static com.android.testutils.MiscAssertsKt.assertThrows; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -505,23 +506,25 @@ public class NetworkStatsCollectionTest { } @Test - public void testMultiplySafe() { - assertEquals(25, multiplySafe(50, 1, 2)); - assertEquals(100, multiplySafe(50, 2, 1)); + public void testMultiplySafeRational() { + assertEquals(25, multiplySafeByRational(50, 1, 2)); + assertEquals(100, multiplySafeByRational(50, 2, 1)); - assertEquals(-10, multiplySafe(30, -1, 3)); - assertEquals(0, multiplySafe(30, 0, 3)); - assertEquals(10, multiplySafe(30, 1, 3)); - assertEquals(20, multiplySafe(30, 2, 3)); - assertEquals(30, multiplySafe(30, 3, 3)); - assertEquals(40, multiplySafe(30, 4, 3)); + assertEquals(-10, multiplySafeByRational(30, -1, 3)); + assertEquals(0, multiplySafeByRational(30, 0, 3)); + assertEquals(10, multiplySafeByRational(30, 1, 3)); + assertEquals(20, multiplySafeByRational(30, 2, 3)); + assertEquals(30, multiplySafeByRational(30, 3, 3)); + assertEquals(40, multiplySafeByRational(30, 4, 3)); assertEquals(100_000_000_000L, - multiplySafe(300_000_000_000L, 10_000_000_000L, 30_000_000_000L)); + multiplySafeByRational(300_000_000_000L, 10_000_000_000L, 30_000_000_000L)); assertEquals(100_000_000_010L, - multiplySafe(300_000_000_000L, 10_000_000_001L, 30_000_000_000L)); + multiplySafeByRational(300_000_000_000L, 10_000_000_001L, 30_000_000_000L)); assertEquals(823_202_048L, - multiplySafe(4_939_212_288L, 2_121_815_528L, 12_730_893_165L)); + multiplySafeByRational(4_939_212_288L, 2_121_815_528L, 12_730_893_165L)); + + assertThrows(ArithmeticException.class, () -> multiplySafeByRational(30, 3, 0)); } /**