Stop monitoring sockets until all sockets are unrequested

Now MdnsSocketProvider is stopped when there is no client request
left in NsdService, but this does not trigger
SocketCallback.onInterfaceDestroyed callbacks. If the network of
the socket is then lost while MdnsSocketProvider is not
monitoring, no callback will be fired. Users of the socket
(MdnsDiscoveryManager and MdnsAdvertiser) may keep using it
without ever getting notified. So ignore the stop and wait until
all sockets are unrequested. Then the socket destroy should be
notified to all users.

Bug: 267978487
Test: atest FrameworksNetTests
Change-Id: I7a8bb0550262fe397b91f1236a8dbca1cf2c7518
This commit is contained in:
Paul Hu
2023-02-18 11:41:07 +08:00
parent 22f17371b3
commit 58f2060614
4 changed files with 90 additions and 16 deletions

View File

@@ -317,7 +317,7 @@ public class NsdService extends INsdManager.Stub {
if (!mIsMonitoringSocketsStarted) return; if (!mIsMonitoringSocketsStarted) return;
if (isAnyRequestActive()) return; if (isAnyRequestActive()) return;
mMdnsSocketProvider.stopMonitoringSockets(); mMdnsSocketProvider.requestStopWhenInactive();
mIsMonitoringSocketsStarted = false; mIsMonitoringSocketsStarted = false;
} }

View File

@@ -82,6 +82,7 @@ public class MdnsSocketProvider {
private final List<String> mTetheredInterfaces = new ArrayList<>(); private final List<String> mTetheredInterfaces = new ArrayList<>();
private final byte[] mPacketReadBuffer = new byte[READ_BUFFER_SIZE]; private final byte[] mPacketReadBuffer = new byte[READ_BUFFER_SIZE];
private boolean mMonitoringSockets = false; private boolean mMonitoringSockets = false;
private boolean mRequestStop = false;
public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper) { public MdnsSocketProvider(@NonNull Context context, @NonNull Looper looper) {
this(context, looper, new Dependencies()); this(context, looper, new Dependencies());
@@ -179,6 +180,7 @@ public class MdnsSocketProvider {
/*** Start monitoring sockets by listening callbacks for sockets creation or removal */ /*** Start monitoring sockets by listening callbacks for sockets creation or removal */
public void startMonitoringSockets() { public void startMonitoringSockets() {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
mRequestStop = false; // Reset stop request flag.
if (mMonitoringSockets) { if (mMonitoringSockets) {
Log.d(TAG, "Already monitoring sockets."); Log.d(TAG, "Already monitoring sockets.");
return; return;
@@ -195,22 +197,34 @@ public class MdnsSocketProvider {
mMonitoringSockets = true; mMonitoringSockets = true;
} }
/*** Stop monitoring sockets and unregister callbacks */ private void maybeStopMonitoringSockets() {
public void stopMonitoringSockets() { if (!mMonitoringSockets) return; // Already unregistered.
if (!mRequestStop) return; // No stop request.
// Only unregister the network callback if there is no socket request.
if (mCallbacksToRequestedNetworks.isEmpty()) {
mContext.getSystemService(ConnectivityManager.class)
.unregisterNetworkCallback(mNetworkCallback);
final TetheringManager tetheringManager = mContext.getSystemService(
TetheringManager.class);
tetheringManager.unregisterTetheringEventCallback(mTetheringEventCallback);
mHandler.post(mNetlinkMonitor::stop);
mMonitoringSockets = false;
}
}
/*** Request to stop monitoring sockets and unregister callbacks */
public void requestStopWhenInactive() {
ensureRunningOnHandlerThread(mHandler); ensureRunningOnHandlerThread(mHandler);
if (!mMonitoringSockets) { if (!mMonitoringSockets) {
Log.d(TAG, "Monitoring sockets hasn't been started."); Log.d(TAG, "Monitoring sockets hasn't been started.");
return; return;
} }
if (DBG) Log.d(TAG, "Stop monitoring sockets."); if (DBG) Log.d(TAG, "Try to stop monitoring sockets.");
mContext.getSystemService(ConnectivityManager.class) mRequestStop = true;
.unregisterNetworkCallback(mNetworkCallback); maybeStopMonitoringSockets();
final TetheringManager tetheringManager = mContext.getSystemService(TetheringManager.class);
tetheringManager.unregisterTetheringEventCallback(mTetheringEventCallback);
mHandler.post(mNetlinkMonitor::stop);
mMonitoringSockets = false;
} }
/*** Check whether the target network is matched current network */ /*** Check whether the target network is matched current network */
@@ -450,6 +464,9 @@ public class MdnsSocketProvider {
cb.onInterfaceDestroyed(new Network(INetd.LOCAL_NET_ID), info.mSocket); cb.onInterfaceDestroyed(new Network(INetd.LOCAL_NET_ID), info.mSocket);
} }
mTetherInterfaceSockets.clear(); mTetherInterfaceSockets.clear();
// Try to unregister network callback.
maybeStopMonitoringSockets();
} }
/*** Callbacks for listening socket changes */ /*** Callbacks for listening socket changes */

View File

@@ -932,7 +932,7 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
verify(mDiscoveryManager).unregisterListener(eq(serviceTypeWithLocalDomain), any()); verify(mDiscoveryManager).unregisterListener(eq(serviceTypeWithLocalDomain), any());
verify(discListener, timeout(TIMEOUT_MS)).onDiscoveryStopped(SERVICE_TYPE); verify(discListener, timeout(TIMEOUT_MS)).onDiscoveryStopped(SERVICE_TYPE);
verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets(); verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).requestStopWhenInactive();
} }
@Test @Test
@@ -1016,7 +1016,7 @@ public class NsdServiceTest {
// Verify the listener has been unregistered. // Verify the listener has been unregistered.
verify(mDiscoveryManager, timeout(TIMEOUT_MS)) verify(mDiscoveryManager, timeout(TIMEOUT_MS))
.unregisterListener(eq(constructedServiceType), any()); .unregisterListener(eq(constructedServiceType), any());
verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).stopMonitoringSockets(); verify(mSocketProvider, timeout(CLEANUP_DELAY_MS + TIMEOUT_MS)).requestStopWhenInactive();
} }
@Test @Test
@@ -1090,7 +1090,7 @@ public class NsdServiceTest {
verify(mAdvertiser).removeService(idCaptor.getValue()); verify(mAdvertiser).removeService(idCaptor.getValue());
verify(regListener, timeout(TIMEOUT_MS)).onServiceUnregistered( verify(regListener, timeout(TIMEOUT_MS)).onServiceUnregistered(
argThat(info -> matches(info, regInfo))); argThat(info -> matches(info, regInfo)));
verify(mSocketProvider, timeout(TIMEOUT_MS)).stopMonitoringSockets(); verify(mSocketProvider, timeout(TIMEOUT_MS)).requestStopWhenInactive();
} }
@Test @Test

