Merge "Notify socket changes using a SoketKey"

This commit is contained in:
Paul Hu
2023-06-15 01:01:24 +00:00
committed by Gerrit Code Review
7 changed files with 184 additions and 89 deletions

View File

@@ -287,7 +287,7 @@ public class MdnsAdvertiser {
} }
@Override @Override
public void onSocketCreated(@NonNull Network network, public void onSocketCreated(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull MdnsInterfaceSocket socket,
@NonNull List<LinkAddress> addresses) { @NonNull List<LinkAddress> addresses) {
MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket); MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket);
@@ -311,14 +311,14 @@ public class MdnsAdvertiser {
} }
@Override @Override
public void onInterfaceDestroyed(@NonNull Network network, public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket) { @NonNull MdnsInterfaceSocket socket) {
final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket); final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
if (advertiser != null) advertiser.destroyNow(); if (advertiser != null) advertiser.destroyNow();
} }
@Override @Override
public void onAddressesChanged(@NonNull Network network, public void onAddressesChanged(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) { @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket); final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
if (advertiser != null) advertiser.updateAddresses(addresses); if (advertiser != null) advertiser.updateAddresses(addresses);

View File

@@ -64,7 +64,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
@NonNull @NonNull
private final SocketCreationCallback mSocketCreationCallback; private final SocketCreationCallback mSocketCreationCallback;
@NonNull @NonNull
private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets = private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveNetworkSockets =
new ArrayMap<>(); new ArrayMap<>();
InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) { InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -72,32 +72,32 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
} }
@Override @Override
public void onSocketCreated(@Nullable Network network, public void onSocketCreated(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) { @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
// The socket may be already created by other request before, try to get the stored // The socket may be already created by other request before, try to get the stored
// ReadPacketHandler. // ReadPacketHandler.
ReadPacketHandler handler = mSocketPacketHandlers.get(socket); ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
if (handler == null) { if (handler == null) {
// First request to create this socket. Initial a ReadPacketHandler for this socket. // First request to create this socket. Initial a ReadPacketHandler for this socket.
handler = new ReadPacketHandler(network, socket.getInterface().getIndex()); handler = new ReadPacketHandler(socketKey);
mSocketPacketHandlers.put(socket, handler); mSocketPacketHandlers.put(socket, handler);
} }
socket.addPacketHandler(handler); socket.addPacketHandler(handler);
mActiveNetworkSockets.put(socket, network); mActiveNetworkSockets.put(socket, socketKey);
mSocketCreationCallback.onSocketCreated(network); mSocketCreationCallback.onSocketCreated(socketKey.getNetwork());
} }
@Override @Override
public void onInterfaceDestroyed(@Nullable Network network, public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket) { @NonNull MdnsInterfaceSocket socket) {
notifySocketDestroyed(socket); notifySocketDestroyed(socket);
maybeCleanupPacketHandler(socket); maybeCleanupPacketHandler(socket);
} }
private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) { private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
final Network network = mActiveNetworkSockets.remove(socket); final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
if (!isAnySocketActive(network)) { if (!isAnySocketActive(socketKey)) {
mSocketCreationCallback.onAllSocketsDestroyed(network); mSocketCreationCallback.onAllSocketsDestroyed(socketKey.getNetwork());
} }
} }
@@ -121,18 +121,18 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
return false; return false;
} }
private boolean isAnySocketActive(@Nullable Network network) { private boolean isAnySocketActive(@NonNull SocketKey socketKey) {
for (int i = 0; i < mRequestedNetworks.size(); i++) { for (int i = 0; i < mRequestedNetworks.size(); i++) {
final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i); final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
if (isc.mActiveNetworkSockets.containsValue(network)) { if (isc.mActiveNetworkSockets.containsValue(socketKey)) {
return true; return true;
} }
} }
return false; return false;
} }
private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() { private ArrayMap<MdnsInterfaceSocket, SocketKey> getActiveSockets() {
final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>(); final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
for (int i = 0; i < mRequestedNetworks.size(); i++) { for (int i = 0; i < mRequestedNetworks.size(); i++) {
final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i); final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
sockets.putAll(isc.mActiveNetworkSockets); sockets.putAll(isc.mActiveNetworkSockets);
@@ -146,17 +146,15 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
} }
private class ReadPacketHandler implements MulticastPacketReader.PacketHandler { private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
private final Network mNetwork; @NonNull private final SocketKey mSocketKey;
private final int mInterfaceIndex;
ReadPacketHandler(@NonNull Network network, int interfaceIndex) { ReadPacketHandler(@NonNull SocketKey socketKey) {
mNetwork = network; mSocketKey = socketKey;
mInterfaceIndex = interfaceIndex;
} }
@Override @Override
public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) { public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork); processResponsePacket(recvbuf, length, mSocketKey);
} }
} }
@@ -220,10 +218,10 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
instanceof Inet6Address; instanceof Inet6Address;
final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress() final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
instanceof Inet4Address; instanceof Inet4Address;
final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets(); final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
for (int i = 0; i < activeSockets.size(); i++) { for (int i = 0; i < activeSockets.size(); i++) {
final MdnsInterfaceSocket socket = activeSockets.keyAt(i); final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
final Network network = activeSockets.valueAt(i); final Network network = activeSockets.valueAt(i).getNetwork();
// Check ip capability and network before sending packet // Check ip capability and network before sending packet
if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4())) if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
// Contrary to MdnsUtils.isNetworkMatched, only send packets targeting // Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
@@ -239,8 +237,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
} }
} }
private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex, private void processResponsePacket(byte[] recvbuf, int length, @NonNull SocketKey socketKey) {
@NonNull Network network) {
int packetNumber = ++mReceivedPacketNumber; int packetNumber = ++mReceivedPacketNumber;
final MdnsPacket response; final MdnsPacket response;
@@ -250,14 +247,16 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase {
if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) { if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
Log.e(TAG, e.getMessage(), e); Log.e(TAG, e.getMessage(), e);
if (mCallback != null) { if (mCallback != null) {
mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, network); mCallback.onFailedToParseMdnsResponse(
packetNumber, e.code, socketKey.getNetwork());
} }
} }
return; return;
} }
if (mCallback != null) { if (mCallback != null) {
mCallback.onResponseReceived(response, interfaceIndex, network); mCallback.onResponseReceived(
response, socketKey.getInterfaceIndex(), socketKey.getNetwork());
} }
} }

