diff --git a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java index 9c0d33fe5a..1307a849f1 100644 --- a/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java +++ b/tests/net/java/com/android/server/net/NetworkStatsServiceTest.java @@ -245,7 +245,7 @@ public class NetworkStatsServiceTest extends NetworkStatsBaseTest { @Override public NetworkStatsSubscriptionsMonitor makeSubscriptionsMonitor( - @NonNull Context context, @NonNull Executor executor, + @NonNull Context context, @NonNull Looper looper, @NonNull Executor executor, @NonNull NetworkStatsService service) { return mNetworkStatsSubscriptionsMonitor; diff --git a/tests/net/java/com/android/server/net/NetworkStatsSubscriptionsMonitorTest.java b/tests/net/java/com/android/server/net/NetworkStatsSubscriptionsMonitorTest.java index 058856dcd6..16fed39bcc 100644 --- a/tests/net/java/com/android/server/net/NetworkStatsSubscriptionsMonitorTest.java +++ b/tests/net/java/com/android/server/net/NetworkStatsSubscriptionsMonitorTest.java @@ -17,6 +17,7 @@ package com.android.server.net; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.eq; @@ -29,14 +30,13 @@ import static org.mockito.Mockito.when; import android.annotation.NonNull; import android.content.Context; -import android.os.Looper; +import android.os.test.TestLooper; import android.telephony.PhoneStateListener; import android.telephony.ServiceState; import android.telephony.SubscriptionManager; import android.telephony.TelephonyManager; import com.android.internal.util.CollectionUtils; -import com.android.server.net.NetworkStatsSubscriptionsMonitor.Delegate; import com.android.server.net.NetworkStatsSubscriptionsMonitor.RatTypeListener; import org.junit.Before; @@ -64,20 +64,17 @@ public final class NetworkStatsSubscriptionsMonitorTest { @Mock private PhoneStateListener mPhoneStateListener; @Mock private SubscriptionManager mSubscriptionManager; @Mock private TelephonyManager mTelephonyManager; - @Mock private Delegate mDelegate; + @Mock private NetworkStatsSubscriptionsMonitor.Delegate mDelegate; private final List mTestSubList = new ArrayList<>(); private final Executor mExecutor = Executors.newSingleThreadExecutor(); private NetworkStatsSubscriptionsMonitor mMonitor; + private TestLooper mTestLooper = new TestLooper(); @Before public void setUp() { MockitoAnnotations.initMocks(this); - if (Looper.myLooper() == null) { - Looper.prepare(); - } - when(mTelephonyManager.createForSubscriptionId(anyInt())).thenReturn(mTelephonyManager); when(mContext.getSystemService(eq(Context.TELEPHONY_SUBSCRIPTION_SERVICE))) @@ -85,7 +82,8 @@ public final class NetworkStatsSubscriptionsMonitorTest { when(mContext.getSystemService(eq(Context.TELEPHONY_SERVICE))) .thenReturn(mTelephonyManager); - mMonitor = new NetworkStatsSubscriptionsMonitor(mContext, mExecutor, mDelegate); + mMonitor = new NetworkStatsSubscriptionsMonitor(mContext, mTestLooper.getLooper(), + mExecutor, mDelegate); } @Test @@ -117,16 +115,18 @@ public final class NetworkStatsSubscriptionsMonitorTest { when(serviceState.getDataNetworkType()).thenReturn(type); final RatTypeListener match = CollectionUtils .find(listeners, it -> it.getSubId() == subId); - if (match != null) { - match.onServiceStateChanged(serviceState); + if (match == null) { + fail("Could not find listener with subId: " + subId); } + match.onServiceStateChanged(serviceState); } private void addTestSub(int subId, String subscriberId) { // add SubId to TestSubList. - if (!mTestSubList.contains(subId)) { - mTestSubList.add(subId); - } + if (mTestSubList.contains(subId)) fail("The subscriber list already contains this ID"); + + mTestSubList.add(subId); + final int[] subList = convertArrayListToIntArray(mTestSubList); when(mSubscriptionManager.getCompleteActiveSubscriptionIdList()).thenReturn(subList); when(mTelephonyManager.getSubscriberId(subId)).thenReturn(subscriberId);