diff --git a/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java b/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java index fbb342dd48..846abcb867 100644 --- a/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java +++ b/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java @@ -295,8 +295,7 @@ public class OffloadHardwareInterface { NF_NETLINK_CONNTRACK_NEW | NF_NETLINK_CONNTRACK_DESTROY); if (h1 == null) return false; - sendIpv4NfGenMsg(h1, (short) ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET), - (short) (NLM_F_REQUEST | NLM_F_DUMP)); + requestSocketDump(h1); final NativeHandle h2 = mDeps.createConntrackSocket( NF_NETLINK_CONNTRACK_UPDATE | NF_NETLINK_CONNTRACK_DESTROY); @@ -325,7 +324,7 @@ public class OffloadHardwareInterface { } @VisibleForTesting - public void sendIpv4NfGenMsg(@NonNull NativeHandle handle, short type, short flags) { + void sendIpv4NfGenMsg(@NonNull NativeHandle handle, short type, short flags) { final int length = StructNlMsgHdr.STRUCT_SIZE + StructNfGenMsg.STRUCT_SIZE; final byte[] msg = new byte[length]; final ByteBuffer byteBuffer = ByteBuffer.wrap(msg); @@ -350,6 +349,12 @@ public class OffloadHardwareInterface { } } + @VisibleForTesting + void requestSocketDump(NativeHandle handle) { + sendIpv4NfGenMsg(handle, (short) ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET), + (short) (NLM_F_REQUEST | NLM_F_DUMP)); + } + private void closeFdInNativeHandle(final NativeHandle h) { try { h.close(); diff --git a/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java b/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java index d38a7c3206..23fb60c6f0 100644 --- a/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java +++ b/Tethering/tests/privileged/src/com/android/networkstack/tethering/ConntrackSocketTest.java @@ -16,28 +16,32 @@ package com.android.networkstack.tethering; +import static android.system.OsConstants.EAGAIN; +import static android.system.OsConstants.IPPROTO_TCP; +import static android.system.OsConstants.NETLINK_NETFILTER; + import static com.android.net.module.util.netlink.NetlinkSocket.DEFAULT_RECV_BUFSIZE; -import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_DUMP; -import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST; -import static com.android.networkstack.tethering.OffloadHardwareInterface.IPCTNL_MSG_CT_GET; import static com.android.networkstack.tethering.OffloadHardwareInterface.IPCTNL_MSG_CT_NEW; import static com.android.networkstack.tethering.OffloadHardwareInterface.NFNL_SUBSYS_CTNETLINK; import static com.android.networkstack.tethering.OffloadHardwareInterface.NF_NETLINK_CONNTRACK_DESTROY; import static com.android.networkstack.tethering.OffloadHardwareInterface.NF_NETLINK_CONNTRACK_NEW; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import android.os.Handler; import android.os.HandlerThread; import android.os.Looper; import android.os.NativeHandle; -import android.system.Os; +import android.system.ErrnoException; +import android.util.Log; import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; import com.android.net.module.util.SharedLog; +import com.android.net.module.util.netlink.ConntrackMessage; +import com.android.net.module.util.netlink.NetlinkMessage; +import com.android.net.module.util.netlink.NetlinkSocket; import com.android.net.module.util.netlink.StructNlMsgHdr; import org.junit.Before; @@ -45,18 +49,18 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.MockitoAnnotations; +import java.io.FileDescriptor; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; -import java.net.SocketAddress; import java.nio.ByteBuffer; -import java.nio.ByteOrder; @RunWith(AndroidJUnit4.class) @SmallTest public class ConntrackSocketTest { private static final long TIMEOUT = 500; + private static final String TAG = ConntrackSocketTest.class.getSimpleName(); private HandlerThread mHandlerThread; private Handler mHandler; @@ -80,51 +84,72 @@ public class ConntrackSocketTest { mOffloadHw = new OffloadHardwareInterface(mHandler, mLog, mDeps); } + void findConnectionOrThrow(FileDescriptor fd, InetSocketAddress local, InetSocketAddress remote) + throws Exception { + Log.d(TAG, "Looking for socket " + local + " -> " + remote); + + // Loop until the socket is found (and return) or recvMessage throws an exception. + while (true) { + final ByteBuffer buffer = NetlinkSocket.recvMessage(fd, DEFAULT_RECV_BUFSIZE, TIMEOUT); + + // Parse all the netlink messages in the dump. + // NetlinkMessage#parse returns null if the message is truncated or invalid. + while (buffer.remaining() > 0) { + NetlinkMessage nlmsg = NetlinkMessage.parse(buffer, NETLINK_NETFILTER); + Log.d(TAG, "Got netlink message: " + nlmsg); + if (!(nlmsg instanceof ConntrackMessage)) { + continue; + } + + StructNlMsgHdr nlmsghdr = nlmsg.getHeader(); + ConntrackMessage ctmsg = (ConntrackMessage) nlmsg; + ConntrackMessage.Tuple tuple = ctmsg.tupleOrig; + + if (nlmsghdr.nlmsg_type == (NFNL_SUBSYS_CTNETLINK << 8 | IPCTNL_MSG_CT_NEW) + && tuple.protoNum == IPPROTO_TCP + && tuple.srcIp.equals(local.getAddress()) + && tuple.dstIp.equals(remote.getAddress()) + && tuple.srcPort == (short) local.getPort() + && tuple.dstPort == (short) remote.getPort()) { + return; + } + } + } + } + @Test public void testIpv4ConntrackSocket() throws Exception { // Set up server and connect. - final InetSocketAddress anyAddress = new InetSocketAddress( - InetAddress.getByName("127.0.0.1"), 0); + final InetAddress localhost = InetAddress.getByName("127.0.0.1"); + final InetSocketAddress anyAddress = new InetSocketAddress(localhost, 0); final ServerSocket serverSocket = new ServerSocket(); serverSocket.bind(anyAddress); - final SocketAddress theAddress = serverSocket.getLocalSocketAddress(); + final InetSocketAddress theAddress = + (InetSocketAddress) serverSocket.getLocalSocketAddress(); // Make a connection to the server. final Socket socket = new Socket(); socket.connect(theAddress); + final InetSocketAddress localAddress = (InetSocketAddress) socket.getLocalSocketAddress(); final Socket acceptedSocket = serverSocket.accept(); final NativeHandle handle = mDeps.createConntrackSocket( NF_NETLINK_CONNTRACK_NEW | NF_NETLINK_CONNTRACK_DESTROY); - mOffloadHw.sendIpv4NfGenMsg(handle, - (short) ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET), - (short) (NLM_F_REQUEST | NLM_F_DUMP)); - - boolean foundConntrackEntry = false; - ByteBuffer buffer = ByteBuffer.allocate(DEFAULT_RECV_BUFSIZE); - buffer.order(ByteOrder.nativeOrder()); + mOffloadHw.requestSocketDump(handle); try { - while (Os.read(handle.getFileDescriptor(), buffer) > 0) { - buffer.flip(); - - // TODO: ConntrackMessage should get a parse API like StructNlMsgHdr - // so we can confirm that the conntrack added is for the TCP connection above. - final StructNlMsgHdr nlmsghdr = StructNlMsgHdr.parse(buffer); - assertNotNull(nlmsghdr); - - // As long as 1 conntrack entry is found test case will pass, even if it's not - // the from the TCP connection above. - if (nlmsghdr.nlmsg_type == ((NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_NEW)) { - foundConntrackEntry = true; - break; - } + findConnectionOrThrow(handle.getFileDescriptor(), localAddress, theAddress); + // No exceptions? Socket was found, test passes. + } catch (ErrnoException e) { + if (e.errno == EAGAIN) { + fail("Did not find socket " + localAddress + "->" + theAddress + " in dump"); + } else { + throw e; } } finally { socket.close(); serverSocket.close(); + acceptedSocket.close(); } - assertTrue("Did not receive any NFNL_SUBSYS_CTNETLINK/IPCTNL_MSG_CT_NEW message", - foundConntrackEntry); } }