View File

@@ -258,6 +258,11 @@ public class MdnsSocketProvider {
@NonNull final NetLinkMonitorCallBack cb) { @NonNull final NetLinkMonitorCallBack cb) {
return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb); return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
} }
/*** Get interface index by given socket */
public int getInterfaceIndex(@NonNull MdnsInterfaceSocket socket) {
return socket.getInterface().getIndex();
}
} }
/** /**
* The callback interface for the netlink monitor messages. * The callback interface for the netlink monitor messages.
@@ -597,8 +602,10 @@ public class MdnsSocketProvider {
for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) { for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i); final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
if (isNetworkMatched(requestedNetwork, network)) { if (isNetworkMatched(requestedNetwork, network)) {
mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socketInfo.mSocket, final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
socketInfo.mAddresses); final SocketKey socketKey = new SocketKey(network, ifaceIndex);
mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketKey,
socketInfo.mSocket, socketInfo.mAddresses);
mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket, mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
socketInfo.mTransports); socketInfo.mTransports);
} }
@@ -609,7 +616,9 @@ public class MdnsSocketProvider {
for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) { for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i); final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
if (isNetworkMatched(requestedNetwork, network)) { if (isNetworkMatched(requestedNetwork, network)) {
mCallbacksToRequestedNetworks.keyAt(i).onInterfaceDestroyed(network, socket); final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
mCallbacksToRequestedNetworks.keyAt(i)
.onInterfaceDestroyed(new SocketKey(network, ifaceIndex), socket);
} }
} }
} }
@@ -619,8 +628,9 @@ public class MdnsSocketProvider {
for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) { for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i); final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
if (isNetworkMatched(requestedNetwork, network)) { if (isNetworkMatched(requestedNetwork, network)) {
final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
mCallbacksToRequestedNetworks.keyAt(i) mCallbacksToRequestedNetworks.keyAt(i)
.onAddressesChanged(network, socket, addresses); .onAddressesChanged(new SocketKey(network, ifaceIndex), socket, addresses);
} }
} }
} }
@@ -637,7 +647,9 @@ public class MdnsSocketProvider {
createSocket(new NetworkAsKey(network), lp); createSocket(new NetworkAsKey(network), lp);
} else { } else {
// Notify the socket for requested network. // Notify the socket for requested network.
cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses); final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
final SocketKey socketKey = new SocketKey(network, ifaceIndex);
cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket, mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
socketInfo.mTransports); socketInfo.mTransports);
} }
@@ -652,8 +664,9 @@ public class MdnsSocketProvider {
createLPForTetheredInterface(interfaceName, ifaceIndex)); createLPForTetheredInterface(interfaceName, ifaceIndex));
} else { } else {
// Notify the socket for requested network. // Notify the socket for requested network.
cb.onSocketCreated( final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
null /* network */, socketInfo.mSocket, socketInfo.mAddresses); final SocketKey socketKey = new SocketKey(ifaceIndex);
cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */, mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
socketInfo.mSocket, socketInfo.mTransports); socketInfo.mSocket, socketInfo.mTransports);
} }
@@ -741,21 +754,21 @@ public class MdnsSocketProvider {
* This may be called immediately when the request is registered with an existing socket, * This may be called immediately when the request is registered with an existing socket,
* if it had been created previously for other requests. * if it had been created previously for other requests.
*/ */
default void onSocketCreated(@Nullable Network network, @NonNull MdnsInterfaceSocket socket, default void onSocketCreated(@NonNull SocketKey socketKey,
@NonNull List<LinkAddress> addresses) {} @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
/** /**
* Notify that the interface was destroyed, so the provided socket cannot be used anymore. * Notify that the interface was destroyed, so the provided socket cannot be used anymore.
* *
* This indicates that although the socket was still requested, it had to be destroyed. * This indicates that although the socket was still requested, it had to be destroyed.
*/ */
default void onInterfaceDestroyed(@Nullable Network network, default void onInterfaceDestroyed(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket) {} @NonNull MdnsInterfaceSocket socket) {}
/** /**
* Notify the interface addresses have changed for the network. * Notify the interface addresses have changed for the network.
*/ */
default void onAddressesChanged(@Nullable Network network, default void onAddressesChanged(@NonNull SocketKey socketKey,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {} @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
} }

View File

@@ -0,0 +1,72 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
import android.annotation.Nullable;
import android.net.Network;
import java.util.Objects;
/**
* A class that identifies a socket.
*
* <p> A socket is typically created with an associated network. However, tethering interfaces do
* not have an associated network, only an interface index. This means that the socket cannot be
* identified in some places. Therefore, this class is necessary for identifying a socket. It
* includes both the network and interface index.
*/
public class SocketKey {
@Nullable
private final Network mNetwork;
private final int mInterfaceIndex;
SocketKey(int interfaceIndex) {
this(null /* network */, interfaceIndex);
}
SocketKey(@Nullable Network network, int interfaceIndex) {
mNetwork = network;
mInterfaceIndex = interfaceIndex;
}
public Network getNetwork() {
return mNetwork;
}
public int getInterfaceIndex() {
return mInterfaceIndex;
}
@Override
public int hashCode() {
return Objects.hash(mNetwork, mInterfaceIndex);
}
@Override
public boolean equals(@Nullable Object other) {
if (!(other instanceof SocketKey)) {
return false;
}
return Objects.equals(mNetwork, ((SocketKey) other).mNetwork)
&& mInterfaceIndex == ((SocketKey) other).mInterfaceIndex;
}
@Override
public String toString() {
return "SocketKey{ network=" + mNetwork + " interfaceIndex=" + mInterfaceIndex + " }";
}
}

View File

@@ -56,7 +56,8 @@ private const val TIMEOUT_MS = 10_000L
private val TEST_ADDR = parseNumericAddress("2001:db8::123") private val TEST_ADDR = parseNumericAddress("2001:db8::123")
private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */) private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
private val TEST_NETWORK_1 = mock(Network::class.java) private val TEST_NETWORK_1 = mock(Network::class.java)
private val TEST_NETWORK_2 = mock(Network::class.java) private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
private val TEST_HOSTNAME = arrayOf("Android_test", "local") private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SUBTYPE = "_subtype" private const val TEST_SUBTYPE = "_subtype"
@@ -145,7 +146,7 @@ class MdnsAdvertiserTest {
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
val socketCb = socketCbCaptor.value val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) } postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
verify(mockDeps).makeAdvertiser( verify(mockDeps).makeAdvertiser(
@@ -163,7 +164,7 @@ class MdnsAdvertiserTest {
mockInterfaceAdvertiser1, SERVICE_ID_1) } mockInterfaceAdvertiser1, SERVICE_ID_1) }
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) }) verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
postSync { socketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) } postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
verify(mockInterfaceAdvertiser1).destroyNow() verify(mockInterfaceAdvertiser1).destroyNow()
} }
@@ -177,8 +178,8 @@ class MdnsAdvertiserTest {
socketCbCaptor.capture()) socketCbCaptor.capture())
val socketCb = socketCbCaptor.value val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) } postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
postSync { socketCb.onSocketCreated(TEST_NETWORK_2, mockSocket2, listOf(TEST_LINKADDR)) } postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR)) }
val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java) val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
@@ -241,8 +242,8 @@ class MdnsAdvertiserTest {
// Callbacks for matching network and all networks both get the socket // Callbacks for matching network and all networks both get the socket
postSync { postSync {
oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) oneNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) allNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
} }
val expectedRenamed = NsdServiceInfo( val expectedRenamed = NsdServiceInfo(
@@ -294,8 +295,8 @@ class MdnsAdvertiserTest {
verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2), verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
argThat { it.matches(expectedRenamed) }) argThat { it.matches(expectedRenamed) })
postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) } postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) } postSync { allNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
// destroyNow can be called multiple times // destroyNow can be called multiple times
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow() verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()

View File

@@ -68,12 +68,15 @@ public class MdnsMultinetworkSocketClientTest {
@Mock private MdnsServiceBrowserListener mListener; @Mock private MdnsServiceBrowserListener mListener;
@Mock private MdnsSocketClientBase.Callback mCallback; @Mock private MdnsSocketClientBase.Callback mCallback;
@Mock private SocketCreationCallback mSocketCreationCallback; @Mock private SocketCreationCallback mSocketCreationCallback;
@Mock private SocketKey mSocketKey;
private MdnsMultinetworkSocketClient mSocketClient; private MdnsMultinetworkSocketClient mSocketClient;
private Handler mHandler; private Handler mHandler;
@Before @Before
public void setUp() throws SocketException { public void setUp() throws SocketException {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
doReturn(mNetwork).when(mSocketKey).getNetwork();
final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest"); final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
thread.start(); thread.start();
mHandler = new Handler(thread.getLooper()); mHandler = new Handler(thread.getLooper());
@@ -123,12 +126,16 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(socket).getInterface(); doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
} }
final SocketKey tetherSocketKey1 = mock(SocketKey.class);
final SocketKey tetherSocketKey2 = mock(SocketKey.class);
doReturn(null).when(tetherSocketKey1).getNetwork();
doReturn(null).when(tetherSocketKey2).getNetwork();
// Notify socket created // Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of()); callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork); verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(null, tetherIfaceSock1, List.of()); callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
verify(mSocketCreationCallback).onSocketCreated(null); verify(mSocketCreationCallback).onSocketCreated(null);
callback.onSocketCreated(null, tetherIfaceSock2, List.of()); callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null); verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
// Send packet to IPv4 with target network and verify sending has been called. // Send packet to IPv4 with target network and verify sending has been called.
@@ -164,7 +171,7 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface(); doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
// Notify socket created // Notify socket created
callback.onSocketCreated(mNetwork, mSocket, List.of()); callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork); verify(mSocketCreationCallback).onSocketCreated(mNetwork);
final ArgumentCaptor<PacketHandler> handlerCaptor = final ArgumentCaptor<PacketHandler> handlerCaptor =
@@ -214,9 +221,11 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(socket2).getInterface(); doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
doReturn(createEmptyNetworkInterface()).when(socket3).getInterface(); doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
callback.onSocketCreated(mNetwork, mSocket, List.of()); final SocketKey socketKey2 = mock(SocketKey.class);
callback.onSocketCreated(null, socket2, List.of()); doReturn(null).when(socketKey2).getNetwork();
callback.onSocketCreated(null, socket3, List.of()); callback.onSocketCreated(mSocketKey, mSocket, List.of());
callback.onSocketCreated(socketKey2, socket2, List.of());
callback.onSocketCreated(socketKey2, socket3, List.of());
verify(mSocketCreationCallback).onSocketCreated(mNetwork); verify(mSocketCreationCallback).onSocketCreated(mNetwork);
verify(mSocketCreationCallback, times(2)).onSocketCreated(null); verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
@@ -241,9 +250,9 @@ public class MdnsMultinetworkSocketClientTest {
final SocketCallback callback2 = callback2Captor.getAllValues().get(1); final SocketCallback callback2 = callback2Captor.getAllValues().get(1);
// Notify socket created for all networks. // Notify socket created for all networks.
callback2.onSocketCreated(mNetwork, mSocket, List.of()); callback2.onSocketCreated(mSocketKey, mSocket, List.of());
callback2.onSocketCreated(null, socket2, List.of()); callback2.onSocketCreated(socketKey2, socket2, List.of());
callback2.onSocketCreated(null, socket3, List.of()); callback2.onSocketCreated(socketKey2, socket3, List.of());
verify(socketCreationCb2).onSocketCreated(mNetwork); verify(socketCreationCb2).onSocketCreated(mNetwork);
verify(socketCreationCb2, times(2)).onSocketCreated(null); verify(socketCreationCb2, times(2)).onSocketCreated(null);
@@ -286,17 +295,17 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface(); doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface(); doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of()); callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null); verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(null /* network */, otherSocket, List.of()); callback.onSocketCreated(mSocketKey, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null); verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(null /* network */); verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork);
mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener)); mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
verify(mProvider).unrequestSocket(callback); verify(mProvider).unrequestSocket(callback);
verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */); verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
} }
@Test @Test
@@ -306,15 +315,15 @@ public class MdnsMultinetworkSocketClientTest {
doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface(); doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface(); doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
callback.onSocketCreated(null /* network */, mSocket, List.of()); callback.onSocketCreated(mSocketKey, mSocket, List.of());
verify(mSocketCreationCallback).onSocketCreated(null); verify(mSocketCreationCallback).onSocketCreated(mNetwork);
callback.onSocketCreated(null /* network */, otherSocket, List.of()); callback.onSocketCreated(mSocketKey, otherSocket, List.of());
verify(mSocketCreationCallback, times(2)).onSocketCreated(null); verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork);
// Notify socket destroyed // Notify socket destroyed
callback.onInterfaceDestroyed(null /* network */, mSocket); callback.onInterfaceDestroyed(mSocketKey, mSocket);
verifyNoMoreInteractions(mSocketCreationCallback); verifyNoMoreInteractions(mSocketCreationCallback);
callback.onInterfaceDestroyed(null /* network */, otherSocket); callback.onInterfaceDestroyed(mSocketKey, otherSocket);
verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */); verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork);
} }
} }

