Merge "Process all responses if the network is null in MdnsServiceTypeClient" into udc-dev

This commit is contained in:
Paul Hu
2023-05-11 07:00:24 +00:00
committed by Android (Google) Code Review
2 changed files with 84 additions and 33 deletions

View File

@@ -82,7 +82,11 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback {
final List<MdnsServiceTypeClient> list = new ArrayList<>();
for (int i = 0; i < clients.size(); i++) {
final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
if (isNetworkMatched(network, perNetworkServiceType.second)) {
final Network serviceTypeNetwork = perNetworkServiceType.second;
// The serviceTypeNetwork would be null if the MdnsSocketClient is being used. This
// is also the case if the socket is for a tethering interface. In either of these
// cases, the client is expected to process any responses.
if (serviceTypeNetwork == null || isNetworkMatched(network, serviceTypeNetwork)) {
list.add(clients.valueAt(i));
}
}
@@ -167,12 +171,12 @@ public class MdnsDiscoveryManager implements MdnsSocketClientBase.Callback {
// No listener is registered for the service type anymore, remove it from the list
// of the service type clients.
perNetworkServiceTypeClients.remove(serviceTypeClient);
}
}
if (perNetworkServiceTypeClients.isEmpty()) {
// No discovery request. Stops the socket client.
socketClient.stopDiscovery();
}
}
}
// Unrequested the network.
socketClient.notifyNetworkUnrequested(listener);
}

View File

