Do not send socket destroyed on unregistration

When a SocketCallback is unregistered from MdnsSocketProvider, do not
send socket destroyed callbacks. Callers may not expect getting
callbacks after unregistration, and the current callbacks are also
broken when an unrequested socket is still in use by another requester.

MdnsAdvertiser already does not depend on getting this callback, as it
only unregisters the SocketCallback after it is done using the socket.
This change fixes MdnsMultinetworkSocketClient to destroy the socket by
itself when unrequesting.

Bug: 276177548
Test: atest
(cherry picked from https://android-review.googlesource.com/q/commit:5fe9bacc63c1b6a77878f23d5f53a07fc482f354)
Merged-In: If95f833e293f3aab91128aab1c9852ebfd41995d
Change-Id: If95f833e293f3aab91128aab1c9852ebfd41995d
This commit is contained in:
Remi NGUYEN VAN
2023-05-15 11:15:18 +09:00
committed by Cherrypicker Worker
parent 60437e59de
commit 6721aa3570
4 changed files with 141 additions and 17 deletions

View File

@@ -34,7 +34,6 @@ import java.net.Inet4Address;
import java.net.Inet6Address; import java.net.Inet6Address;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* The {@link MdnsMultinetworkSocketClient} manages the multinetwork socket for mDns * The {@link MdnsMultinetworkSocketClient} manages the multinetwork socket for mDns
@@ -48,9 +47,8 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
@NonNull private final Handler mHandler; @NonNull private final Handler mHandler;
@NonNull private final MdnsSocketProvider mSocketProvider; @NonNull private final MdnsSocketProvider mSocketProvider;
private final Map<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks = private final ArrayMap<MdnsServiceBrowserListener, InterfaceSocketCallback> mRequestedNetworks =
new ArrayMap<>(); new ArrayMap<>();
private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets = new ArrayMap<>();
private final ArrayMap<MdnsInterfaceSocket, ReadPacketHandler> mSocketPacketHandlers = private final ArrayMap<MdnsInterfaceSocket, ReadPacketHandler> mSocketPacketHandlers =
new ArrayMap<>(); new ArrayMap<>();
private MdnsSocketClientBase.Callback mCallback = null; private MdnsSocketClientBase.Callback mCallback = null;
@@ -63,7 +61,11 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
} }
private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback { private class InterfaceSocketCallback implements MdnsSocketProvider.SocketCallback {
@NonNull
private final SocketCreationCallback mSocketCreationCallback; private final SocketCreationCallback mSocketCreationCallback;
@NonNull
private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets =
new ArrayMap<>();
InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) { InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
mSocketCreationCallback = socketCreationCallback; mSocketCreationCallback = socketCreationCallback;
@@ -88,10 +90,47 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
@Override @Override
public void onInterfaceDestroyed(@Nullable Network network, public void onInterfaceDestroyed(@Nullable Network network,
@NonNull MdnsInterfaceSocket socket) { @NonNull MdnsInterfaceSocket socket) {
mSocketPacketHandlers.remove(socket); notifySocketDestroyed(socket);
mActiveNetworkSockets.remove(socket); maybeCleanupPacketHandler(socket);
}
private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
final Network network = mActiveNetworkSockets.remove(socket);
mSocketCreationCallback.onSocketDestroyed(network); mSocketCreationCallback.onSocketDestroyed(network);
} }
void onNetworkUnrequested() {
for (int i = mActiveNetworkSockets.size() - 1; i >= 0; i--) {
// Iterate from the end so the socket can be removed
final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i);
notifySocketDestroyed(socket);
maybeCleanupPacketHandler(socket);
}
}
}
private boolean isSocketActive(@NonNull MdnsInterfaceSocket socket) {
for (int i = 0; i < mRequestedNetworks.size(); i++) {
final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
if (isc.mActiveNetworkSockets.containsKey(socket)) {
return true;
}
}
return false;
}
private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() {
final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>();
for (int i = 0; i < mRequestedNetworks.size(); i++) {
final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
sockets.putAll(isc.mActiveNetworkSockets);
}
return sockets;
}
private void maybeCleanupPacketHandler(@NonNull MdnsInterfaceSocket socket) {
if (isSocketActive(socket)) return;
mSocketPacketHandlers.remove(socket);
} }
private class ReadPacketHandler implements MulticastPacketReader.PacketHandler { private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
@@ -149,6 +188,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
return; return;
} }
mSocketProvider.unrequestSocket(callback); mSocketProvider.unrequestSocket(callback);
callback.onNetworkUnrequested();
} }
private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork) { private void sendMdnsPacket(@NonNull DatagramPacket packet, @Nullable Network targetNetwork) {
@@ -156,9 +196,10 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
instanceof Inet6Address; instanceof Inet6Address;
final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress() final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
instanceof Inet4Address; instanceof Inet4Address;
for (int i = 0; i < mActiveNetworkSockets.size(); i++) { final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets();
final MdnsInterfaceSocket socket = mActiveNetworkSockets.keyAt(i); for (int i = 0; i < activeSockets.size(); i++) {
final Network network = mActiveNetworkSockets.valueAt(i); final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
final Network network = activeSockets.valueAt(i);
// Check ip capability and network before sending packet // Check ip capability and network before sending packet
if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4())) if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
&& isNetworkMatched(targetNetwork, network)) { && isNetworkMatched(targetNetwork, network)) {

View File

@@ -599,8 +599,6 @@ public class MdnsSocketProvider {
if (matchRequestedNetwork(network)) continue; if (matchRequestedNetwork(network)) continue;
final SocketInfo info = mNetworkSockets.removeAt(i); final SocketInfo info = mNetworkSockets.removeAt(i);
info.mSocket.destroy(); info.mSocket.destroy();
// Still notify to unrequester for socket destroy.
cb.onInterfaceDestroyed(network, info.mSocket);
mSharedLog.log("Remove socket on net:" + network + " after unrequestSocket"); mSharedLog.log("Remove socket on net:" + network + " after unrequestSocket");
} }
@@ -610,8 +608,6 @@ public class MdnsSocketProvider {
for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) { for (int i = mTetherInterfaceSockets.size() - 1; i >= 0; i--) {
final SocketInfo info = mTetherInterfaceSockets.valueAt(i); final SocketInfo info = mTetherInterfaceSockets.valueAt(i);
info.mSocket.destroy(); info.mSocket.destroy();
// Still notify to unrequester for socket destroy.
cb.onInterfaceDestroyed(null /* network */, info.mSocket);
mSharedLog.log("Remove socket on ifName:" + mTetherInterfaceSockets.keyAt(i) mSharedLog.log("Remove socket on ifName:" + mTetherInterfaceSockets.keyAt(i)
+ " after unrequestSocket"); + " after unrequestSocket");
} }

View File

@@ -24,7 +24,9 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import android.net.InetAddresses; import android.net.InetAddresses;
@@ -77,12 +79,17 @@ public class MdnsMultinetworkSocketClientTest {
} }
private SocketCallback expectSocketCallback() { private SocketCallback expectSocketCallback() {
return expectSocketCallback(mListener, mNetwork);
}
private SocketCallback expectSocketCallback(MdnsServiceBrowserListener listener,
Network requestedNetwork) {
final ArgumentCaptor<SocketCallback> callbackCaptor = final ArgumentCaptor<SocketCallback> callbackCaptor =
ArgumentCaptor.forClass(SocketCallback.class); ArgumentCaptor.forClass(SocketCallback.class);
mHandler.post(() -> mSocketClient.notifyNetworkRequested( mHandler.post(() -> mSocketClient.notifyNetworkRequested(
mListener, mNetwork, mSocketCreationCallback)); listener, requestedNetwork, mSocketCreationCallback));
verify(mProvider, timeout(DEFAULT_TIMEOUT)) verify(mProvider, timeout(DEFAULT_TIMEOUT))
.requestSocket(eq(mNetwork), callbackCaptor.capture()); .requestSocket(eq(requestedNetwork), callbackCaptor.capture());
return callbackCaptor.getValue(); return callbackCaptor.getValue();
} }
@@ -169,4 +176,83 @@ public class MdnsMultinetworkSocketClientTest {
new String[] { "Android", "local" } /* serviceHost */) new String[] { "Android", "local" } /* serviceHost */)
), response.answers); ), response.answers);
} }
@Test
public void testSocketRemovedAfterNetworkUnrequested() throws IOException {
// Request a socket
final SocketCallback callback = expectSocketCallback(mListener, mNetwork);
final DatagramPacket ipv4Packet = new DatagramPacket(BUFFER, 0 /* offset */, BUFFER.length,
InetAddresses.parseNumericAddress("192.0.2.1"), 0 /* port */);
doReturn(true).when(mSocket).hasJoinedIpv4();
doReturn(true).when(mSocket).hasJoinedIpv6();
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
// Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
// Send IPv4 packet and verify sending has been called.
mSocketClient.sendMulticastPacket(ipv4Packet);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket).send(ipv4Packet);
// Request another socket with null network
final MdnsServiceBrowserListener listener2 = mock(MdnsServiceBrowserListener.class);
final Network network2 = mock(Network.class);
final MdnsInterfaceSocket socket2 = mock(MdnsInterfaceSocket.class);
final SocketCallback callback2 = expectSocketCallback(listener2, null);
doReturn(true).when(socket2).hasJoinedIpv4();
doReturn(true).when(socket2).hasJoinedIpv6();
doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
// Notify socket created for two networks.
callback2.onSocketCreated(mNetwork, mSocket, List.of());
callback2.onSocketCreated(network2, socket2, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
verify(mSocketCreationCallback).onSocketCreated(network2);
// Send IPv4 packet and verify sending to two sockets.
mSocketClient.sendMulticastPacket(ipv4Packet);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, times(2)).send(ipv4Packet);
verify(socket2).send(ipv4Packet);
// Unrequest another socket
mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(listener2));
verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback2);
// Send IPv4 packet again and verify only sending via mSocket
mSocketClient.sendMulticastPacket(ipv4Packet);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, times(3)).send(ipv4Packet);
verify(socket2).send(ipv4Packet);
// Unrequest remaining socket
mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
verify(mProvider, timeout(DEFAULT_TIMEOUT)).unrequestSocket(callback);
// Send IPv4 packet and verify no more sending.
mSocketClient.sendMulticastPacket(ipv4Packet);
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mSocket, times(3)).send(ipv4Packet);
verify(socket2).send(ipv4Packet);
}
@Test
public void testNotifyNetworkUnrequested_SocketsOnNullNetwork() {
final MdnsInterfaceSocket otherSocket = mock(MdnsInterfaceSocket.class);
final SocketCallback callback = expectSocketCallback(
mListener, null /* requestedNetwork */);
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null /* network */, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mProvider).unrequestSocket(callback);
verify(mSocketCreationCallback, times(2)).onSocketDestroyed(null /* network */);
}
} }

View File

@@ -349,8 +349,8 @@ public class MdnsSocketProviderTest {
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
testCallback1.expectedNoCallback(); testCallback1.expectedNoCallback();
testCallback2.expectedNoCallback(); testCallback2.expectedNoCallback();
// Expect the socket destroy for tethered interface. // There was still a tethered interface, but no callback should be sent once unregistered
testCallback3.expectedInterfaceDestroyedForNetwork(null /* network */); testCallback3.expectedNoCallback();
} }
private RtNetlinkAddressMessage createNetworkAddressUpdateNetLink( private RtNetlinkAddressMessage createNetworkAddressUpdateNetLink(
@@ -528,7 +528,8 @@ public class MdnsSocketProviderTest {
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback)); mHandler.post(()-> mSocketProvider.unrequestSocket(testCallback));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
testCallback.expectedInterfaceDestroyedForNetwork(TEST_NETWORK); // No callback sent when unregistered
testCallback.expectedNoCallback();
verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class)); verify(mCm, times(1)).unregisterNetworkCallback(any(NetworkCallback.class));
verify(mTm, times(1)).unregisterTetheringEventCallback(any()); verify(mTm, times(1)).unregisterTetheringEventCallback(any());