View File

@@ -157,6 +157,7 @@ public class MdnsSocketProviderTest {
TETHERED_IFACE_NAME); TETHERED_IFACE_NAME);
doReturn(789).when(mDeps).getNetworkInterfaceIndexByName( doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
WIFI_P2P_IFACE_NAME); WIFI_P2P_IFACE_NAME);
doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest"); final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
thread.start(); thread.start();
mHandler = new Handler(thread.getLooper()); mHandler = new Handler(thread.getLooper());
@@ -227,30 +228,30 @@ public class MdnsSocketProviderTest {
private class TestSocketCallback implements MdnsSocketProvider.SocketCallback { private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
private class SocketEvent { private class SocketEvent {
public final Network mNetwork; public final SocketKey mSocketKey;
public final List<LinkAddress> mAddresses; public final List<LinkAddress> mAddresses;
SocketEvent(Network network, List<LinkAddress> addresses) { SocketEvent(SocketKey socketKey, List<LinkAddress> addresses) {
mNetwork = network; mSocketKey = socketKey;
mAddresses = Collections.unmodifiableList(addresses); mAddresses = Collections.unmodifiableList(addresses);
} }
} }
private class SocketCreatedEvent extends SocketEvent { private class SocketCreatedEvent extends SocketEvent {
SocketCreatedEvent(Network nw, List<LinkAddress> addresses) { SocketCreatedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(nw, addresses); super(socketKey, addresses);
} }
} }
private class InterfaceDestroyedEvent extends SocketEvent { private class InterfaceDestroyedEvent extends SocketEvent {
InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) { InterfaceDestroyedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(nw, addresses); super(socketKey, addresses);
} }
} }
private class AddressesChangedEvent extends SocketEvent { private class AddressesChangedEvent extends SocketEvent {
AddressesChangedEvent(Network nw, List<LinkAddress> addresses) { AddressesChangedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
super(nw, addresses); super(socketKey, addresses);
} }
} }
@@ -258,27 +259,27 @@ public class MdnsSocketProviderTest {
new ArrayTrackRecord<SocketEvent>().newReadHead(); new ArrayTrackRecord<SocketEvent>().newReadHead();
@Override @Override
public void onSocketCreated(Network network, MdnsInterfaceSocket socket, public void onSocketCreated(SocketKey socketKey, MdnsInterfaceSocket socket,
List<LinkAddress> addresses) { List<LinkAddress> addresses) {
mHistory.add(new SocketCreatedEvent(network, addresses)); mHistory.add(new SocketCreatedEvent(socketKey, addresses));
} }
@Override @Override
public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) { public void onInterfaceDestroyed(SocketKey socketKey, MdnsInterfaceSocket socket) {
mHistory.add(new InterfaceDestroyedEvent(network, List.of())); mHistory.add(new InterfaceDestroyedEvent(socketKey, List.of()));
} }
@Override @Override
public void onAddressesChanged(Network network, MdnsInterfaceSocket socket, public void onAddressesChanged(SocketKey socketKey, MdnsInterfaceSocket socket,
List<LinkAddress> addresses) { List<LinkAddress> addresses) {
mHistory.add(new AddressesChangedEvent(network, addresses)); mHistory.add(new AddressesChangedEvent(socketKey, addresses));
} }
public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) { public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true); final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event); assertNotNull(event);
assertTrue(event instanceof SocketCreatedEvent); assertTrue(event instanceof SocketCreatedEvent);
assertEquals(network, event.mNetwork); assertEquals(network, event.mSocketKey.getNetwork());
assertEquals(addresses, event.mAddresses); assertEquals(addresses, event.mAddresses);
} }
@@ -286,7 +287,7 @@ public class MdnsSocketProviderTest {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true); final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event); assertNotNull(event);
assertTrue(event instanceof InterfaceDestroyedEvent); assertTrue(event instanceof InterfaceDestroyedEvent);
assertEquals(network, event.mNetwork); assertEquals(network, event.mSocketKey.getNetwork());
} }
public void expectedAddressesChangedForNetwork(Network network, public void expectedAddressesChangedForNetwork(Network network,
@@ -294,7 +295,7 @@ public class MdnsSocketProviderTest {
final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true); final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
assertNotNull(event); assertNotNull(event);
assertTrue(event instanceof AddressesChangedEvent); assertTrue(event instanceof AddressesChangedEvent);
assertEquals(network, event.mNetwork); assertEquals(network, event.mSocketKey.getNetwork());
assertEquals(event.mAddresses, addresses); assertEquals(event.mAddresses, addresses);
} }