Merge "Add support to update the registered service in place" into main

This commit is contained in:
Yuyang Huang
2023-12-02 08:49:23 +00:00
committed by Gerrit Code Review
9 changed files with 372 additions and 62 deletions

View File

@@ -91,6 +91,7 @@ import com.android.net.module.util.PermissionUtils;
import com.android.net.module.util.SharedLog; import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.ExecutorProvider; import com.android.server.connectivity.mdns.ExecutorProvider;
import com.android.server.connectivity.mdns.MdnsAdvertiser; import com.android.server.connectivity.mdns.MdnsAdvertiser;
import com.android.server.connectivity.mdns.MdnsAdvertisingOptions;
import com.android.server.connectivity.mdns.MdnsDiscoveryManager; import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
import com.android.server.connectivity.mdns.MdnsFeatureFlags; import com.android.server.connectivity.mdns.MdnsFeatureFlags;
import com.android.server.connectivity.mdns.MdnsInterfaceSocket; import com.android.server.connectivity.mdns.MdnsInterfaceSocket;
@@ -850,7 +851,9 @@ public class NsdService extends INsdManager.Stub {
// service type would generate service instance names like // service type would generate service instance names like
// Name._subtype._sub._type._tcp, which is incorrect // Name._subtype._sub._type._tcp, which is incorrect
// (it should be Name._type._tcp). // (it should be Name._type._tcp).
mAdvertiser.addService(transactionId, serviceInfo, typeSubtype.second); mAdvertiser.addOrUpdateService(transactionId, serviceInfo,
typeSubtype.second,
MdnsAdvertisingOptions.newBuilder().build());
storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo, storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo,
serviceInfo.getNetwork()); serviceInfo.getNetwork());
} else { } else {

View File

@@ -43,6 +43,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import java.util.function.BiPredicate; import java.util.function.BiPredicate;
import java.util.function.Consumer; import java.util.function.Consumer;
@@ -342,16 +343,16 @@ public class MdnsAdvertiser {
} }
/** /**
* Add a service. * Add a service to advertise.
* *
* Conflicts must be checked via {@link #getConflictingService} before attempting to add. * Conflicts must be checked via {@link #getConflictingService} before attempting to add.
*/ */
void addService(int id, Registration registration) { void addService(int id, @NonNull Registration registration) {
mPendingRegistrations.put(id, registration); mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) { for (int i = 0; i < mAdvertisers.size(); i++) {
try { try {
mAdvertisers.valueAt(i).addService( mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo(),
id, registration.getServiceInfo(), registration.getSubtype()); registration.getSubtype());
} catch (NameConflictException e) { } catch (NameConflictException e) {
mSharedLog.wtf("Name conflict adding services that should have unique names", mSharedLog.wtf("Name conflict adding services that should have unique names",
e); e);
@@ -359,6 +360,17 @@ public class MdnsAdvertiser {
} }
} }
/**
* Update an already registered service.
* The caller is expected to check that the service being updated doesn't change its name
*/
void updateService(int id, @NonNull Registration registration) {
mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) {
mAdvertisers.valueAt(i).updateService(id, registration.getSubtype());
}
}
void removeService(int id) { void removeService(int id) {
mPendingRegistrations.remove(id); mPendingRegistrations.remove(id);
for (int i = 0; i < mAdvertisers.size(); i++) { for (int i = 0; i < mAdvertisers.size(); i++) {
@@ -474,7 +486,8 @@ public class MdnsAdvertiser {
@NonNull @NonNull
private NsdServiceInfo mServiceInfo; private NsdServiceInfo mServiceInfo;
@Nullable @Nullable
private final String mSubtype; private String mSubtype;
int mConflictDuringProbingCount; int mConflictDuringProbingCount;
int mConflictAfterProbingCount; int mConflictAfterProbingCount;
@@ -484,6 +497,22 @@ public class MdnsAdvertiser {
this.mSubtype = subtype; this.mSubtype = subtype;
} }
/**
* Matches between the NsdServiceInfo in the Registration and the provided argument.
*/
public boolean matches(@Nullable NsdServiceInfo newInfo) {
return Objects.equals(newInfo.getServiceName(), mOriginalName) && Objects.equals(
newInfo.getServiceType(), mServiceInfo.getServiceType()) && Objects.equals(
newInfo.getNetwork(), mServiceInfo.getNetwork());
}
/**
* Update subType for the registration.
*/
public void updateSubtype(@Nullable String subtype) {
this.mSubtype = subtype;
}
/** /**
* Update the registration to use a different service name, after a conflict was found. * Update the registration to use a different service name, after a conflict was found.
* *
@@ -632,42 +661,68 @@ public class MdnsAdvertiser {
} }
/** /**
* Add a service to advertise. * Add or update a service to advertise.
*
* @param id A unique ID for the service. * @param id A unique ID for the service.
* @param service The service info to advertise. * @param service The service info to advertise.
* @param subtype An optional subtype to advertise the service with. * @param subtype An optional subtype to advertise the service with.
* @param advertisingOptions The advertising options.
*/ */
public void addService(int id, NsdServiceInfo service, @Nullable String subtype) { public void addOrUpdateService(int id, NsdServiceInfo service, @Nullable String subtype,
MdnsAdvertisingOptions advertisingOptions) {
checkThread(); checkThread();
if (mRegistrations.get(id) != null) { final Registration existingRegistration = mRegistrations.get(id);
final Network network = service.getNetwork();
Registration registration;
if (advertisingOptions.isOnlyUpdate()) {
if (existingRegistration == null) {
mSharedLog.e("Update non existing registration for " + service);
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
return;
}
if (!(existingRegistration.matches(service))) {
mSharedLog.e("Update request can only update subType, serviceInfo: " + service
+ ", existing serviceInfo: " + existingRegistration.getServiceInfo());
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
return;
}
mSharedLog.i("Update service " + service + " with ID " + id + " and subtype " + subtype
+ " advertisingOptions " + advertisingOptions);
registration = existingRegistration;
registration.updateSubtype(subtype);
} else {
if (existingRegistration != null) {
mSharedLog.e("Adding duplicate registration for " + service); mSharedLog.e("Adding duplicate registration for " + service);
// TODO (b/264986328): add a more specific error code // TODO (b/264986328): add a more specific error code
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR); mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
return; return;
} }
mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype
mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype); + " advertisingOptions " + advertisingOptions);
registration = new Registration(service, subtype);
final Network network = service.getNetwork();
final Registration registration = new Registration(service, subtype);
final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter; final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
if (network == null) { if (network == null) {
// If registering on all networks, no advertiser must have conflicts // If registering on all networks, no advertiser must have conflicts
checkConflictFilter = (net, adv) -> true; checkConflictFilter = (net, adv) -> true;
} else { } else {
// If registering on one network, the matching network advertiser and the one for all // If registering on one network, the matching network advertiser and the one
// networks must not have conflicts // for all networks must not have conflicts
checkConflictFilter = (net, adv) -> net == null || network.equals(net); checkConflictFilter = (net, adv) -> net == null || network.equals(net);
} }
updateRegistrationUntilNoConflict(checkConflictFilter, registration); updateRegistrationUntilNoConflict(checkConflictFilter, registration);
}
InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network); InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
if (advertiser == null) { if (advertiser == null) {
advertiser = new InterfaceAdvertiserRequest(network); advertiser = new InterfaceAdvertiserRequest(network);
mAdvertiserRequests.put(network, advertiser); mAdvertiserRequests.put(network, advertiser);
} }
if (advertisingOptions.isOnlyUpdate()) {
advertiser.updateService(id, registration);
} else {
advertiser.addService(id, registration); advertiser.addService(id, registration);
}
mRegistrations.put(id, registration); mRegistrations.put(id, registration);
} }

