Merge "Add unit tests for probing"

This commit is contained in:
Remi NGUYEN VAN
2023-01-13 04:05:46 +00:00
committed by Gerrit Code Review
5 changed files with 313 additions and 9 deletions

View File

@@ -38,7 +38,8 @@ public class MdnsAnnouncer extends MdnsPacketRepeater<MdnsAnnouncer.Announcement
@NonNull @NonNull
private final String mLogTag; private final String mLogTag;
static class AnnouncementInfo implements MdnsPacketRepeater.Request { /** Announcement request to send with {@link MdnsAnnouncer}. */
public static class AnnouncementInfo implements MdnsPacketRepeater.Request {
@NonNull @NonNull
private final MdnsPacket mPacket; private final MdnsPacket mPacket;

View File

@@ -17,12 +17,16 @@
package com.android.server.connectivity.mdns; package com.android.server.connectivity.mdns;
import android.annotation.NonNull; import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.LinkAddress; import android.net.LinkAddress;
import android.net.nsd.NsdServiceInfo; import android.net.nsd.NsdServiceInfo;
import android.os.Handler; import android.os.Handler;
import android.os.Looper; import android.os.Looper;
import android.util.Log; import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@@ -31,7 +35,8 @@ import java.util.List;
*/ */
public class MdnsInterfaceAdvertiser { public class MdnsInterfaceAdvertiser {
private static final boolean DBG = MdnsAdvertiser.DBG; private static final boolean DBG = MdnsAdvertiser.DBG;
private static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L; @VisibleForTesting
public static final long EXIT_ANNOUNCEMENT_DELAY_MS = 100L;
@NonNull @NonNull
private final String mTag; private final String mTag;
@NonNull @NonNull
@@ -84,7 +89,7 @@ public class MdnsInterfaceAdvertiser {
* Callbacks from {@link MdnsProber}. * Callbacks from {@link MdnsProber}.
*/ */
private class ProbingCallback implements private class ProbingCallback implements
MdnsPacketRepeater.PacketRepeaterCallback<MdnsProber.ProbingInfo> { PacketRepeaterCallback<MdnsProber.ProbingInfo> {
@Override @Override
public void onFinished(MdnsProber.ProbingInfo info) { public void onFinished(MdnsProber.ProbingInfo info) {
final MdnsAnnouncer.AnnouncementInfo announcementInfo; final MdnsAnnouncer.AnnouncementInfo announcementInfo;
@@ -109,23 +114,64 @@ public class MdnsInterfaceAdvertiser {
* Callbacks from {@link MdnsAnnouncer}. * Callbacks from {@link MdnsAnnouncer}.
*/ */
private class AnnouncingCallback private class AnnouncingCallback
implements MdnsPacketRepeater.PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> { implements PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> {
// TODO: implement // TODO: implement
} }
/**
* Dependencies for {@link MdnsInterfaceAdvertiser}, useful for testing.
*/
@VisibleForTesting
public static class Dependencies {
/** @see MdnsRecordRepository */
@NonNull
public MdnsRecordRepository makeRecordRepository(@NonNull Looper looper) {
return new MdnsRecordRepository(looper);
}
/** @see MdnsReplySender */
@NonNull
public MdnsReplySender makeReplySender(@NonNull Looper looper,
@NonNull MdnsInterfaceSocket socket, @NonNull byte[] packetCreationBuffer) {
return new MdnsReplySender(looper, socket, packetCreationBuffer);
}
/** @see MdnsAnnouncer */
public MdnsAnnouncer makeMdnsAnnouncer(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
@Nullable PacketRepeaterCallback<MdnsAnnouncer.AnnouncementInfo> cb) {
return new MdnsAnnouncer(interfaceTag, looper, replySender, cb);
}
/** @see MdnsProber */
public MdnsProber makeMdnsProber(@NonNull String interfaceTag, @NonNull Looper looper,
@NonNull MdnsReplySender replySender,
@NonNull PacketRepeaterCallback<MdnsProber.ProbingInfo> cb) {
return new MdnsProber(interfaceTag, looper, replySender, cb);
}
}
public MdnsInterfaceAdvertiser(@NonNull String logTag, public MdnsInterfaceAdvertiser(@NonNull String logTag,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses, @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
@NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb) { @NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb) {
this(logTag, socket, initialAddresses, looper, packetCreationBuffer, cb,
new Dependencies());
}
public MdnsInterfaceAdvertiser(@NonNull String logTag,
@NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> initialAddresses,
@NonNull Looper looper, @NonNull byte[] packetCreationBuffer, @NonNull Callback cb,
@NonNull Dependencies deps) {
mTag = MdnsInterfaceAdvertiser.class.getSimpleName() + "/" + logTag; mTag = MdnsInterfaceAdvertiser.class.getSimpleName() + "/" + logTag;
mRecordRepository = new MdnsRecordRepository(looper); mRecordRepository = deps.makeRecordRepository(looper);
mRecordRepository.updateAddresses(initialAddresses); mRecordRepository.updateAddresses(initialAddresses);
mSocket = socket; mSocket = socket;
mCb = cb; mCb = cb;
mCbHandler = new Handler(looper); mCbHandler = new Handler(looper);
mReplySender = new MdnsReplySender(looper, socket, packetCreationBuffer); mReplySender = deps.makeReplySender(looper, socket, packetCreationBuffer);
mAnnouncer = new MdnsAnnouncer(logTag, looper, mReplySender, mAnnouncer = deps.makeMdnsAnnouncer(logTag, looper, mReplySender,
mAnnouncingCallback); mAnnouncingCallback);
mProber = new MdnsProber(logTag, looper, mReplySender, mProbingCallback); mProber = deps.makeMdnsProber(logTag, looper, mReplySender, mProbingCallback);
} }
/** /**

View File

@@ -44,7 +44,8 @@ public class MdnsProber extends MdnsPacketRepeater<MdnsProber.ProbingInfo> {
mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag; mLogTag = MdnsProber.class.getSimpleName() + "/" + interfaceTag;
} }
static class ProbingInfo implements Request { /** Probing request to send with {@link MdnsProber}. */
public static class ProbingInfo implements Request {
private final int mServiceId; private final int mServiceId;
@NonNull @NonNull

View File

@@ -0,0 +1,129 @@
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress
import android.net.LinkAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsAnnouncer.AnnouncementInfo
import com.android.server.connectivity.mdns.MdnsInterfaceAdvertiser.EXIT_ANNOUNCEMENT_DELAY_MS
import com.android.server.connectivity.mdns.MdnsPacketRepeater.PacketRepeaterCallback
import com.android.server.connectivity.mdns.MdnsProber.ProbingInfo
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.waitForIdle
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.ArgumentCaptor
import org.mockito.Mockito.any
import org.mockito.Mockito.anyInt
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.verify
private const val LOG_TAG = "testlogtag"
private const val TIMEOUT_MS = 10_000L
private val TEST_ADDRS = listOf(LinkAddress(parseNumericAddress("2001:db8::123"), 64))
private val TEST_BUFFER = ByteArray(1300)
private const val TEST_SERVICE_ID_1 = 42
private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
serviceName = "MyTestService"
port = 12345
}
@RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsInterfaceAdvertiserTest {
private val socket = mock(MdnsInterfaceSocket::class.java)
private val thread = HandlerThread(MdnsInterfaceAdvertiserTest::class.simpleName)
private val cb = mock(MdnsInterfaceAdvertiser.Callback::class.java)
private val deps = mock(MdnsInterfaceAdvertiser.Dependencies::class.java)
private val repository = mock(MdnsRecordRepository::class.java)
private val replySender = mock(MdnsReplySender::class.java)
private val announcer = mock(MdnsAnnouncer::class.java)
private val prober = mock(MdnsProber::class.java)
private val probeCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<ProbingInfo>>
private val announceCbCaptor = ArgumentCaptor.forClass(PacketRepeaterCallback::class.java)
as ArgumentCaptor<PacketRepeaterCallback<AnnouncementInfo>>
private val probeCb get() = probeCbCaptor.value
private val announceCb get() = announceCbCaptor.value
private val advertiser by lazy {
MdnsInterfaceAdvertiser(LOG_TAG, socket, TEST_ADDRS, thread.looper, TEST_BUFFER, cb, deps)
}
@Before
fun setUp() {
doReturn(repository).`when`(deps).makeRecordRepository(any())
doReturn(replySender).`when`(deps).makeReplySender(any(), any(), any())
doReturn(announcer).`when`(deps).makeMdnsAnnouncer(any(), any(), any(), any())
doReturn(prober).`when`(deps).makeMdnsProber(any(), any(), any(), any())
doReturn(-1).`when`(repository).addService(anyInt(), any())
thread.start()
advertiser.start()
verify(deps).makeMdnsProber(any(), any(), any(), probeCbCaptor.capture())
verify(deps).makeMdnsAnnouncer(any(), any(), any(), announceCbCaptor.capture())
}
@After
fun tearDown() {
thread.quitSafely()
}
@Test
fun testAddRemoveService() {
val testProbingInfo = mock(ProbingInfo::class.java)
doReturn(TEST_SERVICE_ID_1).`when`(testProbingInfo).serviceId
doReturn(testProbingInfo).`when`(repository).setServiceProbing(TEST_SERVICE_ID_1)
advertiser.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
verify(repository).addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
verify(prober).startProbing(testProbingInfo)
// Simulate probing success: continues to announcing
val testAnnouncementInfo = mock(AnnouncementInfo::class.java)
doReturn(testAnnouncementInfo).`when`(repository).onProbingSucceeded(testProbingInfo)
probeCb.onFinished(testProbingInfo)
verify(announcer).startSending(TEST_SERVICE_ID_1, testAnnouncementInfo,
0L /* initialDelayMs */)
thread.waitForIdle(TIMEOUT_MS)
verify(cb).onRegisterServiceSucceeded(advertiser, TEST_SERVICE_ID_1)
// Remove the service: expect exit announcements
val testExitInfo = mock(AnnouncementInfo::class.java)
doReturn(testExitInfo).`when`(repository).exitService(TEST_SERVICE_ID_1)
advertiser.removeService(TEST_SERVICE_ID_1)
verify(announcer).startSending(TEST_SERVICE_ID_1, testExitInfo, EXIT_ANNOUNCEMENT_DELAY_MS)
// TODO: after exit announcements are implemented, verify that announceCb.onFinished causes
// cb.onDestroyed to be called.
}
}

View File

@@ -0,0 +1,127 @@
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns
import android.net.InetAddresses.parseNumericAddress
import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.os.HandlerThread
import com.android.server.connectivity.mdns.MdnsRecordRepository.Dependencies
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRunner
import java.net.NetworkInterface
import java.util.Collections
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
private const val TEST_SERVICE_ID_1 = 42
private const val TEST_SERVICE_ID_2 = 43
private const val TEST_PORT = 12345
private val TEST_HOSTNAME = arrayOf("Android_000102030405060708090A0B0C0D0E0F", "local")
private val TEST_ADDRESSES = arrayOf(
parseNumericAddress("192.0.2.111"),
parseNumericAddress("2001:db8::111"),
parseNumericAddress("2001:db8::222"))
private val TEST_SERVICE_1 = NsdServiceInfo().apply {
serviceType = "_testservice._tcp"
serviceName = "MyTestService"
port = TEST_PORT
}
@RunWith(DevSdkIgnoreRunner::class)
@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
class MdnsRecordRepositoryTest {
private val thread = HandlerThread(MdnsRecordRepositoryTest::class.simpleName)
private val deps = object : Dependencies() {
override fun getHostname() = TEST_HOSTNAME
override fun getInterfaceInetAddresses(iface: NetworkInterface) =
Collections.enumeration(TEST_ADDRESSES.toList())
}
@Before
fun setUp() {
thread.start()
}
@After
fun tearDown() {
thread.quitSafely()
}
@Test
fun testAddServiceAndProbe() {
val repository = MdnsRecordRepository(thread.looper, deps)
assertEquals(0, repository.servicesCount)
assertEquals(-1, repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1))
assertEquals(1, repository.servicesCount)
val probingInfo = repository.setServiceProbing(TEST_SERVICE_ID_1)
assertNotNull(probingInfo)
assertTrue(repository.isProbing(TEST_SERVICE_ID_1))
assertEquals(TEST_SERVICE_ID_1, probingInfo.serviceId)
val packet = probingInfo.getPacket(0)
assertEquals(MdnsConstants.FLAGS_QUERY, packet.flags)
assertEquals(0, packet.answers.size)
assertEquals(0, packet.additionalRecords.size)
assertEquals(1, packet.questions.size)
val expectedName = arrayOf("MyTestService", "_testservice", "_tcp", "local")
assertEquals(MdnsAnyRecord(expectedName, false /* unicast */), packet.questions[0])
assertEquals(1, packet.authorityRecords.size)
assertEquals(MdnsServiceRecord(expectedName,
0L /* receiptTimeMillis */,
false /* cacheFlush */,
120_000L /* ttlMillis */,
0 /* servicePriority */, 0 /* serviceWeight */,
TEST_PORT, TEST_HOSTNAME), packet.authorityRecords[0])
assertContentEquals(intArrayOf(TEST_SERVICE_ID_1), repository.clearServices())
}
@Test
fun testAddAndConflicts() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
assertFailsWith(NameConflictException::class) {
repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1)
}
}
@Test
fun testExitingServiceReAdded() {
val repository = MdnsRecordRepository(thread.looper, deps)
repository.addService(TEST_SERVICE_ID_1, TEST_SERVICE_1)
repository.exitService(TEST_SERVICE_ID_1)
assertEquals(TEST_SERVICE_ID_1, repository.addService(TEST_SERVICE_ID_2, TEST_SERVICE_1))
assertEquals(1, repository.servicesCount)
repository.removeService(TEST_SERVICE_ID_2)
assertEquals(0, repository.servicesCount)
}
}