@@ -18,8 +18,8 @@ package com.android.server.connectivity.mdns;
import static com.android.testutils.DevSdkIgnoreRuleKt.SC_V2;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -38,6 +38,7 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import java.io.IOException;
@@ -53,15 +54,23 @@ public class MdnsDiscoveryManagerTests {
private static final String SERVICE_TYPE_1 = "_googlecast._tcp.local";
private static final String SERVICE_TYPE_2 = "_test._tcp.local";
private static final Network NETWORK_1 = Mockito.mock(Network.class);
private static final Network NETWORK_2 = Mockito.mock(Network.class);
private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1 =
Pair.create(SERVICE_TYPE_1, null);
private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_1 =
Pair.create(SERVICE_TYPE_1, NETWORK_1);
private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2 =
Pair.create(SERVICE_TYPE_2, null);
private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_2 =
Pair.create(SERVICE_TYPE_2, NETWORK_2);
@Mock private ExecutorProvider executorProvider;
@Mock private MdnsSocketClientBase socketClient;
@Mock private MdnsServiceTypeClient mockServiceTypeClientOne;
@Mock private MdnsServiceTypeClient mockServiceTypeClientOne1;
@Mock private MdnsServiceTypeClient mockServiceTypeClientTwo;
@Mock private MdnsServiceTypeClient mockServiceTypeClientTwo2;
@Mock MdnsServiceBrowserListener mockListenerOne;
@Mock MdnsServiceBrowserListener mockListenerTwo;
@@ -71,11 +80,6 @@ public class MdnsDiscoveryManagerTests {
public void setUp() {
MockitoAnnotations.initMocks(this);
when(mockServiceTypeClientOne.getServiceTypeLabels())
.thenReturn(TextUtils.split(SERVICE_TYPE_1, "\\."));
when(mockServiceTypeClientTwo.getServiceTypeLabels())
.thenReturn(TextUtils.split(SERVICE_TYPE_2, "\\."));
discoveryManager = new MdnsDiscoveryManager(executorProvider, socketClient) {
@Override
MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
@@ -84,31 +88,37 @@ public class MdnsDiscoveryManagerTests {
Pair.create(serviceType, network);
if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1)) {
return mockServiceTypeClientOne;
} else if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1_1)) {
return mockServiceTypeClientOne1;
} else if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_2)) {
return mockServiceTypeClientTwo;
} else if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_2_2)) {
return mockServiceTypeClientTwo2;
}
return null;
}
};
}
private void verifyListenerRegistration(String serviceType, MdnsServiceBrowserListener listener,
MdnsServiceTypeClient client) throws IOException {
private SocketCreationCallback expectSocketCreationCallback(String serviceType,
MdnsServiceBrowserListener listener, MdnsSearchOptions options) throws IOException {
final ArgumentCaptor<SocketCreationCallback> callbackCaptor =
ArgumentCaptor.forClass(SocketCreationCallback.class);
discoveryManager.registerListener(serviceType, listener,
MdnsSearchOptions.getDefaultOptions());
discoveryManager.registerListener(serviceType, listener, options);
verify(socketClient).startDiscovery();
verify(socketClient).notifyNetworkRequested(
eq(listener), any(), callbackCaptor.capture());
final SocketCreationCallback callback = callbackCaptor.getValue();
callback.onSocketCreated(null /* network */);
verify(client).startSendAndReceive(listener, MdnsSearchOptions.getDefaultOptions());
eq(listener), eq(options.getNetwork()), callbackCaptor.capture());
return callbackCaptor.getValue();
}
@Test
public void registerListener_unregisterListener() throws IOException {
verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
final MdnsSearchOptions options =
MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
final SocketCreationCallback callback = expectSocketCreationCallback(
SERVICE_TYPE_1, mockListenerOne, options);
callback.onSocketCreated(null /* network */);
verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options);
when(mockServiceTypeClientOne.stopSendAndReceive(mockListenerOne)).thenReturn(true);
discoveryManager.unregisterListener(SERVICE_TYPE_1, mockListenerOne);
@@ -118,30 +128,67 @@ public class MdnsDiscoveryManagerTests {
@Test
public void registerMultipleListeners() throws IOException {
verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
verifyListenerRegistration(SERVICE_TYPE_2, mockListenerTwo, mockServiceTypeClientTwo);
final MdnsSearchOptions options =
MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
final SocketCreationCallback callback = expectSocketCreationCallback(
SERVICE_TYPE_1, mockListenerOne, options);
callback.onSocketCreated(null /* network */);
verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options);
callback.onSocketCreated(NETWORK_1);
verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options);
final SocketCreationCallback callback2 = expectSocketCreationCallback(
SERVICE_TYPE_2, mockListenerTwo, options);
callback2.onSocketCreated(null /* network */);
verify(mockServiceTypeClientTwo).startSendAndReceive(mockListenerTwo, options);
callback2.onSocketCreated(NETWORK_2);
verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options);
}
@Test
public void onResponseReceived() throws IOException {
verifyListenerRegistration(SERVICE_TYPE_1, mockListenerOne, mockServiceTypeClientOne);
verifyListenerRegistration(SERVICE_TYPE_2, mockListenerTwo, mockServiceTypeClientTwo);
final MdnsSearchOptions options1 =
MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
final SocketCreationCallback callback = expectSocketCreationCallback(
SERVICE_TYPE_1, mockListenerOne, options1);
callback.onSocketCreated(null /* network */);
verify(mockServiceTypeClientOne).startSendAndReceive(mockListenerOne, options1);
callback.onSocketCreated(NETWORK_1);
verify(mockServiceTypeClientOne1).startSendAndReceive(mockListenerOne, options1);
MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
final MdnsSearchOptions options2 =
MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
final SocketCreationCallback callback2 = expectSocketCreationCallback(
SERVICE_TYPE_2, mockListenerTwo, options2);
callback2.onSocketCreated(NETWORK_2);
verify(mockServiceTypeClientTwo2).startSendAndReceive(mockListenerTwo, options2);
final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
final int ifIndex = 1;
discoveryManager.onResponseReceived(responseForServiceTypeOne, ifIndex, null /* network */);
verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeOne, ifIndex,
null /* network */);
MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
discoveryManager.onResponseReceived(responseForServiceTypeTwo, ifIndex, null /* network */);
verify(mockServiceTypeClientTwo).processResponse(responseForServiceTypeTwo, ifIndex,
verify(mockServiceTypeClientOne1).processResponse(responseForServiceTypeOne, ifIndex,
null /* network */);
verify(mockServiceTypeClientTwo2).processResponse(responseForServiceTypeOne, ifIndex,
null /* network */);
MdnsPacket responseForSubtype = createMdnsPacket("subtype._sub._googlecast._tcp.local");
discoveryManager.onResponseReceived(responseForSubtype, ifIndex, null /* network */);
verify(mockServiceTypeClientOne).processResponse(responseForSubtype, ifIndex,
null /* network */);
final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
discoveryManager.onResponseReceived(responseForServiceTypeTwo, ifIndex, NETWORK_1);
verify(mockServiceTypeClientOne).processResponse(responseForServiceTypeTwo, ifIndex,
NETWORK_1);
verify(mockServiceTypeClientOne1).processResponse(responseForServiceTypeTwo, ifIndex,
NETWORK_1);
verify(mockServiceTypeClientTwo2, never()).processResponse(responseForServiceTypeTwo,
ifIndex, NETWORK_1);
final MdnsPacket responseForSubtype =
createMdnsPacket("subtype._sub._googlecast._tcp.local");
discoveryManager.onResponseReceived(responseForSubtype, ifIndex, NETWORK_2);
verify(mockServiceTypeClientOne).processResponse(responseForSubtype, ifIndex, NETWORK_2);
verify(mockServiceTypeClientOne1, never()).processResponse(
responseForSubtype, ifIndex, NETWORK_2);
verify(mockServiceTypeClientTwo2).processResponse(responseForSubtype, ifIndex, NETWORK_2);
}
private MdnsPacket createMdnsPacket(String serviceType) {