View File

@@ -0,0 +1,92 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
/**
* API configuration parameters for advertising the mDNS service.
*
* <p>Use {@link MdnsAdvertisingOptions.Builder} to create {@link MdnsAdvertisingOptions}.
*
* @hide
*/
public class MdnsAdvertisingOptions {
private static MdnsAdvertisingOptions sDefaultOptions;
private final boolean mIsOnlyUpdate;
/**
* Parcelable constructs for a {@link MdnsAdvertisingOptions}.
*/
MdnsAdvertisingOptions(
boolean isOnlyUpdate) {
this.mIsOnlyUpdate = isOnlyUpdate;
}
/**
* Returns a {@link Builder} for {@link MdnsAdvertisingOptions}.
*/
public static Builder newBuilder() {
return new Builder();
}
/**
* Returns a default search options.
*/
public static synchronized MdnsAdvertisingOptions getDefaultOptions() {
if (sDefaultOptions == null) {
sDefaultOptions = newBuilder().build();
}
return sDefaultOptions;
}
/**
* @return {@code true} if the advertising request is an update request.
*/
public boolean isOnlyUpdate() {
return mIsOnlyUpdate;
}
@Override
public String toString() {
return "MdnsAdvertisingOptions{" + "mIsOnlyUpdate=" + mIsOnlyUpdate + '}';
}
/**
* A builder to create {@link MdnsAdvertisingOptions}.
*/
public static final class Builder {
private boolean mIsOnlyUpdate = false;
private Builder() {
}
/**
* Sets if the advertising request is an update request.
*/
public Builder setIsOnlyUpdate(boolean isOnlyUpdate) {
this.mIsOnlyUpdate = isOnlyUpdate;
return this;
}
/**
* Builds a {@link MdnsAdvertisingOptions} with the arguments supplied to this builder.
*/
public MdnsAdvertisingOptions build() {
return new MdnsAdvertisingOptions(mIsOnlyUpdate);
}
}
}

