Add support to update the registered service in place

The current implementation of MdnsAdvertiser doesn't support updating an
existing registration. For an update request, the client needs to
unregister it first, which will trigger an exit message and then
register again, which will trigger an announcement message. There are
some clients that don't want to trigger the exit and announcement
message every time. This CL adds the API to support that use case.

Bug: 309372239
Test: TH
Change-Id: Iabe69a987a11104090082e01969e7595f05504e8
This commit is contained in:
Yuyang Huang
2023-11-02 18:05:47 +09:00
parent 5db2089717
commit e5cba9cb87
9 changed files with 372 additions and 62 deletions

View File

@@ -90,6 +90,7 @@ import com.android.net.module.util.PermissionUtils;
import com.android.net.module.util.SharedLog; import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.ExecutorProvider; import com.android.server.connectivity.mdns.ExecutorProvider;
import com.android.server.connectivity.mdns.MdnsAdvertiser; import com.android.server.connectivity.mdns.MdnsAdvertiser;
import com.android.server.connectivity.mdns.MdnsAdvertisingOptions;
import com.android.server.connectivity.mdns.MdnsDiscoveryManager; import com.android.server.connectivity.mdns.MdnsDiscoveryManager;
import com.android.server.connectivity.mdns.MdnsFeatureFlags; import com.android.server.connectivity.mdns.MdnsFeatureFlags;
import com.android.server.connectivity.mdns.MdnsInterfaceSocket; import com.android.server.connectivity.mdns.MdnsInterfaceSocket;
@@ -849,7 +850,9 @@ public class NsdService extends INsdManager.Stub {
// service type would generate service instance names like // service type would generate service instance names like
// Name._subtype._sub._type._tcp, which is incorrect // Name._subtype._sub._type._tcp, which is incorrect
// (it should be Name._type._tcp). // (it should be Name._type._tcp).
mAdvertiser.addService(transactionId, serviceInfo, typeSubtype.second); mAdvertiser.addOrUpdateService(transactionId, serviceInfo,
typeSubtype.second,
MdnsAdvertisingOptions.newBuilder().build());
storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo, storeAdvertiserRequestMap(clientRequestId, transactionId, clientInfo,
serviceInfo.getNetwork()); serviceInfo.getNetwork());
} else { } else {

View File

@@ -43,6 +43,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import java.util.function.BiPredicate; import java.util.function.BiPredicate;
import java.util.function.Consumer; import java.util.function.Consumer;
@@ -342,16 +343,16 @@ public class MdnsAdvertiser {
} }
/** /**
* Add a service. * Add a service to advertise.
* *
* Conflicts must be checked via {@link #getConflictingService} before attempting to add. * Conflicts must be checked via {@link #getConflictingService} before attempting to add.
*/ */
void addService(int id, Registration registration) { void addService(int id, @NonNull Registration registration) {
mPendingRegistrations.put(id, registration); mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) { for (int i = 0; i < mAdvertisers.size(); i++) {
try { try {
mAdvertisers.valueAt(i).addService( mAdvertisers.valueAt(i).addService(id, registration.getServiceInfo(),
id, registration.getServiceInfo(), registration.getSubtype()); registration.getSubtype());
} catch (NameConflictException e) { } catch (NameConflictException e) {
mSharedLog.wtf("Name conflict adding services that should have unique names", mSharedLog.wtf("Name conflict adding services that should have unique names",
e); e);
@@ -359,6 +360,17 @@ public class MdnsAdvertiser {
} }
} }
/**
* Update an already registered service.
* The caller is expected to check that the service being updated doesn't change its name
*/
void updateService(int id, @NonNull Registration registration) {
mPendingRegistrations.put(id, registration);
for (int i = 0; i < mAdvertisers.size(); i++) {
mAdvertisers.valueAt(i).updateService(id, registration.getSubtype());
}
}
void removeService(int id) { void removeService(int id) {
mPendingRegistrations.remove(id); mPendingRegistrations.remove(id);
for (int i = 0; i < mAdvertisers.size(); i++) { for (int i = 0; i < mAdvertisers.size(); i++) {
@@ -474,7 +486,8 @@ public class MdnsAdvertiser {
@NonNull @NonNull
private NsdServiceInfo mServiceInfo; private NsdServiceInfo mServiceInfo;
@Nullable @Nullable
private final String mSubtype; private String mSubtype;
int mConflictDuringProbingCount; int mConflictDuringProbingCount;
int mConflictAfterProbingCount; int mConflictAfterProbingCount;
@@ -484,6 +497,22 @@ public class MdnsAdvertiser {
this.mSubtype = subtype; this.mSubtype = subtype;
} }
/**
* Matches between the NsdServiceInfo in the Registration and the provided argument.
*/
public boolean matches(@Nullable NsdServiceInfo newInfo) {
return Objects.equals(newInfo.getServiceName(), mOriginalName) && Objects.equals(
newInfo.getServiceType(), mServiceInfo.getServiceType()) && Objects.equals(
newInfo.getNetwork(), mServiceInfo.getNetwork());
}
/**
* Update subType for the registration.
*/
public void updateSubtype(@Nullable String subtype) {
this.mSubtype = subtype;
}
/** /**
* Update the registration to use a different service name, after a conflict was found. * Update the registration to use a different service name, after a conflict was found.
* *
@@ -632,42 +661,68 @@ public class MdnsAdvertiser {
} }
/** /**
* Add a service to advertise. * Add or update a service to advertise.
*
* @param id A unique ID for the service. * @param id A unique ID for the service.
* @param service The service info to advertise. * @param service The service info to advertise.
* @param subtype An optional subtype to advertise the service with. * @param subtype An optional subtype to advertise the service with.
* @param advertisingOptions The advertising options.
*/ */
public void addService(int id, NsdServiceInfo service, @Nullable String subtype) { public void addOrUpdateService(int id, NsdServiceInfo service, @Nullable String subtype,
MdnsAdvertisingOptions advertisingOptions) {
checkThread(); checkThread();
if (mRegistrations.get(id) != null) { final Registration existingRegistration = mRegistrations.get(id);
mSharedLog.e("Adding duplicate registration for " + service);
// TODO (b/264986328): add a more specific error code
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
return;
}
mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype);
final Network network = service.getNetwork(); final Network network = service.getNetwork();
final Registration registration = new Registration(service, subtype); Registration registration;
final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter; if (advertisingOptions.isOnlyUpdate()) {
if (network == null) { if (existingRegistration == null) {
// If registering on all networks, no advertiser must have conflicts mSharedLog.e("Update non existing registration for " + service);
checkConflictFilter = (net, adv) -> true; mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
} else { return;
// If registering on one network, the matching network advertiser and the one for all }
// networks must not have conflicts if (!(existingRegistration.matches(service))) {
checkConflictFilter = (net, adv) -> net == null || network.equals(net); mSharedLog.e("Update request can only update subType, serviceInfo: " + service
} + ", existing serviceInfo: " + existingRegistration.getServiceInfo());
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
return;
updateRegistrationUntilNoConflict(checkConflictFilter, registration); }
mSharedLog.i("Update service " + service + " with ID " + id + " and subtype " + subtype
+ " advertisingOptions " + advertisingOptions);
registration = existingRegistration;
registration.updateSubtype(subtype);
} else {
if (existingRegistration != null) {
mSharedLog.e("Adding duplicate registration for " + service);
// TODO (b/264986328): add a more specific error code
mCb.onRegisterServiceFailed(id, NsdManager.FAILURE_INTERNAL_ERROR);
return;
}
mSharedLog.i("Adding service " + service + " with ID " + id + " and subtype " + subtype
+ " advertisingOptions " + advertisingOptions);
registration = new Registration(service, subtype);
final BiPredicate<Network, InterfaceAdvertiserRequest> checkConflictFilter;
if (network == null) {
// If registering on all networks, no advertiser must have conflicts
checkConflictFilter = (net, adv) -> true;
} else {
// If registering on one network, the matching network advertiser and the one
// for all networks must not have conflicts
checkConflictFilter = (net, adv) -> net == null || network.equals(net);
}
updateRegistrationUntilNoConflict(checkConflictFilter, registration);
}
InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network); InterfaceAdvertiserRequest advertiser = mAdvertiserRequests.get(network);
if (advertiser == null) { if (advertiser == null) {
advertiser = new InterfaceAdvertiserRequest(network); advertiser = new InterfaceAdvertiserRequest(network);
mAdvertiserRequests.put(network, advertiser); mAdvertiserRequests.put(network, advertiser);
} }
advertiser.addService(id, registration); if (advertisingOptions.isOnlyUpdate()) {
advertiser.updateService(id, registration);
} else {
advertiser.addService(id, registration);
}
mRegistrations.put(id, registration); mRegistrations.put(id, registration);
} }

View File

@@ -0,0 +1,92 @@
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
/**
* API configuration parameters for advertising the mDNS service.
*
* <p>Use {@link MdnsAdvertisingOptions.Builder} to create {@link MdnsAdvertisingOptions}.
*
* @hide
*/
public class MdnsAdvertisingOptions {
private static MdnsAdvertisingOptions sDefaultOptions;
private final boolean mIsOnlyUpdate;
/**
* Parcelable constructs for a {@link MdnsAdvertisingOptions}.
*/
MdnsAdvertisingOptions(
boolean isOnlyUpdate) {
this.mIsOnlyUpdate = isOnlyUpdate;
}
/**
* Returns a {@link Builder} for {@link MdnsAdvertisingOptions}.
*/
public static Builder newBuilder() {
return new Builder();
}
/**
* Returns a default search options.
*/
public static synchronized MdnsAdvertisingOptions getDefaultOptions() {
if (sDefaultOptions == null) {
sDefaultOptions = newBuilder().build();
}
return sDefaultOptions;
}
/**
* @return {@code true} if the advertising request is an update request.
*/
public boolean isOnlyUpdate() {
return mIsOnlyUpdate;
}
@Override
public String toString() {
return "MdnsAdvertisingOptions{" + "mIsOnlyUpdate=" + mIsOnlyUpdate + '}';
}
/**
* A builder to create {@link MdnsAdvertisingOptions}.
*/
public static final class Builder {
private boolean mIsOnlyUpdate = false;
private Builder() {
}
/**
* Sets if the advertising request is an update request.
*/
public Builder setIsOnlyUpdate(boolean isOnlyUpdate) {
this.mIsOnlyUpdate = isOnlyUpdate;
return this;
}
/**
* Builds a {@link MdnsAdvertisingOptions} with the arguments supplied to this builder.
*/
public MdnsAdvertisingOptions build() {
return new MdnsAdvertisingOptions(mIsOnlyUpdate);
}
}
}

View File

@@ -228,6 +228,18 @@ public class MdnsInterfaceAdvertiser implements MulticastPacketReader.PacketHand
mSocket.addPacketHandler(this); mSocket.addPacketHandler(this);
} }
/**
* Update an already registered service without sending exit/re-announcement packet.
*
* @param id An exiting service id
* @param subtype A new subtype
*/
public void updateService(int id, @Nullable String subtype) {
// The current implementation is intended to be used in cases where subtypes don't get
// announced.
mRecordRepository.updateService(id, subtype);
}
/** /**
* Start advertising a service. * Start advertising a service.
* *

View File

@@ -167,7 +167,7 @@ public class MdnsRecordRepository {
/** /**
* Whether the service is sending exit announcements and will be destroyed soon. * Whether the service is sending exit announcements and will be destroyed soon.
*/ */
public boolean exiting = false; public boolean exiting;
/** /**
* The replied query packet count of this service. * The replied query packet count of this service.
@@ -184,14 +184,21 @@ public class MdnsRecordRepository {
*/ */
private boolean isProbing; private boolean isProbing;
/**
* Create a ServiceRegistration with only update the subType
*/
ServiceRegistration withSubtype(String newSubType) {
return new ServiceRegistration(srvRecord.record.getServiceHost(), serviceInfo,
newSubType, repliedServiceCount, sentPacketCount, exiting, isProbing);
}
/** /**
* Create a ServiceRegistration for dns-sd service registration (RFC6763). * Create a ServiceRegistration for dns-sd service registration (RFC6763).
*
* @param deviceHostname Hostname of the device (for the interface used)
* @param serviceInfo Service to advertise
*/ */
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo, ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
@Nullable String subtype, int repliedServiceCount, int sentPacketCount) { @Nullable String subtype, int repliedServiceCount, int sentPacketCount,
boolean exiting, boolean isProbing) {
this.serviceInfo = serviceInfo; this.serviceInfo = serviceInfo;
this.subtype = subtype; this.subtype = subtype;
@@ -266,7 +273,20 @@ public class MdnsRecordRepository {
this.allRecords = Collections.unmodifiableList(allRecords); this.allRecords = Collections.unmodifiableList(allRecords);
this.repliedServiceCount = repliedServiceCount; this.repliedServiceCount = repliedServiceCount;
this.sentPacketCount = sentPacketCount; this.sentPacketCount = sentPacketCount;
this.isProbing = true; this.isProbing = isProbing;
this.exiting = exiting;
}
/**
* Create a ServiceRegistration for dns-sd service registration (RFC6763).
*
* @param deviceHostname Hostname of the device (for the interface used)
* @param serviceInfo Service to advertise
*/
ServiceRegistration(@NonNull String[] deviceHostname, @NonNull NsdServiceInfo serviceInfo,
@Nullable String subtype, int repliedServiceCount, int sentPacketCount) {
this(deviceHostname, serviceInfo, subtype, repliedServiceCount, sentPacketCount,
false /* exiting */, true /* isProbing */);
} }
void setProbing(boolean probing) { void setProbing(boolean probing) {
@@ -304,6 +324,24 @@ public class MdnsRecordRepository {
} }
} }
/**
* Update a service that already registered in the repository.
*
* @param serviceId An existing service ID.
* @param subtype A new subtype
* @return
*/
public void updateService(int serviceId, @Nullable String subtype) {
final ServiceRegistration existingRegistration = mServices.get(serviceId);
if (existingRegistration == null) {
throw new IllegalArgumentException(
"Service ID must already exist for an update request: " + serviceId);
}
final ServiceRegistration updatedRegistration = existingRegistration.withSubtype(
subtype);
mServices.put(serviceId, updatedRegistration);
}
/** /**
* Add a service to the repository. * Add a service to the repository.
* *

View File

@@ -1114,9 +1114,9 @@ public class NsdServiceTest {
final RegistrationListener regListener = mock(RegistrationListener.class); final RegistrationListener regListener = mock(RegistrationListener.class);
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mAdvertiser).addService(anyInt(), argThat(s -> verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(s ->
"Instance".equals(s.getServiceName()) "Instance".equals(s.getServiceName())
&& SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype")); && SERVICE_TYPE.equals(s.getServiceType())), eq("_subtype"), any());
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener); client.discoverServices(typeWithSubtype, PROTOCOL, network, Runnable::run, discListener);
@@ -1221,8 +1221,8 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(serviceIdCaptor.capture(), verify(mAdvertiser).addOrUpdateService(serviceIdCaptor.capture(),
argThat(info -> matches(info, regInfo)), eq(null) /* subtype */); argThat(info -> matches(info, regInfo)), eq(null) /* subtype */, any());
client.unregisterService(regListenerWithoutFeature); client.unregisterService(regListenerWithoutFeature);
waitForIdle(); waitForIdle();
@@ -1281,10 +1281,10 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
// The advertiser is enabled for _type2 but not _type1 // The advertiser is enabled for _type2 but not _type1
verify(mAdvertiser, never()).addService( verify(mAdvertiser, never()).addOrUpdateService(anyInt(),
anyInt(), argThat(info -> matches(info, service1)), eq(null) /* subtype */); argThat(info -> matches(info, service1)), eq(null) /* subtype */, any());
verify(mAdvertiser).addService( verify(mAdvertiser).addOrUpdateService(anyInt(), argThat(info -> matches(info, service2)),
anyInt(), argThat(info -> matches(info, service2)), eq(null) /* subtype */); eq(null) /* subtype */, any());
} }
@Test @Test
@@ -1308,8 +1308,8 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(mAdvertiser).addService(idCaptor.capture(), argThat(info -> verify(mAdvertiser).addOrUpdateService(idCaptor.capture(), argThat(info ->
matches(info, regInfo)), eq(null) /* subtype */); matches(info, regInfo)), eq(null) /* subtype */, any());
// Verify onServiceRegistered callback // Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1357,7 +1357,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mAdvertiser, never()).addService(anyInt(), any(), any()); verify(mAdvertiser, never()).addOrUpdateService(anyInt(), any(), any(), any());
verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed( verify(regListener, timeout(TIMEOUT_MS)).onRegistrationFailed(
argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR)); argThat(info -> matches(info, regInfo)), eq(FAILURE_INTERNAL_ERROR));
@@ -1386,9 +1386,9 @@ public class NsdServiceTest {
waitForIdle(); waitForIdle();
final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class); final ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
// Service name is truncated to 63 characters // Service name is truncated to 63 characters
verify(mAdvertiser).addService(idCaptor.capture(), verify(mAdvertiser).addOrUpdateService(idCaptor.capture(),
argThat(info -> info.getServiceName().equals("a".repeat(63))), argThat(info -> info.getServiceName().equals("a".repeat(63))),
eq(null) /* subtype */); eq(null) /* subtype */, any());
// Verify onServiceRegistered callback // Verify onServiceRegistered callback
final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue(); final MdnsAdvertiser.AdvertiserCallback cb = cbCaptor.getValue();
@@ -1478,7 +1478,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any(), any()); verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
// Verify the discovery uses MdnsDiscoveryManager // Verify the discovery uses MdnsDiscoveryManager
final DiscoveryListener discListener = mock(DiscoveryListener.class); final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -1511,7 +1511,7 @@ public class NsdServiceTest {
client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener); client.registerService(regInfo, NsdManager.PROTOCOL_DNS_SD, Runnable::run, regListener);
waitForIdle(); waitForIdle();
verify(mSocketProvider).startMonitoringSockets(); verify(mSocketProvider).startMonitoringSockets();
verify(mAdvertiser).addService(anyInt(), any(), any()); verify(mAdvertiser).addOrUpdateService(anyInt(), any(), any(), any());
final Network wifiNetwork1 = new Network(123); final Network wifiNetwork1 = new Network(123);
final Network wifiNetwork2 = new Network(124); final Network wifiNetwork2 = new Network(124);

