Notify socket changes using a SoketKey

Currently, all socket changes are identified using a Network
object. However, the Network object is null for all tethering
interface sockets, which means that the socket cannot be
identified in some places. Therefore, the Network object should
be replaced with a SocketKey object, which includes both the
network and interface index.

Bug: 278018903
Test: atest FrameworksNetTests android.net.cts.NsdManagerTest
Change-Id: Ib49981a4071ecab18c7cf3a8827d1459529492a9
This commit is contained in:
Paul Hu
2023-05-18 11:53:05 +08:00
parent 2f2ae4ef33
commit 2f236e9ca4
7 changed files with 184 additions and 89 deletions

View File

@@ -56,7 +56,8 @@ private const val TIMEOUT_MS = 10_000L
private val TEST_ADDR = parseNumericAddress("2001:db8::123")
private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
private val TEST_NETWORK_1 = mock(Network::class.java)
private val TEST_NETWORK_2 = mock(Network::class.java)
private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SUBTYPE = "_subtype"
@@ -145,7 +146,7 @@ class MdnsAdvertiserTest {
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(
@@ -163,7 +164,7 @@ class MdnsAdvertiserTest {
mockInterfaceAdvertiser1, SERVICE_ID_1) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
postSync { socketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
verify(mockInterfaceAdvertiser1).destroyNow()
}
@@ -177,8 +178,8 @@ class MdnsAdvertiserTest {
socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_NETWORK_2, mockSocket2, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR)) }
val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
@@ -241,8 +242,8 @@ class MdnsAdvertiserTest {
// Callbacks for matching network and all networks both get the socket
postSync {
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
oneNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
}
val expectedRenamed = NsdServiceInfo(
@@ -294,8 +295,8 @@ class MdnsAdvertiserTest {
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) })
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
// destroyNow can be called multiple times
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()

View File

@@ -68,12 +68,15 @@ public class MdnsMultinetworkSocketClientTest {
@Mock private MdnsServiceBrowserListener mListener;
@Mock private MdnsSocketClientBase.Callback mCallback;
@Mock private SocketCreationCallback mSocketCreationCallback;
@Mock private SocketKey mSocketKey;
private MdnsMultinetworkSocketClient mSocketClient;
private Handler mHandler;
@Before
public void setUp() throws SocketException {
MockitoAnnotations.initMocks(this);
doReturn(mNetwork).when(mSocketKey).getNetwork();
final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
thread.start();
mHandler = new Handler(thread.getLooper());
@@ -123,12 +126,16 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
}
final SocketKey tetherSocketKey1 = mock(SocketKey.class);
final SocketKey tetherSocketKey2 = mock(SocketKey.class);
doReturn(null).when(tetherSocketKey1).getNetwork();
doReturn(null).when(tetherSocketKey2).getNetwork();
// Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of());
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(null, tetherIfaceSock1, List.of());
callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null, tetherIfaceSock2, List.of());
callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
// Send packet to IPv4 with target network and verify sending has been called.
@@ -164,7 +171,7 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
// Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of());
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
final ArgumentCaptor<PacketHandler> handlerCaptor =
@@ -214,9 +221,11 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
callback.onSocketCreated(mNetwork, mSocket, List.of());
callback.onSocketCreated(null, socket2, List.of());
callback.onSocketCreated(null, socket3, List.of());
final SocketKey socketKey2 = mock(SocketKey.class);
doReturn(null).when(socketKey2).getNetwork();
callback.onSocketCreated(mSocketKey, mSocket, List.of());
callback.onSocketCreated(socketKey2, socket2, List.of());
callback.onSocketCreated(socketKey2, socket3, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
@@ -241,9 +250,9 @@ public class MdnsMultinetworkSocketClientTest {
final SocketCallback callback2 = callback2Captor.getAllValues().get(1);
// Notify socket created for all networks.
callback2.onSocketCreated(mNetwork, mSocket, List.of());
callback2.onSocketCreated(null, socket2, List.of());
callback2.onSocketCreated(null, socket3, List.of());
callback2.onSocketCreated(mSocketKey, mSocket, List.of());
callback2.onSocketCreated(socketKey2, socket2, List.of());
callback2.onSocketCreated(socketKey2, socket3, List.of());
verify(socketCreationCb2).onSocketCreated(mNetwork);
verify(socketCreationCb2, times(2)).onSocketCreated(null);
@@ -286,17 +295,17 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null /* network */, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(mSocketKey, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(null /* network */);
verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork);
mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mProvider).unrequestSocket(callback);
verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
}
@Test
@@ -306,15 +315,15 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null /* network */, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(mSocketKey, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
// Notify socket destroyed
callback.onInterfaceDestroyed(null /* network */, mSocket);
callback.onInterfaceDestroyed(mSocketKey, mSocket);
verifyNoMoreInteractions(mSocketCreationCallback);
callback.onInterfaceDestroyed(null /* network */, otherSocket);
verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
callback.onInterfaceDestroyed(mSocketKey, otherSocket);
verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
}
}