View File

@@ -228,6 +228,18 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
mSocket.addPacketHandler(this); mSocket.addPacketHandler(this);
} }
/**
* Update an already registered service without sending exit/re-announcement packet.
*
* @param id An exiting service id
* @param subtype A new subtype
*/
public void updateService(int id, @Nullable String subtype) {
// The current implementation is intended to be used in cases where subtypes don't get
// announced.
mRecordRepository.updateService(id, subtype);
}
/** /**
* Start advertising a service. * Start advertising a service.
* *

View File

@@ -167,7 +167,7 @@ public class MdnsRecordRepository {
/** /**
* Whether the service is sending exit announcements and will be destroyed soon. * Whether the service is sending exit announcements and will be destroyed soon.
*/ */
public boolean exiting = false; public boolean exiting;
/** /**
* The replied query packet count of this service. * The replied query packet count of this service.
@@ -184,14 +184,21 @@ public class MdnsRecordRepository {
*/ */
private boolean isProbing; private boolean isProbing;
/**
* Create a ServiceRegistration with only update the subType
*/
ServiceRegistration withSubtype(String newSubType) {
return new ServiceRegistration(srvRecord.record.getServiceHost(), serviceInfo,
newSubType, repliedServiceCount, sentPacketCount, exiting, isProbing);
}
/** /**
* Create a ServiceRegistration for dns-sd service registration (RFC6763). * Create a ServiceRegistration for dns-sd service registration (RFC6763).
*
* @param deviceHostname Hostname of the device (for the interface used)
* @param serviceInfo Service to advertise
*/ */
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo, ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
@Nullable String subtype, int repliedServiceCount, int sentPacketCount) { @Nullable String subtype, int repliedServiceCount, int sentPacketCount,
boolean exiting, boolean isProbing) {
this.serviceInfo = serviceInfo; this.serviceInfo = serviceInfo;
this.subtype = subtype; this.subtype = subtype;
@@ -266,7 +273,20 @@ public class MdnsRecordRepository {
this.allRecords = Collections.unmodifiableList(allRecords); this.allRecords = Collections.unmodifiableList(allRecords);
this.repliedServiceCount = repliedServiceCount; this.repliedServiceCount = repliedServiceCount;
this.sentPacketCount = sentPacketCount; this.sentPacketCount = sentPacketCount;
this.isProbing = true; this.isProbing = isProbing;
this.exiting = exiting;
}
/**
* Create a ServiceRegistration for dns-sd service registration (RFC6763).
*
* @param deviceHostname Hostname of the device (for the interface used)
* @param serviceInfo Service to advertise
*/
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
@Nullable String subtype, int repliedServiceCount, int sentPacketCount) {
this(deviceHostname, serviceInfo, subtype, repliedServiceCount, sentPacketCount,
false /* exiting */, true /* isProbing */);
} }
void setProbing(boolean probing) { void setProbing(boolean probing) {
@@ -304,6 +324,24 @@ public class MdnsRecordRepository {
} }
} }
/**
* Update a service that already registered in the repository.
*
* @param serviceId An existing service ID.
* @param subtype A new subtype
* @return
*/
public void updateService(int serviceId, @Nullable String subtype) {
final ServiceRegistration existingRegistration = mServices.get(serviceId);
if (existingRegistration == null) {
throw new IllegalArgumentException(
"Service ID must already exist for an update request: " + serviceId);
}
final ServiceRegistration updatedRegistration = existingRegistration.withSubtype(
subtype);
mServices.put(serviceId, updatedRegistration);
}
/** /**
* Add a service to the repository. * Add a service to the repository.
* *

View File

@@ -1115,9 +1115,9 @@ public class NsdServiceTest {
final RegistrationListener regListener = mock(RegistrationListener.class); final RegistrationListener regListener = mock(RegistrationListener.class);
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mAdvertiser).addService(anyInt(), argThat(s -> verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(s ->
"Instance".equals(s.getServiceName()) "Instance".equals(s.getServiceName())
&& SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype")); && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"), any());
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener); client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
@@ -1222,8 +1222,8 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(serviceIdCaptor.capture(), verify(mAdvertiser).addOrUpdateService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo)), eq(null) /* subtype */); argThat(info -> matches(info, regInfo)), eq(null) /* subtype */, any());
client.unregisterService(regListenerWithoutFeature); client.unregisterService(regListenerWithoutFeature);
waitForIdle(); waitForIdle();
@@ -1282,10 +1282,10 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
// The advertiser is enabled for _type2 but not _type1 // The advertiser is enabled for _type2 but not _type1
verify(mAdvertiser, never()).addService( verify(mAdvertiser, never()).addOrUpdateService(anyInt(),
anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */); argThat(info -> matches(info, service1)), eq(null) /* subtype */, any());
verify(mAdvertiser).addService( verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(info -> matches(info, service2)),
anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */); eq(null) /* subtype */, any());
} }
@Test @Test
@@ -1309,8 +1309,8 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(idCaptor.capture(), argThat(info -> verify(mAdvertiser).addOrUpdateService(idCaptor.capture(), argThat(info ->
matches(info, regInfo)), eq(null) /* subtype */); matches(info, regInfo)), eq(null) /* subtype */, any());
// Verify onServiceRegistered callback // Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1358,7 +1358,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mAdvertiser, never()).addService(anyInt(), any(), any()); verify(mAdvertiser, never()).addOrUpdateService(anyInt(), any(), any(), any());
verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed( verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR)); argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1387,9 +1387,9 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
// Service name is truncated to 63 characters // Service name is truncated to 63 characters
verify(mAdvertiser).addService(idCaptor.capture(), verify(mAdvertiser).addOrUpdateService(idCaptor.capture(),
argThat(info -> info.getServiceName().equals("a".repeat(63))), argThat(info -> info.getServiceName().equals("a".repeat(63))),
eq(null) /* subtype */); eq(null) /* subtype */, any());
// Verify onServiceRegistered callback // Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1479,7 +1479,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any(), any()); verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
// Verify the discovery uses MdnsDiscoveryManager // Verify the discovery uses MdnsDiscoveryManager
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -1512,7 +1512,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any(), any()); verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
final Network wifiNetwork1 = new Network(123); final Network wifiNetwork1 = new Network(123);
final Network wifiNetwork2 = new Network(124); final Network wifiNetwork2 = new Network(124);