View File

@@ -19,6 +19,7 @@ package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress import android.net.LinkAddress
import android.net.Network import android.net.Network
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo import android.net.nsd.NsdServiceInfo
import android.net.nsd.OffloadEngine import android.net.nsd.OffloadEngine
import android.net.nsd.OffloadServiceInfo import android.net.nsd.OffloadServiceInfo
@@ -71,6 +72,7 @@ private val TEST_INTERFACE1 = "test_iface1"
private val TEST_INTERFACE2 = "test_iface2" private val TEST_INTERFACE2 = "test_iface2"
private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03) private val TEST_OFFLOAD_PACKET1 = byteArrayOf(0x01, 0x02, 0x03)
private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04) private val TEST_OFFLOAD_PACKET2 = byteArrayOf(0x02, 0x03, 0x04)
private val DEFAULT_ADVERTISING_OPTION = MdnsAdvertisingOptions.getDefaultOptions()
private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply { private val SERVICE_1 = NsdServiceInfo("TestServiceName", "_advertisertest._tcp").apply {
port = 12345 port = 12345
@@ -186,7 +188,8 @@ class MdnsAdvertiserTest {
fun testAddService_OneNetwork() { fun testAddService_OneNetwork() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
@@ -247,7 +250,8 @@ class MdnsAdvertiserTest {
fun testAddService_AllNetworks() { fun testAddService_AllNetworks() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
TEST_SUBTYPE, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network), verify(socketProvider).requestSocket(eq(ALL_NETWORKS_SERVICE.network),
@@ -318,24 +322,27 @@ class MdnsAdvertiserTest {
fun testAddService_Conflicts() { fun testAddService_Conflicts() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val oneNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), oneNetSocketCbCaptor.capture())
val oneNetSocketCb = oneNetSocketCbCaptor.value val oneNetSocketCb = oneNetSocketCbCaptor.value
// Register a service with the same name on all networks (name conflict) // Register a service with the same name on all networks (name conflict)
postSync { advertiser.addService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java) val allNetSocketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture()) verify(socketProvider).requestSocket(eq(null), allNetSocketCbCaptor.capture())
val allNetSocketCb = allNetSocketCbCaptor.value val allNetSocketCb = allNetSocketCbCaptor.value
postSync { advertiser.addService(LONG_SERVICE_ID_1, LONG_SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_1, LONG_SERVICE_1,
postSync { advertiser.addService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE, null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
null /* subtype */) } postSync { advertiser.addOrUpdateService(LONG_SERVICE_ID_2, LONG_ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.addService(CASE_INSENSITIVE_TEST_SERVICE_ID, ALL_NETWORKS_SERVICE_2, postSync { advertiser.addOrUpdateService(CASE_INSENSITIVE_TEST_SERVICE_ID,
null /* subtype */) } ALL_NETWORKS_SERVICE_2, null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
// Callbacks for matching network and all networks both get the socket // Callbacks for matching network and all networks both get the socket
postSync { postSync {
@@ -399,12 +406,52 @@ class MdnsAdvertiserTest {
verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow() verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
} }
@Test
fun testAddOrUpdateService_Updates() {
val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
val socketCbCaptor = ArgumentCaptor.forClass(SocketCallback::class.java)
verify(socketProvider).requestSocket(eq(null), socketCbCaptor.capture())
val socketCb = socketCbCaptor.value
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
verify(mockInterfaceAdvertiser1).addService(eq(SERVICE_ID_1),
argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(null))
val updateOptions = MdnsAdvertisingOptions.newBuilder().setIsOnlyUpdate(true).build()
// Update with serviceId that is not registered yet should fail
postSync { advertiser.addOrUpdateService(SERVICE_ID_2, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
updateOptions) }
verify(cb).onRegisterServiceFailed(SERVICE_ID_2, NsdManager.FAILURE_INTERNAL_ERROR)
// Update service with different NsdServiceInfo should fail
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1, TEST_SUBTYPE,
updateOptions) }
verify(cb).onRegisterServiceFailed(SERVICE_ID_1, NsdManager.FAILURE_INTERNAL_ERROR)
// Update service with same NsdServiceInfo but different subType should succeed
postSync { advertiser.addOrUpdateService(SERVICE_ID_1, ALL_NETWORKS_SERVICE, TEST_SUBTYPE,
updateOptions) }
verify(mockInterfaceAdvertiser1).updateService(eq(SERVICE_ID_1), eq(TEST_SUBTYPE))
// Newly created MdnsInterfaceAdvertiser will get addService() call.
postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR2)) }
verify(mockInterfaceAdvertiser2).addService(eq(SERVICE_ID_1),
argThat { it.matches(ALL_NETWORKS_SERVICE) }, eq(TEST_SUBTYPE))
}
@Test @Test
fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() { fun testRemoveService_whenAllServiceRemoved_thenUpdateHostName() {
val advertiser = val advertiser =
MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags) MdnsAdvertiser(thread.looper, socketProvider, cb, mockDeps, sharedlog, flags)
verify(mockDeps, times(1)).generateHostname() verify(mockDeps, times(1)).generateHostname()
postSync { advertiser.addService(SERVICE_ID_1, SERVICE_1, null /* subtype */) } postSync { advertiser.addOrUpdateService(SERVICE_ID_1, SERVICE_1,
null /* subtype */, DEFAULT_ADVERTISING_OPTION) }
postSync { advertiser.removeService(SERVICE_ID_1) } postSync { advertiser.removeService(SERVICE_ID_1) }
verify(mockDeps, times(2)).generateHostname() verify(mockDeps, times(2)).generateHostname()
} }

View File

@@ -48,6 +48,7 @@ import org.mockito.Mockito.doAnswer
import org.mockito.Mockito.doReturn import org.mockito.Mockito.doReturn
import org.mockito.Mockito.eq import org.mockito.Mockito.eq
import org.mockito.Mockito.mock import org.mockito.Mockito.mock
import org.mockito.Mockito.never
import org.mockito.Mockito.times import org.mockito.Mockito.times
import org.mockito.Mockito.verify import org.mockito.Mockito.verify
@@ -59,6 +60,7 @@ private val TEST_BUFFER = ByteArray(1300)
private val TEST_HOSTNAME = arrayOf("Android_test", "local") private val TEST_HOSTNAME = arrayOf("Android_test", "local")
private const val TEST_SERVICE_ID_1 = 42 private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_DUPLICATE = 43
private val TEST_SERVICE_1 = NsdServiceInfo().apply { private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp" serviceType = "_testservice._tcp"
serviceName = "MyTestService" serviceName = "MyTestService"
@@ -272,6 +274,28 @@ class MdnsInterfaceAdvertiserTest {
verify(prober).restartForConflict(mockProbingInfo) verify(prober).restartForConflict(mockProbingInfo)
} }
@Test
fun testReplaceExitingService() {
doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
.addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
val subType = "_sub"
advertiser.addService(TEST_SERVICE_ID_DUPLICATE, TEST_SERVICE_1, subType)
verify(repository).addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
verify(announcer).stop(TEST_SERVICE_ID_DUPLICATE)
verify(prober).startProbing(any())
}
@Test
fun testUpdateExistingService() {
doReturn(TEST_SERVICE_ID_DUPLICATE).`when`(repository)
.addService(eq(TEST_SERVICE_ID_DUPLICATE), any(), any())
val subType = "_sub"
advertiser.updateService(TEST_SERVICE_ID_DUPLICATE, subType)
verify(repository).updateService(eq(TEST_SERVICE_ID_DUPLICATE), any())
verify(announcer, never()).stop(TEST_SERVICE_ID_DUPLICATE)
verify(prober, never()).startProbing(any())
}
private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo): private fun addServiceAndFinishProbing(serviceId: Int, serviceInfo: NsdServiceInfo):
AnnouncementInfo { AnnouncementInfo {
val testProbingInfo = mock(ProbingInfo::class.java) val testProbingInfo = mock(ProbingInfo::class.java)

View File

@@ -129,7 +129,7 @@ class MdnsRecordRepositoryTest {
@Test @Test
fun testAddAndConflicts() { fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1, null /* subtype */) repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(NameConflictException::class) { assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */) repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1, null /* subtype */)
} }
@@ -138,6 +138,45 @@ class MdnsRecordRepositoryTest {
} }
} }
@Test
fun testAddAndUpdates() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
repository.initWithService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(IllegalArgumentException::class) {
repository.updateService(TEST_SERVICE_ID_2, null /* subtype */)
}
repository.updateService(TEST_SERVICE_ID_1, TEST_SUBTYPE)
val queriedName = arrayOf(TEST_SUBTYPE, "_sub", "_testservice", "_tcp", "local")
val questions = listOf(MdnsPointerRecord(queriedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
// TTL and data is empty for a question
0L /* ttlMillis */,
null /* pointer */))
val query = MdnsPacket(0 /* flags */, questions, listOf() /* answers */,
listOf() /* authorityRecords */, listOf() /* additionalRecords */)
val src = InetSocketAddress(parseNumericAddress("192.0.2.123"), 5353)
val reply = repository.getReply(query, src)
assertNotNull(reply)
// TTLs as per RFC6762 10.
val longTtl = 4_500_000L
val serviceName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertEquals(listOf(
MdnsPointerRecord(
queriedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
longTtl,
serviceName),
), reply.answers)
}
@Test @Test
fun testInvalidReuseOfServiceId() { fun testInvalidReuseOfServiceId() {
val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags) val repository = MdnsRecordRepository(thread.looper, deps, TEST_HOSTNAME, flags)
@@ -758,7 +797,7 @@ class MdnsRecordRepositoryTest {
private fun MdnsRecordRepository.initWithService( private fun MdnsRecordRepository.initWithService(
serviceId: Int, serviceId: Int,
serviceInfo: NsdServiceInfo, serviceInfo: NsdServiceInfo,
subtype: String? = null subtype: String? = null,
): AnnouncementInfo { ): AnnouncementInfo {
updateAddresses(TEST_ADDRESSES) updateAddresses(TEST_ADDRESSES)
addService(serviceId, serviceInfo, subtype) addService(serviceId, serviceInfo, subtype)