View File

@@ -27,6 +27,7 @@ import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@@ -109,12 +110,15 @@ public class MdnsSocketProviderTest {
final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest"); final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
thread.start(); thread.start();
mHandler = new Handler(thread.getLooper()); mHandler = new Handler(thread.getLooper());
mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps);
}
private void startMonitoringSockets() {
final ArgumentCaptor<NetworkCallback> nwCallbackCaptor = final ArgumentCaptor<NetworkCallback> nwCallbackCaptor =
ArgumentCaptor.forClass(NetworkCallback.class); ArgumentCaptor.forClass(NetworkCallback.class);
final ArgumentCaptor<TetheringEventCallback> teCallbackCaptor = final ArgumentCaptor<TetheringEventCallback> teCallbackCaptor =
ArgumentCaptor.forClass(TetheringEventCallback.class); ArgumentCaptor.forClass(TetheringEventCallback.class);
mSocketProvider = new MdnsSocketProvider(mContext, thread.getLooper(), mDeps);
mHandler.post(mSocketProvider::startMonitoringSockets); mHandler.post(mSocketProvider::startMonitoringSockets);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm).registerNetworkCallback(any(), nwCallbackCaptor.capture(), any()); verify(mCm).registerNetworkCallback(any(), nwCallbackCaptor.capture(), any());
@@ -205,6 +209,8 @@ public class MdnsSocketProviderTest {
@Test @Test
public void testSocketRequestAndUnrequestSocket() { public void testSocketRequestAndUnrequestSocket() {
startMonitoringSockets();
final TestSocketCallback testCallback1 = new TestSocketCallback(); final TestSocketCallback testCallback1 = new TestSocketCallback();
mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback1)); mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback1));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
@@ -275,6 +281,8 @@ public class MdnsSocketProviderTest {
@Test @Test
public void testAddressesChanged() throws Exception { public void testAddressesChanged() throws Exception {
startMonitoringSockets();
final TestSocketCallback testCallback = new TestSocketCallback(); final TestSocketCallback testCallback = new TestSocketCallback();
mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback)); mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
@@ -297,4 +305,53 @@ public class MdnsSocketProviderTest {
testCallback.expectedAddressesChangedForNetwork( testCallback.expectedAddressesChangedForNetwork(
TEST_NETWORK, List.of(LINKADDRV4, LINKADDRV6)); TEST_NETWORK, List.of(LINKADDRV4, LINKADDRV6));
} }
@Test
public void testStartAndStopMonitoringSockets() {
// Stop monitoring sockets before start. Should not unregister any network callback.
mHandler.post(mSocketProvider::requestStopWhenInactive);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm, never()).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, never()).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
// Start sockets monitoring.
startMonitoringSockets();
// Request a socket then unrequest it. Expect no network callback unregistration.
final TestSocketCallback testCallback = new TestSocketCallback();
mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
testCallback.expectedNoCallback();
mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm, never()).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, never()).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
// Request stop and it should unregister network callback immediately because there is no
// socket request.
mHandler.post(mSocketProvider::requestStopWhenInactive);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, times(1)).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
// Start sockets monitoring and request a socket again.
mHandler.post(mSocketProvider::startMonitoringSockets);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm, times(2)).registerNetworkCallback(any(), any(NetworkCallback.class), any());
verify(mTm, times(2)).registerTetheringEventCallback(
any(), any(TetheringEventCallback.class));
final TestSocketCallback testCallback2 = new TestSocketCallback();
mHandler.post(() -> mSocketProvider.requestSocket(TEST_NETWORK, testCallback2));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
testCallback2.expectedNoCallback();
// Try to stop monitoring sockets but should be ignored and wait until all socket are
// unrequested.
mHandler.post(mSocketProvider::requestStopWhenInactive);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, times(1)).unregisterTetheringEventCallback(any());
// Unrequest the socket then network callbacks should be unregistered.
mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback2));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mCm, times(2)).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, times(2)).unregisterTetheringEventCallback(any(TetheringEventCallback.class));
}
} }