Notify socket changes using a SoketKey

Currently, all socket changes are identified using a Network
object. However, the Network object is null for all tethering
interface sockets, which means that the socket cannot be
identified in some places. Therefore, the Network object should
be replaced with a SocketKey object, which includes both the
network and interface index.

Bug: 278018903
Test: atest FrameworksNetTests android.net.cts.NsdManagerTest
Change-Id: Ib49981a4071ecab18c7cf3a8827d1459529492a9
This commit is contained in:
Paul Hu
2023-05-18 11:53:05 +08:00
parent 2f2ae4ef33
commit 2f236e9ca4
7 changed files with 184 additions and 89 deletions

View File

@@ -287,7 +287,7 @@ public class MdnsAdvertiser {
}
@Override
public void onSocketCreated(@NonNull Network network,
public void onSocketCreated(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> addresses) {
MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket);
@@ -311,14 +311,14 @@ public class MdnsAdvertiser {
}
@Override
public void onInterfaceDestroyed(@NonNull Network network,
public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket) {
final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
if (advertiser != null) advertiser.destroyNow();
}
@Override
public void onAddressesChanged(@NonNull Network network,
public void onAddressesChanged(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
if (advertiser != null) advertiser.updateAddresses(addresses);

View File

@@ -64,7 +64,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
@NonNull
private final SocketCreationCallback mSocketCreationCallback;
@NonNull
private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets =
private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveNetworkSockets =
new ArrayMap<>();
InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -72,32 +72,32 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
}
@Override
public void onSocketCreated(@Nullable Network network,
public void onSocketCreated(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
// The socket may be already created by other request before, try to get the stored
// ReadPacketHandler.
ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
if (handler == null) {
// First request to create this socket. Initial a ReadPacketHandler for this socket.
handler = new ReadPacketHandler(network, socket.getInterface().getIndex());
handler = new ReadPacketHandler(socketKey);
mSocketPacketHandlers.put(socket, handler);
}
socket.addPacketHandler(handler);
mActiveNetworkSockets.put(socket, network);
mSocketCreationCallback.onSocketCreated(network);
mActiveNetworkSockets.put(socket, socketKey);
mSocketCreationCallback.onSocketCreated(socketKey.getNetwork());
}
@Override
public void onInterfaceDestroyed(@Nullable Network network,
public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket) {
notifySocketDestroyed(socket);
maybeCleanupPacketHandler(socket);
}
private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
final Network network = mActiveNetworkSockets.remove(socket);
if (!isAnySocketActive(network)) {
mSocketCreationCallback.onAllSocketsDestroyed(network);
final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
if (!isAnySocketActive(socketKey)) {
mSocketCreationCallback.onAllSocketsDestroyed(socketKey.getNetwork());
}
}
@@ -121,18 +121,18 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
return false;
}
private boolean isAnySocketActive(@Nullable Network network) {
private boolean isAnySocketActive(@NonNull SocketKey socketKey) {
for (int i = 0; i < mRequestedNetworks.size(); i++) {
final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
if (isc.mActiveNetworkSockets.containsValue(network)) {
if (isc.mActiveNetworkSockets.containsValue(socketKey)) {
return true;
}
}
return false;
}
private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() {
final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>();
private ArrayMap<MdnsInterfaceSocket, SocketKey> getActiveSockets() {
final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
for (int i = 0; i < mRequestedNetworks.size(); i++) {
final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
sockets.putAll(isc.mActiveNetworkSockets);
@@ -146,17 +146,15 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
}
private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
private final Network mNetwork;
private final int mInterfaceIndex;
@NonNull private final SocketKey mSocketKey;
ReadPacketHandler(@NonNull Network network, int interfaceIndex) {
mNetwork = network;
mInterfaceIndex = interfaceIndex;
ReadPacketHandler(@NonNull SocketKey socketKey) {
mSocketKey = socketKey;
}
@Override
public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork);
processResponsePacket(recvbuf, length, mSocketKey);
}
}
@@ -220,10 +218,10 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
instanceof Inet6Address;
final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
instanceof Inet4Address;
final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets();
final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
for (int i = 0; i < activeSockets.size(); i++) {
final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
final Network network = activeSockets.valueAt(i);
final Network network = activeSockets.valueAt(i).getNetwork();
// Check ip capability and network before sending packet
if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
// Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
@@ -239,8 +237,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
}
}
private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex,
@NonNull Network network) {
private void processResponsePacket(byte[] recvbuf, int length, @NonNull SocketKey socketKey) {
int packetNumber = ++mReceivedPacketNumber;
final MdnsPacket response;
@@ -250,14 +247,16 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
Log.e(TAG, e.getMessage(), e);
if (mCallback != null) {
mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
mCallback.onFailedToParseMdnsResponse(
packetNumber, e.code, socketKey.getNetwork());
}
}
return;
}
if (mCallback != null) {
mCallback.onResponseReceived(response, interfaceIndex, network);
mCallback.onResponseReceived(
response, socketKey.getInterfaceIndex(), socketKey.getNetwork());
}
}

View File

@@ -258,6 +258,11 @@ public class MdnsSocketProvider {
@NonNull final NetLinkMonitorCallBack cb) {
return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
}
/*** Get interface index by given socket */
public int getInterfaceIndex(@NonNull MdnsInterfaceSocket socket) {
return socket.getInterface().getIndex();
}
}
/**
* The callback interface for the netlink monitor messages.
@@ -597,8 +602,10 @@ public class MdnsSocketProvider {
for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
if (isNetworkMatched(requestedNetwork, network)) {
mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socketInfo.mSocket,
socketInfo.mAddresses);
final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
final SocketKey socketKey = new SocketKey(network, ifaceIndex);
mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketKey,
socketInfo.mSocket, socketInfo.mAddresses);
mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
socketInfo.mTransports);
}
@@ -609,7 +616,9 @@ public class MdnsSocketProvider {
for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
if (isNetworkMatched(requestedNetwork, network)) {
mCallbacksToRequestedNetworks.keyAt(i).onInterfaceDestroyed(network, socket);
final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
mCallbacksToRequestedNetworks.keyAt(i)
.onInterfaceDestroyed(new SocketKey(network, ifaceIndex), socket);
}
}
}
@@ -619,8 +628,9 @@ public class MdnsSocketProvider {
for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
if (isNetworkMatched(requestedNetwork, network)) {
final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
mCallbacksToRequestedNetworks.keyAt(i)
.onAddressesChanged(network, socket, addresses);
.onAddressesChanged(new SocketKey(network, ifaceIndex), socket, addresses);
}
}
}
@@ -637,7 +647,9 @@ public class MdnsSocketProvider {
createSocket(new NetworkAsKey(network), lp);
} else {
// Notify the socket for requested network.
cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses);
final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
final SocketKey socketKey = new SocketKey(network, ifaceIndex);
cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
socketInfo.mTransports);
}
@@ -652,8 +664,9 @@ public class MdnsSocketProvider {
createLPForTetheredInterface(interfaceName, ifaceIndex));
} else {
// Notify the socket for requested network.
cb.onSocketCreated(
null /* network */, socketInfo.mSocket, socketInfo.mAddresses);
final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
final SocketKey socketKey = new SocketKey(ifaceIndex);
cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
socketInfo.mSocket, socketInfo.mTransports);
}
@@ -741,21 +754,21 @@ public class MdnsSocketProvider {
* This may be called immediately when the request is registered with an existing socket,
* if it had been created previously for other requests.
*/
default void onSocketCreated(@Nullable Network network, @NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> addresses) {}
default void onSocketCreated(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
/**
* Notify that the interface was destroyed, so the provided socket cannot be used anymore.
*
* This indicates that although the socket was still requested, it had to be destroyed.
*/
default void onInterfaceDestroyed(@Nullable Network network,
default void onInterfaceDestroyed(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket) {}
/**
* Notify the interface addresses have changed for the network.
*/
default void onAddressesChanged(@Nullable Network network,
default void onAddressesChanged(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
}

View File

@@ -0,0 +1,72 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
import android.annotation.Nullable;
import android.net.Network;
import java.util.Objects;
/**
* A class that identifies a socket.
*
* <p> A socket is typically created with an associated network. However, tethering interfaces do
* not have an associated network, only an interface index. This means that the socket cannot be
* identified in some places. Therefore, this class is necessary for identifying a socket. It
* includes both the network and interface index.
*/
public class SocketKey {
@Nullable
private final Network mNetwork;
private final int mInterfaceIndex;
SocketKey(int interfaceIndex) {
this(null /* network */, interfaceIndex);
}
SocketKey(@Nullable Network network, int interfaceIndex) {
mNetwork = network;
mInterfaceIndex = interfaceIndex;
}
public Network getNetwork() {
return mNetwork;
}
public int getInterfaceIndex() {
return mInterfaceIndex;
}
@Override
public int hashCode() {
return Objects.hash(mNetwork, mInterfaceIndex);
}
@Override
public boolean equals(@Nullable Object other) {
if (!(other instanceof SocketKey)) {
return false;
}
return Objects.equals(mNetwork, ((SocketKey) other).mNetwork)
&& mInterfaceIndex == ((SocketKey) other).mInterfaceIndex;
}
@Override
public String toString() {
return "SocketKey{ network=" + mNetwork + " interfaceIndex=" + mInterfaceIndex + " }";
}
}

View File

@@ -56,7 +56,8 @@ private const val TIMEOUT_MS = 10_000L
private val TEST_ADDR = parseNumericAddress("2001:db8::123")
private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
private val TEST_NETWORK_1 = mock(Network::class.java)
private val TEST_NETWORK_2 = mock(Network::class.java)
private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SUBTYPE = "_subtype"
@@ -145,7 +146,7 @@ class MdnsAdvertiserTest {
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser(
@@ -163,7 +164,7 @@ class MdnsAdvertiserTest {
mockInterfaceAdvertiser1, SERVICE_ID_1) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
postSync { socketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
verify(mockInterfaceAdvertiser1).destroyNow()
}
@@ -177,8 +178,8 @@ class MdnsAdvertiserTest {
socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_NETWORK_2, mockSocket2, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR)) }
val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
@@ -241,8 +242,8 @@ class MdnsAdvertiserTest {
// Callbacks for matching network and all networks both get the socket
postSync {
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
oneNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
}
val expectedRenamed = NsdServiceInfo(
@@ -294,8 +295,8 @@ class MdnsAdvertiserTest {
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) })
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
// destroyNow can be called multiple times
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()

View File

@@ -68,12 +68,15 @@ public class MdnsMultinetworkSocketClientTest {
@Mock private MdnsServiceBrowserListener mListener;
@Mock private MdnsSocketClientBase.Callback mCallback;
@Mock private SocketCreationCallback mSocketCreationCallback;
@Mock private SocketKey mSocketKey;
private MdnsMultinetworkSocketClient mSocketClient;
private Handler mHandler;
@Before
public void setUp() throws SocketException {
MockitoAnnotations.initMocks(this);
doReturn(mNetwork).when(mSocketKey).getNetwork();
final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
thread.start();
mHandler = new Handler(thread.getLooper());
@@ -123,12 +126,16 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
}
final SocketKey tetherSocketKey1 = mock(SocketKey.class);
final SocketKey tetherSocketKey2 = mock(SocketKey.class);
doReturn(null).when(tetherSocketKey1).getNetwork();
doReturn(null).when(tetherSocketKey2).getNetwork();
// Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of());
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(null, tetherIfaceSock1, List.of());
callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null, tetherIfaceSock2, List.of());
callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
// Send packet to IPv4 with target network and verify sending has been called.
@@ -164,7 +171,7 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
// Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of());
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
final ArgumentCaptor<PacketHandler> handlerCaptor =
@@ -214,9 +221,11 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
callback.onSocketCreated(mNetwork, mSocket, List.of());
callback.onSocketCreated(null, socket2, List.of());
callback.onSocketCreated(null, socket3, List.of());
final SocketKey socketKey2 = mock(SocketKey.class);
doReturn(null).when(socketKey2).getNetwork();
callback.onSocketCreated(mSocketKey, mSocket, List.of());
callback.onSocketCreated(socketKey2, socket2, List.of());
callback.onSocketCreated(socketKey2, socket3, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
@@ -241,9 +250,9 @@ public class MdnsMultinetworkSocketClientTest {
final SocketCallback callback2 = callback2Captor.getAllValues().get(1);
// Notify socket created for all networks.
callback2.onSocketCreated(mNetwork, mSocket, List.of());
callback2.onSocketCreated(null, socket2, List.of());
callback2.onSocketCreated(null, socket3, List.of());
callback2.onSocketCreated(mSocketKey, mSocket, List.of());
callback2.onSocketCreated(socketKey2, socket2, List.of());
callback2.onSocketCreated(socketKey2, socket3, List.of());
verify(socketCreationCb2).onSocketCreated(mNetwork);
verify(socketCreationCb2, times(2)).onSocketCreated(null);
@@ -286,17 +295,17 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null /* network */, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(mSocketKey, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(null /* network */);
verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork);
mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mProvider).unrequestSocket(callback);
verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
}
@Test
@@ -306,15 +315,15 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null /* network */, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(mSocketKey, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
// Notify socket destroyed
callback.onInterfaceDestroyed(null /* network */, mSocket);
callback.onInterfaceDestroyed(mSocketKey, mSocket);
verifyNoMoreInteractions(mSocketCreationCallback);
callback.onInterfaceDestroyed(null /* network */, otherSocket);
verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
callback.onInterfaceDestroyed(mSocketKey, otherSocket);
verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
}
}

View File

@@ -157,6 +157,7 @@ public class MdnsSocketProviderTest {
TETHERED_IFACE_NAME);
doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
WIFI_P2P_IFACE_NAME);
doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
thread.start();
mHandler = new Handler(thread.getLooper());
@@ -227,30 +228,30 @@ public class MdnsSocketProviderTest {
private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
private class SocketEvent {
public final Network mNetwork;
public final SocketKey mSocketKey;
public final List<LinkAddress> mAddresses;
SocketEvent(Network network, List<LinkAddress> addresses) {
mNetwork = network;
SocketEvent(SocketKey socketKey, List<LinkAddress> addresses) {
mSocketKey = socketKey;
mAddresses = Collections.unmodifiableList(addresses);
}
}
private class SocketCreatedEvent extends SocketEvent {
SocketCreatedEvent(Network nw, List<LinkAddress> addresses) {
super(nw, addresses);
SocketCreatedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(socketKey, addresses);
}
}
private class InterfaceDestroyedEvent extends SocketEvent {
InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) {
super(nw, addresses);
InterfaceDestroyedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(socketKey, addresses);
}
}
private class AddressesChangedEvent extends SocketEvent {
AddressesChangedEvent(Network nw, List<LinkAddress> addresses) {
super(nw, addresses);
AddressesChangedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(socketKey, addresses);
}
}
@@ -258,27 +259,27 @@ public class MdnsSocketProviderTest {
new ArrayTrackRecord<SocketEvent>().newReadHead();
@Override
public void onSocketCreated(Network network, MdnsInterfaceSocket socket,
public void onSocketCreated(SocketKey socketKey, MdnsInterfaceSocket socket,
List<LinkAddress> addresses) {
mHistory.add(new SocketCreatedEvent(network, addresses));
mHistory.add(new SocketCreatedEvent(socketKey, addresses));
}
@Override
public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
mHistory.add(new InterfaceDestroyedEvent(network, List.of()));
public void onInterfaceDestroyed(SocketKey socketKey, MdnsInterfaceSocket socket) {
mHistory.add(new InterfaceDestroyedEvent(socketKey, List.of()));
}
@Override
public void onAddressesChanged(Network network, MdnsInterfaceSocket socket,
public void onAddressesChanged(SocketKey socketKey, MdnsInterfaceSocket socket,
List<LinkAddress> addresses) {
mHistory.add(new AddressesChangedEvent(network, addresses));
mHistory.add(new AddressesChangedEvent(socketKey, addresses));
}
public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event);
assertTrue(event instanceof SocketCreatedEvent);
assertEquals(network, event.mNetwork);
assertEquals(network, event.mSocketKey.getNetwork());
assertEquals(addresses, event.mAddresses);
}
@@ -286,7 +287,7 @@ public class MdnsSocketProviderTest {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event);
assertTrue(event instanceof InterfaceDestroyedEvent);
assertEquals(network, event.mNetwork);
assertEquals(network, event.mSocketKey.getNetwork());
}
public void expectedAddressesChangedForNetwork(Network network,
@@ -294,7 +295,7 @@ public class MdnsSocketProviderTest {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event);
assertTrue(event instanceof AddressesChangedEvent);
assertEquals(network, event.mNetwork);
assertEquals(network, event.mSocketKey.getNetwork());
assertEquals(event.mAddresses, addresses);
}