View File

@@ -19,6 +19,7 @@ package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress import android.net.LinkAddress
import android.net.Network import android.net.Network
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo import android.net.nsd.NsdServiceInfo
import android.net.nsd.OffloadEngine import android.net.nsd.OffloadEngine
import android.net.nsd.OffloadServiceInfo import android.net.nsd.OffloadServiceInfo
@@ -71,6 +72,7 @@ private val TEST_INTERFACE1 = "test_iface1"
private val TEST_INTERFACE2 = "test_iface2" private val TEST_INTERFACE2 = "test_iface2"
private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03) private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03)
private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04) private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04)
private val DEFAULT_ADVERTISING_OPTION = MdnsAdvertisingOptions.getDefaultOptions()
private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345 port = 12345
@@ -186,7 +188,8 @@ class MdnsAdvertiserTest {
fun testAddService_OneNetwork() { fun testAddService_OneNetwork() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -247,7 +250,8 @@ class MdnsAdvertiserTest {
fun testAddService_AllNetworks() { fun testAddService_AllNetworks() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
TEST_SUBTYPE, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network), verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -318,24 +322,27 @@ class MdnsAdvertiserTest {
fun testAddService_Conflicts() { fun testAddService_Conflicts() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
val oneNetSocketCb = oneNetSocketCbCaptor.value val oneNetSocketCb = oneNetSocketCbCaptor.value
// Register a service with the same name on all networks (name conflict) // Register a service with the same name on all networks (name conflict)
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value val allNetSocketCb = allNetSocketCbCaptor.value
postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_1, LONG_SERVICE_1,
postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE, null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
null /* subtype */) } postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.addService(CASE_INSENSITIVE_TEST_SERVICE_ID, ALL_NETWORKS_SERVICE_2, postSync { advertiser.addOrUpdateService(CASE_INSENSITIVE_TEST_SERVICE_ID,
null /* subtype */) } ALL_NETWORKS_SERVICE_2, null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
// Callbacks for matching network and all networks both get the socket // Callbacks for matching network and all networks both get the socket
postSync { postSync {
@@ -399,12 +406,52 @@ class MdnsAdvertiserTest {
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow() verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
} }
@Test
fun testAddOrUpdateService_Updates() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(null))
val updateOptions = MdnsAdvertisingOptions.newBuilder().setIsOnlyUpdate(true).build()
// Update with serviceId that is not registered yet should fail
postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
updateOptions) }
verify(cb).onRegisterServiceFailed(SERVICE_ID_2, NsdManager.FAILURE_INTERNAL_ERROR)
// Update service with different NsdServiceInfo should fail
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, TEST_SUBTYPE,
updateOptions) }
verify(cb).onRegisterServiceFailed(SERVICE_ID_1, NsdManager.FAILURE_INTERNAL_ERROR)
// Update service with same NsdServiceInfo but different subType should succeed
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
updateOptions) }
verify(mockInterfaceAdvertiser1).updateService(eq(SERVICE_ID_1), eq(TEST_SUBTYPE))
// Newly created MdnsInterfaceAdvertiser will get addService() call.
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR2)) }
verify(mockInterfaceAdvertiser2).addService(eq(SERVICE_ID_1),
argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(TEST_SUBTYPE))
}
@Test @Test
fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() { fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
verify(mockDeps, times(1)).generateHostname() verify(mockDeps, times(1)).generateHostname()
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.removeService(SERVICE_ID_1) } postSync { advertiser.removeService(SERVICE_ID_1) }
verify(mockDeps, times(2)).generateHostname() verify(mockDeps, times(2)).generateHostname()
} }

