diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java index 92a26f1a2a..afad3b7c0e 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java @@ -22,7 +22,6 @@ import android.Manifest.permission; import android.annotation.NonNull; import android.annotation.Nullable; import android.annotation.RequiresPermission; -import android.net.Network; import android.os.Handler; import android.os.HandlerThread; import android.util.ArrayMap; @@ -36,7 +35,6 @@ import com.android.server.connectivity.mdns.util.MdnsUtils; import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.Objects; /** * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and @@ -50,54 +48,58 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { private final MdnsSocketClientBase socketClient; @NonNull private final SharedLog sharedLog; - @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients; + @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients; @NonNull private final Handler handler; @Nullable private final HandlerThread handlerThread; - private static class PerNetworkServiceTypeClients { - private final ArrayMap, MdnsServiceTypeClient> clients = + private static class PerSocketServiceTypeClients { + private final ArrayMap, MdnsServiceTypeClient> clients = new ArrayMap<>(); - public void put(@NonNull String serviceType, @Nullable Network network, + public void put(@NonNull String serviceType, @NonNull SocketKey socketKey, @NonNull MdnsServiceTypeClient client) { final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType); - final Pair perNetworkServiceType = new Pair<>(dnsLowerServiceType, - network); - clients.put(perNetworkServiceType, client); + final Pair perSocketServiceType = new Pair<>(dnsLowerServiceType, + socketKey); + clients.put(perSocketServiceType, client); } @Nullable - public MdnsServiceTypeClient get(@NonNull String serviceType, @Nullable Network network) { + public MdnsServiceTypeClient get( + @NonNull String serviceType, @NonNull SocketKey socketKey) { final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType); - final Pair perNetworkServiceType = new Pair<>(dnsLowerServiceType, - network); - return clients.getOrDefault(perNetworkServiceType, null); + final Pair perSocketServiceType = new Pair<>(dnsLowerServiceType, + socketKey); + return clients.getOrDefault(perSocketServiceType, null); } public List getByServiceType(@NonNull String serviceType) { final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType); final List list = new ArrayList<>(); for (int i = 0; i < clients.size(); i++) { - final Pair perNetworkServiceType = clients.keyAt(i); - if (dnsLowerServiceType.equals(perNetworkServiceType.first)) { + final Pair perSocketServiceType = clients.keyAt(i); + if (dnsLowerServiceType.equals(perSocketServiceType.first)) { list.add(clients.valueAt(i)); } } return list; } - public List getByNetwork(@Nullable Network network) { + public List getBySocketKey(@NonNull SocketKey socketKey) { final List list = new ArrayList<>(); for (int i = 0; i < clients.size(); i++) { - final Pair perNetworkServiceType = clients.keyAt(i); - final Network serviceTypeNetwork = perNetworkServiceType.second; - if (Objects.equals(network, serviceTypeNetwork)) { + final Pair perSocketServiceType = clients.keyAt(i); + if (socketKey.equals(perSocketServiceType.second)) { list.add(clients.valueAt(i)); } } return list; } + public List getAllMdnsServiceTypeClient() { + return new ArrayList<>(clients.values()); + } + public void remove(@NonNull MdnsServiceTypeClient client) { final int index = clients.indexOfValue(client); clients.removeAt(index); @@ -113,7 +115,7 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { this.executorProvider = executorProvider; this.socketClient = socketClient; this.sharedLog = sharedLog; - this.perNetworkServiceTypeClients = new PerNetworkServiceTypeClients(); + this.perSocketServiceTypeClients = new PerSocketServiceTypeClients(); if (socketClient.getLooper() != null) { this.handlerThread = null; this.handler = new Handler(socketClient.getLooper()); @@ -164,7 +166,7 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { @NonNull String serviceType, @NonNull MdnsServiceBrowserListener listener, @NonNull MdnsSearchOptions searchOptions) { - if (perNetworkServiceTypeClients.isEmpty()) { + if (perSocketServiceTypeClients.isEmpty()) { // First listener. Starts the socket client. try { socketClient.startDiscovery(); @@ -177,29 +179,29 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(), new MdnsSocketClientBase.SocketCreationCallback() { @Override - public void onSocketCreated(@Nullable Network network) { + public void onSocketCreated(@NonNull SocketKey socketKey) { ensureRunningOnHandlerThread(handler); // All listeners of the same service types shares the same // MdnsServiceTypeClient. MdnsServiceTypeClient serviceTypeClient = - perNetworkServiceTypeClients.get(serviceType, network); + perSocketServiceTypeClients.get(serviceType, socketKey); if (serviceTypeClient == null) { - serviceTypeClient = createServiceTypeClient(serviceType, network); - perNetworkServiceTypeClients.put(serviceType, network, + serviceTypeClient = createServiceTypeClient(serviceType, socketKey); + perSocketServiceTypeClients.put(serviceType, socketKey, serviceTypeClient); } serviceTypeClient.startSendAndReceive(listener, searchOptions); } @Override - public void onAllSocketsDestroyed(@Nullable Network network) { + public void onAllSocketsDestroyed(@NonNull SocketKey socketKey) { ensureRunningOnHandlerThread(handler); final MdnsServiceTypeClient serviceTypeClient = - perNetworkServiceTypeClients.get(serviceType, network); + perSocketServiceTypeClients.get(serviceType, socketKey); if (serviceTypeClient == null) return; // Notify all listeners that all services are removed from this socket. serviceTypeClient.notifySocketDestroyed(); - perNetworkServiceTypeClients.remove(serviceTypeClient); + perSocketServiceTypeClients.remove(serviceTypeClient); } }); } @@ -224,7 +226,7 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { socketClient.notifyNetworkUnrequested(listener); final List serviceTypeClients = - perNetworkServiceTypeClients.getByServiceType(serviceType); + perSocketServiceTypeClients.getByServiceType(serviceType); if (serviceTypeClients.isEmpty()) { return; } @@ -233,10 +235,10 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { if (serviceTypeClient.stopSendAndReceive(listener)) { // No listener is registered for the service type anymore, remove it from the list // of the service type clients. - perNetworkServiceTypeClients.remove(serviceTypeClient); + perSocketServiceTypeClients.remove(serviceTypeClient); } } - if (perNetworkServiceTypeClients.isEmpty()) { + if (perSocketServiceTypeClients.isEmpty()) { // No discovery request. Stops the socket client. sharedLog.i("All service type listeners unregistered; stopping discovery"); socketClient.stopDiscovery(); @@ -244,50 +246,48 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback { } @Override - public void onResponseReceived(@NonNull MdnsPacket packet, - int interfaceIndex, @Nullable Network network) { + public void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey) { checkAndRunOnHandlerThread(() -> - handleOnResponseReceived(packet, interfaceIndex, network)); + handleOnResponseReceived(packet, socketKey)); } - private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex, - @Nullable Network network) { - for (MdnsServiceTypeClient serviceTypeClient - : getMdnsServiceTypeClient(network)) { - serviceTypeClient.processResponse(packet, interfaceIndex, network); + private void handleOnResponseReceived(@NonNull MdnsPacket packet, + @NonNull SocketKey socketKey) { + for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) { + serviceTypeClient.processResponse( + packet, socketKey.getInterfaceIndex(), socketKey.getNetwork()); } } - private List getMdnsServiceTypeClient(@Nullable Network network) { + private List getMdnsServiceTypeClient(@NonNull SocketKey socketKey) { if (socketClient.supportsRequestingSpecificNetworks()) { - return perNetworkServiceTypeClients.getByNetwork(network); + return perSocketServiceTypeClients.getBySocketKey(socketKey); } else { - return perNetworkServiceTypeClients.getByNetwork(null); + return perSocketServiceTypeClients.getAllMdnsServiceTypeClient(); } } @Override public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode, - @Nullable Network network) { + @NonNull SocketKey socketKey) { checkAndRunOnHandlerThread(() -> - handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network)); + handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, socketKey)); } private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode, - @Nullable Network network) { - for (MdnsServiceTypeClient serviceTypeClient - : getMdnsServiceTypeClient(network)) { + @NonNull SocketKey socketKey) { + for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) { serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode); } } @VisibleForTesting MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType, - @Nullable Network network) { - sharedLog.log("createServiceTypeClient for type:" + serviceType + ", net:" + network); + @NonNull SocketKey socketKey) { + sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey); return new MdnsServiceTypeClient( serviceType, socketClient, - executorProvider.newServiceTypeClientSchedulerExecutor(), network, - sharedLog.forSubComponent(serviceType + "-" + network)); + executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey, + sharedLog.forSubComponent(serviceType + "-" + socketKey)); } } \ No newline at end of file diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java index 03be681e6b..d0ca20e4a9 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java @@ -84,7 +84,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase { } socket.addPacketHandler(handler); mActiveNetworkSockets.put(socket, socketKey); - mSocketCreationCallback.onSocketCreated(socketKey.getNetwork()); + mSocketCreationCallback.onSocketCreated(socketKey); } @Override @@ -97,7 +97,7 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase { private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) { final SocketKey socketKey = mActiveNetworkSockets.remove(socket); if (!isAnySocketActive(socketKey)) { - mSocketCreationCallback.onAllSocketsDestroyed(socketKey.getNetwork()); + mSocketCreationCallback.onAllSocketsDestroyed(socketKey); } } @@ -247,16 +247,14 @@ public class MdnsMultinetworkSocketClient implements MdnsSocketClientBase { if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) { Log.e(TAG, e.getMessage(), e); if (mCallback != null) { - mCallback.onFailedToParseMdnsResponse( - packetNumber, e.code, socketKey.getNetwork()); + mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, socketKey); } } return; } if (mCallback != null) { - mCallback.onResponseReceived( - response, socketKey.getInterfaceIndex(), socketKey.getNetwork()); + mCallback.onResponseReceived(response, socketKey); } } diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java index 0e3522c753..bdc673e1df 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java @@ -58,7 +58,7 @@ public class MdnsServiceTypeClient { private final MdnsSocketClientBase socketClient; private final MdnsResponseDecoder responseDecoder; private final ScheduledExecutorService executor; - @Nullable private final Network network; + @NonNull private final SocketKey socketKey; @NonNull private final SharedLog sharedLog; private final Object lock = new Object(); private final ArrayMap listeners = @@ -90,9 +90,9 @@ public class MdnsServiceTypeClient { @NonNull String serviceType, @NonNull MdnsSocketClientBase socketClient, @NonNull ScheduledExecutorService executor, - @Nullable Network network, + @NonNull SocketKey socketKey, @NonNull SharedLog sharedLog) { - this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), network, + this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey, sharedLog); } @@ -102,7 +102,7 @@ public class MdnsServiceTypeClient { @NonNull MdnsSocketClientBase socketClient, @NonNull ScheduledExecutorService executor, @NonNull MdnsResponseDecoder.Clock clock, - @Nullable Network network, + @NonNull SocketKey socketKey, @NonNull SharedLog sharedLog) { this.serviceType = serviceType; this.socketClient = socketClient; @@ -110,7 +110,7 @@ public class MdnsServiceTypeClient { this.serviceTypeLabels = TextUtils.split(serviceType, "\\."); this.responseDecoder = new MdnsResponseDecoder(clock, serviceTypeLabels); this.clock = clock; - this.network = network; + this.socketKey = socketKey; this.sharedLog = sharedLog; } @@ -199,7 +199,7 @@ public class MdnsServiceTypeClient { searchOptions.getSubtypes(), searchOptions.isPassiveMode(), currentSessionId, - network); + socketKey); if (hadReply) { requestTaskFuture = scheduleNextRunLocked(taskConfig); } else { @@ -437,10 +437,10 @@ public class MdnsServiceTypeClient { private int burstCounter; private int timeToRunNextTaskInMs; private boolean isFirstBurst; - @Nullable private final Network network; + @NonNull private final SocketKey socketKey; QueryTaskConfig(@NonNull Collection subtypes, boolean usePassiveMode, - long sessionId, @Nullable Network network) { + long sessionId, @NonNull SocketKey socketKey) { this.usePassiveMode = usePassiveMode; this.subtypes = new ArrayList<>(subtypes); this.queriesPerBurst = QUERIES_PER_BURST; @@ -462,7 +462,7 @@ public class MdnsServiceTypeClient { // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS. this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS; } - this.network = network; + this.socketKey = socketKey; } QueryTaskConfig getConfigForNextRun() { @@ -545,7 +545,7 @@ public class MdnsServiceTypeClient { // Only the names are used to know which queries to send, other parameters like // interfaceIndex do not matter. servicesToResolve = makeResponsesForResolve( - 0 /* interfaceIndex */, config.network); + 0 /* interfaceIndex */, config.socketKey.getNetwork()); sendDiscoveryQueries = servicesToResolve.size() < listeners.size(); } Pair> result; @@ -558,7 +558,7 @@ public class MdnsServiceTypeClient { config.subtypes, config.expectUnicastResponse, config.transactionId, - config.network, + config.socketKey.getNetwork(), sendDiscoveryQueries, servicesToResolve, clock) diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java index b982644da1..2b6e5d0d35 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java @@ -235,7 +235,7 @@ public class MdnsSocketClient implements MdnsSocketClientBase { throw new IllegalArgumentException("This socket client does not support requesting " + "specific networks"); } - socketCreationCallback.onSocketCreated(null); + socketCreationCallback.onSocketCreated(new SocketKey(multicastSocket.getInterfaceIndex())); } @Override @@ -456,7 +456,8 @@ public class MdnsSocketClient implements MdnsSocketClientBase { LOGGER.w(String.format("Error while decoding %s packet (%d): %d", responseType, packetNumber, e.code)); if (callback != null) { - callback.onFailedToParseMdnsResponse(packetNumber, e.code, network); + callback.onFailedToParseMdnsResponse(packetNumber, e.code, + new SocketKey(network, interfaceIndex)); } return e.code; } @@ -466,7 +467,8 @@ public class MdnsSocketClient implements MdnsSocketClientBase { } if (callback != null) { - callback.onResponseReceived(response, interfaceIndex, network); + callback.onResponseReceived( + response, new SocketKey(network, interfaceIndex)); } return MdnsResponseErrorCode.SUCCESS; diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java index e0762f9a7c..a35925a9f9 100644 --- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java +++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java @@ -73,20 +73,19 @@ public interface MdnsSocketClientBase { /*** Callback for mdns response */ interface Callback { /*** Receive a mdns response */ - void onResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex, - @Nullable Network network); + void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey); /*** Parse a mdns response failed */ void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode, - @Nullable Network network); + @NonNull SocketKey socketKey); } /*** Callback for requested socket creation */ interface SocketCreationCallback { /*** Notify requested socket is created */ - void onSocketCreated(@Nullable Network network); + void onSocketCreated(@NonNull SocketKey socketKey); /*** Notify requested socket is destroyed */ - void onAllSocketsDestroyed(@Nullable Network network); + void onAllSocketsDestroyed(@NonNull SocketKey socketKey); } } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java index a24664e4ac..d2298fe522 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java @@ -28,7 +28,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import android.annotation.NonNull; -import android.annotation.Nullable; import android.net.Network; import android.os.Handler; import android.os.HandlerThread; @@ -65,17 +64,22 @@ public class MdnsDiscoveryManagerTests { private static final String SERVICE_TYPE_2 = "_test._tcp.local"; private static final Network NETWORK_1 = Mockito.mock(Network.class); private static final Network NETWORK_2 = Mockito.mock(Network.class); - private static final Pair PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK = - Pair.create(SERVICE_TYPE_1, null); - private static final Pair PER_NETWORK_SERVICE_TYPE_1_NETWORK_1 = - Pair.create(SERVICE_TYPE_1, NETWORK_1); - private static final Pair PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK = - Pair.create(SERVICE_TYPE_2, null); - private static final Pair PER_NETWORK_SERVICE_TYPE_2_NETWORK_1 = - Pair.create(SERVICE_TYPE_2, NETWORK_1); - private static final Pair PER_NETWORK_SERVICE_TYPE_2_NETWORK_2 = - Pair.create(SERVICE_TYPE_2, NETWORK_2); - + private static final SocketKey SOCKET_KEY_NULL_NETWORK = + new SocketKey(null /* network */, 999 /* interfaceIndex */); + private static final SocketKey SOCKET_KEY_NETWORK_1 = + new SocketKey(NETWORK_1, 998 /* interfaceIndex */); + private static final SocketKey SOCKET_KEY_NETWORK_2 = + new SocketKey(NETWORK_2, 997 /* interfaceIndex */); + private static final Pair PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK = + Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NULL_NETWORK); + private static final Pair PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK = + Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NULL_NETWORK); + private static final Pair PER_SOCKET_SERVICE_TYPE_1_NETWORK_1 = + Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NETWORK_1); + private static final Pair PER_SOCKET_SERVICE_TYPE_2_NETWORK_1 = + Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_1); + private static final Pair PER_SOCKET_SERVICE_TYPE_2_NETWORK_2 = + Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_2); @Mock private ExecutorProvider executorProvider; @Mock private MdnsSocketClientBase socketClient; @Mock private MdnsServiceTypeClient mockServiceTypeClientType1NullNetwork; @@ -104,22 +108,22 @@ public class MdnsDiscoveryManagerTests { sharedLog) { @Override MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType, - @Nullable Network network) { - final Pair perNetworkServiceType = - Pair.create(serviceType, network); - if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK)) { + @NonNull SocketKey socketKey) { + final Pair perSocketServiceType = + Pair.create(serviceType, socketKey); + if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) { return mockServiceTypeClientType1NullNetwork; - } else if (perNetworkServiceType.equals( - PER_NETWORK_SERVICE_TYPE_1_NETWORK_1)) { + } else if (perSocketServiceType.equals( + PER_SOCKET_SERVICE_TYPE_1_NETWORK_1)) { return mockServiceTypeClientType1Network1; - } else if (perNetworkServiceType.equals( - PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK)) { + } else if (perSocketServiceType.equals( + PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK)) { return mockServiceTypeClientType2NullNetwork; - } else if (perNetworkServiceType.equals( - PER_NETWORK_SERVICE_TYPE_2_NETWORK_1)) { + } else if (perSocketServiceType.equals( + PER_SOCKET_SERVICE_TYPE_2_NETWORK_1)) { return mockServiceTypeClientType2Network1; - } else if (perNetworkServiceType.equals( - PER_NETWORK_SERVICE_TYPE_2_NETWORK_2)) { + } else if (perSocketServiceType.equals( + PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) { return mockServiceTypeClientType2Network2; } return null; @@ -156,7 +160,7 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options); - runOnHandler(() -> callback.onSocketCreated(null /* network */)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK)); verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options); when(mockServiceTypeClientType1NullNetwork.stopSendAndReceive(mockListenerOne)) @@ -172,16 +176,16 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options); - runOnHandler(() -> callback.onSocketCreated(null /* network */)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK)); verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options); - runOnHandler(() -> callback.onSocketCreated(NETWORK_1)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1)); verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options); final SocketCreationCallback callback2 = expectSocketCreationCallback( SERVICE_TYPE_2, mockListenerTwo, options); - runOnHandler(() -> callback2.onSocketCreated(null /* network */)); + runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NULL_NETWORK)); verify(mockServiceTypeClientType2NullNetwork).startSendAndReceive(mockListenerTwo, options); - runOnHandler(() -> callback2.onSocketCreated(NETWORK_2)); + runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2)); verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options); } @@ -191,49 +195,48 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, options1); - runOnHandler(() -> callback.onSocketCreated(null /* network */)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK)); verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive( mockListenerOne, options1); - runOnHandler(() -> callback.onSocketCreated(NETWORK_1)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1)); verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options1); final MdnsSearchOptions options2 = MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build(); final SocketCreationCallback callback2 = expectSocketCreationCallback( SERVICE_TYPE_2, mockListenerTwo, options2); - runOnHandler(() -> callback2.onSocketCreated(NETWORK_2)); + runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2)); verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options2); final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1); - final int ifIndex = 1; runOnHandler(() -> discoveryManager.onResponseReceived( - responseForServiceTypeOne, ifIndex, null /* network */)); + responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK)); // Packets for network null are only processed by the ServiceTypeClient for network null verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne, - ifIndex, null /* network */); + SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork()); verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any()); verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any()); final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2); runOnHandler(() -> discoveryManager.onResponseReceived( - responseForServiceTypeTwo, ifIndex, NETWORK_1)); + responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1)); verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(), - eq(NETWORK_1)); + eq(SOCKET_KEY_NETWORK_1.getNetwork())); verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo, - ifIndex, NETWORK_1); + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), - eq(NETWORK_1)); + eq(SOCKET_KEY_NETWORK_1.getNetwork())); final MdnsPacket responseForSubtype = createMdnsPacket("subtype._sub._googlecast._tcp.local"); runOnHandler(() -> discoveryManager.onResponseReceived( - responseForSubtype, ifIndex, NETWORK_2)); - verify(mockServiceTypeClientType1NullNetwork, never()).processResponse( - any(), anyInt(), eq(NETWORK_2)); - verify(mockServiceTypeClientType1Network1, never()).processResponse( - any(), anyInt(), eq(NETWORK_2)); - verify(mockServiceTypeClientType2Network2).processResponse( - responseForSubtype, ifIndex, NETWORK_2); + responseForSubtype, SOCKET_KEY_NETWORK_2)); + verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(), + eq(SOCKET_KEY_NETWORK_2.getNetwork())); + verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), + eq(SOCKET_KEY_NETWORK_2.getNetwork())); + verify(mockServiceTypeClientType2Network2).processResponse(responseForSubtype, + SOCKET_KEY_NETWORK_2.getInterfaceIndex(), SOCKET_KEY_NETWORK_2.getNetwork()); } @Test @@ -243,55 +246,53 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, network1Options); - runOnHandler(() -> callback.onSocketCreated(NETWORK_1)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1)); verify(mockServiceTypeClientType1Network1).startSendAndReceive( mockListenerOne, network1Options); // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_1 final SocketCreationCallback callback2 = expectSocketCreationCallback( SERVICE_TYPE_2, mockListenerTwo, network1Options); - runOnHandler(() -> callback2.onSocketCreated(NETWORK_1)); + runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_1)); verify(mockServiceTypeClientType2Network1).startSendAndReceive( mockListenerTwo, network1Options); // Receive a response, it should be processed on both clients. final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1); - final int ifIndex = 1; - runOnHandler(() -> discoveryManager.onResponseReceived( - response, ifIndex, NETWORK_1)); - verify(mockServiceTypeClientType1Network1).processResponse(response, ifIndex, NETWORK_1); - verify(mockServiceTypeClientType2Network1).processResponse(response, ifIndex, NETWORK_1); + runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1)); + verify(mockServiceTypeClientType1Network1).processResponse(response, + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); + verify(mockServiceTypeClientType2Network1).processResponse(response, + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); // The first callback receives a notification that the network has been destroyed, // mockServiceTypeClientOne1 should send service removed notifications and remove from the // list of clients. - runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1)); + runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_1)); verify(mockServiceTypeClientType1Network1).notifySocketDestroyed(); // Receive a response again, it should be processed only on // mockServiceTypeClientType2Network1. Because the mockServiceTypeClientType1Network1 is // removed from the list of clients, it is no longer able to process responses. - runOnHandler(() -> discoveryManager.onResponseReceived( - response, ifIndex, NETWORK_1)); + runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1)); // Still times(1) as a response was received once previously - verify(mockServiceTypeClientType1Network1, times(1)) - .processResponse(response, ifIndex, NETWORK_1); - verify(mockServiceTypeClientType2Network1, times(2)) - .processResponse(response, ifIndex, NETWORK_1); + verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response, + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); + verify(mockServiceTypeClientType2Network1, times(2)).processResponse(response, + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed, // mockServiceTypeClientTwo2 shouldn't send any notifications. - runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_2)); + runOnHandler(() -> callback2.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_2)); verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed(); // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of // clients, it's still able to process responses. - runOnHandler(() -> discoveryManager.onResponseReceived( - response, ifIndex, NETWORK_1)); - verify(mockServiceTypeClientType1Network1, times(1)) - .processResponse(response, ifIndex, NETWORK_1); - verify(mockServiceTypeClientType2Network1, times(3)) - .processResponse(response, ifIndex, NETWORK_1); + runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1)); + verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response, + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); + verify(mockServiceTypeClientType2Network1, times(3)).processResponse(response, + SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork()); } @Test @@ -301,27 +302,25 @@ public class MdnsDiscoveryManagerTests { MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build(); final SocketCreationCallback callback = expectSocketCreationCallback( SERVICE_TYPE_1, mockListenerOne, network1Options); - runOnHandler(() -> callback.onSocketCreated(null /* network */)); + runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK)); verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive( mockListenerOne, network1Options); // Receive a response, it should be processed on the client. final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1); final int ifIndex = 1; - runOnHandler(() -> discoveryManager.onResponseReceived( - response, ifIndex, null /* network */)); - verify(mockServiceTypeClientType1NullNetwork).processResponse( - response, ifIndex, null /* network */); + runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK)); + verify(mockServiceTypeClientType1NullNetwork).processResponse(response, + SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork()); - runOnHandler(() -> callback.onAllSocketsDestroyed(null /* network */)); + runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK)); verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed(); // Receive a response again, it should not be processed. - runOnHandler(() -> discoveryManager.onResponseReceived( - response, ifIndex, null /* network */)); + runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK)); // Still times(1) as a response was received once previously - verify(mockServiceTypeClientType1NullNetwork, times(1)) - .processResponse(response, ifIndex, null /* network */); + verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(response, + SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork()); // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods // won't be call because the service type client was unregistered and destroyed. But those diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java index a0a302f7be..f7ef0778cb 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java @@ -21,7 +21,6 @@ import static com.android.server.connectivity.mdns.MulticastPacketReader.PacketH import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -132,11 +131,11 @@ public class MdnsMultinetworkSocketClientTest { doReturn(null).when(tetherSocketKey2).getNetwork(); // Notify socket created callback.onSocketCreated(mSocketKey, mSocket, List.of()); - verify(mSocketCreationCallback).onSocketCreated(mNetwork); + verify(mSocketCreationCallback).onSocketCreated(mSocketKey); callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of()); - verify(mSocketCreationCallback).onSocketCreated(null); + verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey1); callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of()); - verify(mSocketCreationCallback, times(2)).onSocketCreated(null); + verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2); // Send packet to IPv4 with target network and verify sending has been called. mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork); @@ -172,7 +171,7 @@ public class MdnsMultinetworkSocketClientTest { doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface(); // Notify socket created callback.onSocketCreated(mSocketKey, mSocket, List.of()); - verify(mSocketCreationCallback).onSocketCreated(mNetwork); + verify(mSocketCreationCallback).onSocketCreated(mSocketKey); final ArgumentCaptor handlerCaptor = ArgumentCaptor.forClass(PacketHandler.class); @@ -183,7 +182,7 @@ public class MdnsMultinetworkSocketClientTest { handler.handlePacket(data, data.length, null /* src */); final ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MdnsPacket.class); - verify(mCallback).onResponseReceived(responseCaptor.capture(), anyInt(), any()); + verify(mCallback).onResponseReceived(responseCaptor.capture(), any()); final MdnsPacket response = responseCaptor.getValue(); assertEquals(0, response.questions.size()); assertEquals(0, response.additionalRecords.size()); @@ -222,12 +221,13 @@ public class MdnsMultinetworkSocketClientTest { doReturn(createEmptyNetworkInterface()).when(socket3).getInterface(); final SocketKey socketKey2 = mock(SocketKey.class); - doReturn(null).when(socketKey2).getNetwork(); + final SocketKey socketKey3 = mock(SocketKey.class); callback.onSocketCreated(mSocketKey, mSocket, List.of()); callback.onSocketCreated(socketKey2, socket2, List.of()); - callback.onSocketCreated(socketKey2, socket3, List.of()); - verify(mSocketCreationCallback).onSocketCreated(mNetwork); - verify(mSocketCreationCallback, times(2)).onSocketCreated(null); + callback.onSocketCreated(socketKey3, socket3, List.of()); + verify(mSocketCreationCallback).onSocketCreated(mSocketKey); + verify(mSocketCreationCallback).onSocketCreated(socketKey2); + verify(mSocketCreationCallback).onSocketCreated(socketKey3); // Send IPv4 packet on the non-null Network and verify sending has been called. mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork); @@ -252,9 +252,10 @@ public class MdnsMultinetworkSocketClientTest { // Notify socket created for all networks. callback2.onSocketCreated(mSocketKey, mSocket, List.of()); callback2.onSocketCreated(socketKey2, socket2, List.of()); - callback2.onSocketCreated(socketKey2, socket3, List.of()); - verify(socketCreationCb2).onSocketCreated(mNetwork); - verify(socketCreationCb2, times(2)).onSocketCreated(null); + callback2.onSocketCreated(socketKey3, socket3, List.of()); + verify(socketCreationCb2).onSocketCreated(mSocketKey); + verify(socketCreationCb2).onSocketCreated(socketKey2); + verify(socketCreationCb2).onSocketCreated(socketKey3); // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets. mSocketClient.sendMulticastPacket(ipv4Packet, null); @@ -296,16 +297,16 @@ public class MdnsMultinetworkSocketClientTest { doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface(); callback.onSocketCreated(mSocketKey, mSocket, List.of()); - verify(mSocketCreationCallback).onSocketCreated(mNetwork); + verify(mSocketCreationCallback).onSocketCreated(mSocketKey); callback.onSocketCreated(mSocketKey, otherSocket, List.of()); - verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork); + verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey); - verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mNetwork); + verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mSocketKey); mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener)); HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT); verify(mProvider).unrequestSocket(callback); - verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork); + verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey); } @Test @@ -316,14 +317,14 @@ public class MdnsMultinetworkSocketClientTest { doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface(); callback.onSocketCreated(mSocketKey, mSocket, List.of()); - verify(mSocketCreationCallback).onSocketCreated(mNetwork); + verify(mSocketCreationCallback).onSocketCreated(mSocketKey); callback.onSocketCreated(mSocketKey, otherSocket, List.of()); - verify(mSocketCreationCallback, times(2)).onSocketCreated(mNetwork); + verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey); // Notify socket destroyed callback.onInterfaceDestroyed(mSocketKey, mSocket); verifyNoMoreInteractions(mSocketCreationCallback); callback.onInterfaceDestroyed(mSocketKey, otherSocket); - verify(mSocketCreationCallback).onAllSocketsDestroyed(mNetwork); + verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey); } } diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java index d1adecfb91..635a1d4240 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java @@ -118,6 +118,7 @@ public class MdnsServiceTypeClientTests { private FakeExecutor currentThreadExecutor = new FakeExecutor(); private MdnsServiceTypeClient client; + private SocketKey socketKey; @Before @SuppressWarnings("DoNotMock") @@ -128,6 +129,7 @@ public class MdnsServiceTypeClientTests { expectedIPv4Packets = new DatagramPacket[16]; expectedIPv6Packets = new DatagramPacket[16]; expectedSendFutures = new ScheduledFuture[16]; + socketKey = new SocketKey(mockNetwork, INTERFACE_INDEX); for (int i = 0; i < expectedSendFutures.length; ++i) { expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */, @@ -174,7 +176,7 @@ public class MdnsServiceTypeClientTests { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, mockNetwork, mockSharedLog) { + mockDecoderClock, socketKey, mockSharedLog) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -325,7 +327,7 @@ public class MdnsServiceTypeClientTests { MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build(); QueryTaskConfig config = new QueryTaskConfig( - searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork); + searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey); // This is the first query. We will ask for unicast response. assertTrue(config.expectUnicastResponse); @@ -354,7 +356,7 @@ public class MdnsServiceTypeClientTests { MdnsSearchOptions searchOptions = MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build(); QueryTaskConfig config = new QueryTaskConfig( - searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork); + searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey); // This is the first query. We will ask for unicast response. assertTrue(config.expectUnicastResponse); @@ -508,9 +510,9 @@ public class MdnsServiceTypeClientTests { // Process a second response with a different port and updated text attributes. client.processResponse(createResponse( - "service-instance-1", ipV4Address, 5354, - /* subtype= */ "ABCDE", - Collections.singletonMap("key", "value"), TEST_TTL), + "service-instance-1", ipV4Address, 5354, + /* subtype= */ "ABCDE", + Collections.singletonMap("key", "value"), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork); // Verify onServiceNameDiscovered was called once for the initial response. @@ -563,9 +565,9 @@ public class MdnsServiceTypeClientTests { // Process a second response with a different port and updated text attributes. client.processResponse(createResponse( - "service-instance-1", ipV6Address, 5354, - /* subtype= */ "ABCDE", - Collections.singletonMap("key", "value"), TEST_TTL), + "service-instance-1", ipV6Address, 5354, + /* subtype= */ "ABCDE", + Collections.singletonMap("key", "value"), TEST_TTL), /* interfaceIndex= */ 20, mockNetwork); // Verify onServiceNameDiscovered was called once for the initial response. @@ -709,7 +711,7 @@ public class MdnsServiceTypeClientTests { final String serviceInstanceName = "service-instance-1"; client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, mockNetwork, mockSharedLog) { + mockDecoderClock, socketKey, mockSharedLog) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -750,7 +752,7 @@ public class MdnsServiceTypeClientTests { final String serviceInstanceName = "service-instance-1"; client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, mockNetwork, mockSharedLog) { + mockDecoderClock, socketKey, mockSharedLog) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -783,7 +785,7 @@ public class MdnsServiceTypeClientTests { final String serviceInstanceName = "service-instance-1"; client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, mockNetwork, mockSharedLog) { + mockDecoderClock, socketKey, mockSharedLog) { @Override MdnsPacketWriter createMdnsPacketWriter() { return mockPacketWriter; @@ -835,8 +837,8 @@ public class MdnsServiceTypeClientTests { // Process the last response which is goodbye message (with the main type, not subtype). client.processResponse(createResponse( - serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS, - Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L), + serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS, + Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L), INTERFACE_INDEX, mockNetwork); // Verify onServiceNameDiscovered was first called for the initial response. @@ -908,7 +910,7 @@ public class MdnsServiceTypeClientTests { @Test public void testProcessResponse_Resolve() throws Exception { client = new MdnsServiceTypeClient( - SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog); + SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog); final String instanceName = "service-instance"; final String[] hostname = new String[] { "testhost "}; @@ -998,7 +1000,7 @@ public class MdnsServiceTypeClientTests { @Test public void testRenewTxtSrvInResolve() throws Exception { client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor, - mockDecoderClock, mockNetwork, mockSharedLog); + mockDecoderClock, socketKey, mockSharedLog); final String instanceName = "service-instance"; final String[] hostname = new String[] { "testhost "}; @@ -1102,7 +1104,7 @@ public class MdnsServiceTypeClientTests { @Test public void testProcessResponse_ResolveExcludesOtherServices() { client = new MdnsServiceTypeClient( - SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog); + SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog); final String requestedInstance = "instance1"; final String otherInstance = "instance2"; @@ -1119,13 +1121,13 @@ public class MdnsServiceTypeClientTests { // Complete response from instanceName client.processResponse(createResponse( - requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS, + requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS, Collections.emptyMap() /* textAttributes */, TEST_TTL), INTERFACE_INDEX, mockNetwork); // Complete response from otherInstanceName client.processResponse(createResponse( - otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS, + otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS, Collections.emptyMap() /* textAttributes */, TEST_TTL), INTERFACE_INDEX, mockNetwork); @@ -1166,7 +1168,7 @@ public class MdnsServiceTypeClientTests { @Test public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() { client = new MdnsServiceTypeClient( - SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog); + SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog); final String matchingInstance = "instance1"; final String subtype = "_subtype"; @@ -1247,7 +1249,7 @@ public class MdnsServiceTypeClientTests { @Test public void testNotifySocketDestroyed() throws Exception { client = new MdnsServiceTypeClient( - SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog); + SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog); final String requestedInstance = "instance1"; final String otherInstance = "instance2"; diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java index abb174748b..e30c2496d5 100644 --- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java +++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; @@ -370,7 +371,7 @@ public class MdnsSocketClientTests { mdnsClient.startDiscovery(); verify(mockCallback, timeout(TIMEOUT).atLeast(1)) - .onResponseReceived(any(MdnsPacket.class), anyInt(), any()); + .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class)); } @Test @@ -379,7 +380,7 @@ public class MdnsSocketClientTests { mdnsClient.startDiscovery(); verify(mockCallback, timeout(TIMEOUT).atLeastOnce()) - .onResponseReceived(any(MdnsPacket.class), anyInt(), any()); + .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class)); mdnsClient.stopDiscovery(); } @@ -513,7 +514,7 @@ public class MdnsSocketClientTests { mdnsClient.startDiscovery(); verify(mockCallback, timeout(TIMEOUT).atLeastOnce()) - .onResponseReceived(any(), eq(21), any()); + .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == 21)); } @Test @@ -536,6 +537,7 @@ public class MdnsSocketClientTests { mdnsClient.startDiscovery(); verify(mockMulticastSocket, never()).getInterfaceIndex(); - verify(mockCallback, timeout(TIMEOUT).atLeast(1)).onResponseReceived(any(), eq(-1), any()); + verify(mockCallback, timeout(TIMEOUT).atLeast(1)) + .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1)); } } \ No newline at end of file