View File

@@ -157,6 +157,7 @@ public class MdnsSocketProviderTest {
TETHERED_IFACE_NAME);
doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
WIFI_P2P_IFACE_NAME);
doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
thread.start();
mHandler = new Handler(thread.getLooper());
@@ -227,30 +228,30 @@ public class MdnsSocketProviderTest {
private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
private class SocketEvent {
public final Network mNetwork;
public final SocketKey mSocketKey;
public final List<LinkAddress> mAddresses;
SocketEvent(Network network, List<LinkAddress> addresses) {
mNetwork = network;
SocketEvent(SocketKey socketKey, List<LinkAddress> addresses) {
mSocketKey = socketKey;
mAddresses = Collections.unmodifiableList(addresses);
}
}
private class SocketCreatedEvent extends SocketEvent {
SocketCreatedEvent(Network nw, List<LinkAddress> addresses) {
super(nw, addresses);
SocketCreatedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(socketKey, addresses);
}
}
private class InterfaceDestroyedEvent extends SocketEvent {
InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) {
super(nw, addresses);
InterfaceDestroyedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(socketKey, addresses);
}
}
private class AddressesChangedEvent extends SocketEvent {
AddressesChangedEvent(Network nw, List<LinkAddress> addresses) {
super(nw, addresses);
AddressesChangedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(socketKey, addresses);
}
}
@@ -258,27 +259,27 @@ public class MdnsSocketProviderTest {
new ArrayTrackRecord<SocketEvent>().newReadHead();
@Override
public void onSocketCreated(Network network, MdnsInterfaceSocket socket,
public void onSocketCreated(SocketKey socketKey, MdnsInterfaceSocket socket,
List<LinkAddress> addresses) {
mHistory.add(new SocketCreatedEvent(network, addresses));
mHistory.add(new SocketCreatedEvent(socketKey, addresses));
}
@Override
public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
mHistory.add(new InterfaceDestroyedEvent(network, List.of()));
public void onInterfaceDestroyed(SocketKey socketKey, MdnsInterfaceSocket socket) {
mHistory.add(new InterfaceDestroyedEvent(socketKey, List.of()));
}
@Override
public void onAddressesChanged(Network network, MdnsInterfaceSocket socket,
public void onAddressesChanged(SocketKey socketKey, MdnsInterfaceSocket socket,
List<LinkAddress> addresses) {
mHistory.add(new AddressesChangedEvent(network, addresses));
mHistory.add(new AddressesChangedEvent(socketKey, addresses));
}
public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event);
assertTrue(event instanceof SocketCreatedEvent);
assertEquals(network, event.mNetwork);
assertEquals(network, event.mSocketKey.getNetwork());
assertEquals(addresses, event.mAddresses);
}
@@ -286,7 +287,7 @@ public class MdnsSocketProviderTest {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event);
assertTrue(event instanceof InterfaceDestroyedEvent);
assertEquals(network, event.mNetwork);
assertEquals(network, event.mSocketKey.getNetwork());
}
public void expectedAddressesChangedForNetwork(Network network,
@@ -294,7 +295,7 @@ public class MdnsSocketProviderTest {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event);
assertTrue(event instanceof AddressesChangedEvent);
assertEquals(network, event.mNetwork);
assertEquals(network, event.mSocketKey.getNetwork());
assertEquals(event.mAddresses, addresses);
}