View File

@@ -48,6 +48,7 @@ import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq import org.mockito.Mockito.eq
import org.mockito.Mockito.mock import org.mockito.Mockito.mock
import org.mockito.Mockito.never
import org.mockito.Mockito.times import org.mockito.Mockito.times
import org.mockito.Mockito.verify import org.mockito.Mockito.verify
@@ -59,6 +60,7 @@ private val TEST_BUFFER = ByteArray(1300)
private val TEST_HOSTNAME = arrayOf("Android_test", "local") private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_DUPLICATE = 43
private val TEST_SERVICE_1 = NsdServiceInfo().apply { private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp" serviceType = "_testservice._tcp"
serviceName = "MyTestService" serviceName = "MyTestService"
@@ -272,6 +274,28 @@ class MdnsInterfaceAdvertiserTest {
verify(prober).restartForConflict(mockProbingInfo) verify(prober).restartForConflict(mockProbingInfo)
} }
@Test
fun testReplaceExitingService() {
doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
.addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
val subType = "_sub"
advertiser.addService(TEST_SERVICE_ID_DUPLICATE, TEST_SERVICE_1, subType)
verify(repository).addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
verify(announcer).stop(TEST_SERVICE_ID_DUPLICATE)
verify(prober).startProbing(any())
}
@Test
fun testUpdateExistingService() {
doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
.addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
val subType = "_sub"
advertiser.updateService(TEST_SERVICE_ID_DUPLICATE, subType)
verify(repository).updateService(eq(TEST_SERVICE_ID_DUPLICATE), any())
verify(announcer, never()).stop(TEST_SERVICE_ID_DUPLICATE)
verify(prober, never()).startProbing(any())
}
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo): private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo { AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java) val testProbingInfo = mock(ProbingInfo::class.java)

View File

@@ -129,7 +129,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testAddAndConflicts() { fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(NameConflictException::class) { assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
} }
@@ -138,6 +138,45 @@ class MdnsRecordRepositoryTest {
} }
} }
@Test
fun testAddAndUpdates() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(IllegalArgumentException::class) {
repository.updateService(TEST_SERVICE_ID_2, null /* subtype */)
}
repository.updateService(TEST_SERVICE_ID_1, TEST_SUBTYPE)
val queriedName = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
val questions = listOf(MdnsPointerRecord(queriedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
// TTL and data is empty for a question
0L /* ttlMillis */,
null /* pointer */))
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val reply = repository.getReply(query, src)
assertNotNull(reply)
// TTLs as per RFC6762 10.
val longTtl = 4_500_000L
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertEquals(listOf(
MdnsPointerRecord(
queriedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
longTtl,
serviceName),
), reply.answers)
}
@Test @Test
fun testInvalidReuseOfServiceId() { fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
@@ -758,7 +797,7 @@ class MdnsRecordRepositoryTest {
private fun MdnsRecordRepository.initWithService( private fun MdnsRecordRepository.initWithService(
serviceId: Int, serviceId: Int,
serviceInfo: NsdServiceInfo, serviceInfo: NsdServiceInfo,
subtype: String? = null subtype: String? = null,
): AnnouncementInfo { ): AnnouncementInfo {
updateAddresses(TEST_ADDRESSES) updateAddresses(TEST_ADDRESSES)
addService(serviceId, serviceInfo, subtype) addService(serviceId, serviceInfo, subtype)