Make a copy of TunUtils and PacketUtils

Temporarily make copies. Eventually will statically include the
source files in CtsIkeTestCases

Bug: 148689509
Test: atest CtsIkeTestCases
Change-Id: I7dd5c8b849f0d987fa6d76bc5a0bc1a7eed49b0d
Merged-In: I7dd5c8b849f0d987fa6d76bc5a0bc1a7eed49b0d
(cherry picked from commit 52716ec0e6)
This commit is contained in:
Yan Yan
2020-04-23 23:20:28 +00:00
parent 35b779fd5a
commit 2b2db7a57f
2 changed files with 731 additions and 0 deletions

View File

@@ -0,0 +1,467 @@
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.net.ipsec.ike.cts;
import static android.system.OsConstants.IPPROTO_IPV6;
import static android.system.OsConstants.IPPROTO_UDP;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.nio.ShortBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.Arrays;
import javax.crypto.Cipher;
import javax.crypto.Mac;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
/**
* This code is a exact copy of {@link PacketUtils} in
* cts/tests/tests/net/src/android/net/cts/PacketUtils.java.
*
* <p>TODO(b/148689509): Statically include the PacketUtils source file instead of copying it.
*/
public class PacketUtils {
private static final String TAG = PacketUtils.class.getSimpleName();
private static final int DATA_BUFFER_LEN = 4096;
static final int IP4_HDRLEN = 20;
static final int IP6_HDRLEN = 40;
static final int UDP_HDRLEN = 8;
static final int TCP_HDRLEN = 20;
static final int TCP_HDRLEN_WITH_TIMESTAMP_OPT = TCP_HDRLEN + 12;
// Not defined in OsConstants
static final int IPPROTO_IPV4 = 4;
static final int IPPROTO_ESP = 50;
// Encryption parameters
static final int AES_GCM_IV_LEN = 8;
static final int AES_CBC_IV_LEN = 16;
static final int AES_GCM_BLK_SIZE = 4;
static final int AES_CBC_BLK_SIZE = 16;
// Encryption algorithms
static final String AES = "AES";
static final String AES_CBC = "AES/CBC/NoPadding";
static final String HMAC_SHA_256 = "HmacSHA256";
public interface Payload {
byte[] getPacketBytes(IpHeader header) throws Exception;
void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception;
short length();
int getProtocolId();
}
public abstract static class IpHeader {
public final byte proto;
public final InetAddress srcAddr;
public final InetAddress dstAddr;
public final Payload payload;
public IpHeader(int proto, InetAddress src, InetAddress dst, Payload payload) {
this.proto = (byte) proto;
this.srcAddr = src;
this.dstAddr = dst;
this.payload = payload;
}
public abstract byte[] getPacketBytes() throws Exception;
public abstract int getProtocolId();
}
public static class Ip4Header extends IpHeader {
private short checksum;
public Ip4Header(int proto, Inet4Address src, Inet4Address dst, Payload payload) {
super(proto, src, dst, payload);
}
public byte[] getPacketBytes() throws Exception {
ByteBuffer resultBuffer = buildHeader();
payload.addPacketBytes(this, resultBuffer);
return getByteArrayFromBuffer(resultBuffer);
}
public ByteBuffer buildHeader() {
ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
// Version, IHL
bb.put((byte) (0x45));
// DCSP, ECN
bb.put((byte) 0);
// Total Length
bb.putShort((short) (IP4_HDRLEN + payload.length()));
// Empty for Identification, Flags and Fragment Offset
bb.putShort((short) 0);
bb.put((byte) 0x40);
bb.put((byte) 0x00);
// TTL
bb.put((byte) 64);
// Protocol
bb.put(proto);
// Header Checksum
final int ipChecksumOffset = bb.position();
bb.putShort((short) 0);
// Src/Dst addresses
bb.put(srcAddr.getAddress());
bb.put(dstAddr.getAddress());
bb.putShort(ipChecksumOffset, calculateChecksum(bb));
return bb;
}
private short calculateChecksum(ByteBuffer bb) {
int checksum = 0;
// Calculate sum of 16-bit values, excluding checksum. IPv4 headers are always 32-bit
// aligned, so no special cases needed for unaligned values.
ShortBuffer shortBuffer = ByteBuffer.wrap(getByteArrayFromBuffer(bb)).asShortBuffer();
while (shortBuffer.hasRemaining()) {
short val = shortBuffer.get();
// Wrap as needed
checksum = addAndWrapForChecksum(checksum, val);
}
return onesComplement(checksum);
}
public int getProtocolId() {
return IPPROTO_IPV4;
}
}
public static class Ip6Header extends IpHeader {
public Ip6Header(int nextHeader, Inet6Address src, Inet6Address dst, Payload payload) {
super(nextHeader, src, dst, payload);
}
public byte[] getPacketBytes() throws Exception {
ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
// Version | Traffic Class (First 4 bits)
bb.put((byte) 0x60);
// Traffic class (Last 4 bits), Flow Label
bb.put((byte) 0);
bb.put((byte) 0);
bb.put((byte) 0);
// Payload Length
bb.putShort((short) payload.length());
// Next Header
bb.put(proto);
// Hop Limit
bb.put((byte) 64);
// Src/Dst addresses
bb.put(srcAddr.getAddress());
bb.put(dstAddr.getAddress());
// Payload
payload.addPacketBytes(this, bb);
return getByteArrayFromBuffer(bb);
}
public int getProtocolId() {
return IPPROTO_IPV6;
}
}
public static class BytePayload implements Payload {
public final byte[] payload;
public BytePayload(byte[] payload) {
this.payload = payload;
}
public int getProtocolId() {
return -1;
}
public byte[] getPacketBytes(IpHeader header) {
ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
addPacketBytes(header, bb);
return getByteArrayFromBuffer(bb);
}
public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) {
resultBuffer.put(payload);
}
public short length() {
return (short) payload.length;
}
}
public static class UdpHeader implements Payload {
public final short srcPort;
public final short dstPort;
public final Payload payload;
public UdpHeader(int srcPort, int dstPort, Payload payload) {
this.srcPort = (short) srcPort;
this.dstPort = (short) dstPort;
this.payload = payload;
}
public int getProtocolId() {
return IPPROTO_UDP;
}
public short length() {
return (short) (payload.length() + 8);
}
public byte[] getPacketBytes(IpHeader header) throws Exception {
ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
addPacketBytes(header, bb);
return getByteArrayFromBuffer(bb);
}
public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception {
// Source, Destination port
resultBuffer.putShort(srcPort);
resultBuffer.putShort(dstPort);
// Payload Length
resultBuffer.putShort(length());
// Get payload bytes for checksum + payload
ByteBuffer payloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
payload.addPacketBytes(header, payloadBuffer);
byte[] payloadBytes = getByteArrayFromBuffer(payloadBuffer);
// Checksum
resultBuffer.putShort(calculateChecksum(header, payloadBytes));
// Payload
resultBuffer.put(payloadBytes);
}
private short calculateChecksum(IpHeader header, byte[] payloadBytes) throws Exception {
int newChecksum = 0;
ShortBuffer srcBuffer = ByteBuffer.wrap(header.srcAddr.getAddress()).asShortBuffer();
ShortBuffer dstBuffer = ByteBuffer.wrap(header.dstAddr.getAddress()).asShortBuffer();
while (srcBuffer.hasRemaining() || dstBuffer.hasRemaining()) {
short val = srcBuffer.hasRemaining() ? srcBuffer.get() : dstBuffer.get();
// Wrap as needed
newChecksum = addAndWrapForChecksum(newChecksum, val);
}
// Add pseudo-header values. Proto is 0-padded, so just use the byte.
newChecksum = addAndWrapForChecksum(newChecksum, header.proto);
newChecksum = addAndWrapForChecksum(newChecksum, length());
newChecksum = addAndWrapForChecksum(newChecksum, srcPort);
newChecksum = addAndWrapForChecksum(newChecksum, dstPort);
newChecksum = addAndWrapForChecksum(newChecksum, length());
ShortBuffer payloadShortBuffer = ByteBuffer.wrap(payloadBytes).asShortBuffer();
while (payloadShortBuffer.hasRemaining()) {
newChecksum = addAndWrapForChecksum(newChecksum, payloadShortBuffer.get());
}
if (payload.length() % 2 != 0) {
newChecksum =
addAndWrapForChecksum(
newChecksum, (payloadBytes[payloadBytes.length - 1] << 8));
}
return onesComplement(newChecksum);
}
}
public static class EspHeader implements Payload {
public final int nextHeader;
public final int spi;
public final int seqNum;
public final byte[] key;
public final byte[] payload;
/**
* Generic constructor for ESP headers.
*
* <p>For Tunnel mode, payload will be a full IP header + attached payloads
*
* <p>For Transport mode, payload will be only the attached payloads, but with the checksum
* calculated using the pre-encryption IP header
*/
public EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload) {
this.nextHeader = nextHeader;
this.spi = spi;
this.seqNum = seqNum;
this.key = key;
this.payload = payload;
}
public int getProtocolId() {
return IPPROTO_ESP;
}
public short length() {
// ALWAYS uses AES-CBC, HMAC-SHA256 (128b trunc len)
return (short)
calculateEspPacketSize(payload.length, AES_CBC_IV_LEN, AES_CBC_BLK_SIZE, 128);
}
public byte[] getPacketBytes(IpHeader header) throws Exception {
ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
addPacketBytes(header, bb);
return getByteArrayFromBuffer(bb);
}
public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception {
ByteBuffer espPayloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
espPayloadBuffer.putInt(spi);
espPayloadBuffer.putInt(seqNum);
espPayloadBuffer.put(getCiphertext(key));
espPayloadBuffer.put(getIcv(getByteArrayFromBuffer(espPayloadBuffer)), 0, 16);
resultBuffer.put(getByteArrayFromBuffer(espPayloadBuffer));
}
private byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
Mac sha256HMAC = Mac.getInstance(HMAC_SHA_256);
SecretKeySpec authKey = new SecretKeySpec(key, HMAC_SHA_256);
sha256HMAC.init(authKey);
return sha256HMAC.doFinal(authenticatedSection);
}
/**
* Encrypts and builds ciphertext block. Includes the IV, Padding and Next-Header blocks
*
* <p>The ciphertext does NOT include the SPI/Sequence numbers, or the ICV.
*/
private byte[] getCiphertext(byte[] key) throws GeneralSecurityException {
int paddedLen = calculateEspEncryptedLength(payload.length, AES_CBC_BLK_SIZE);
ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen);
paddedPayload.put(payload);
// Add padding - consecutive integers from 0x01
int pad = 1;
while (paddedPayload.position() < paddedPayload.limit()) {
paddedPayload.put((byte) pad++);
}
paddedPayload.position(paddedPayload.limit() - 2);
paddedPayload.put((byte) (paddedLen - 2 - payload.length)); // Pad length
paddedPayload.put((byte) nextHeader);
// Generate Initialization Vector
byte[] iv = new byte[AES_CBC_IV_LEN];
new SecureRandom().nextBytes(iv);
IvParameterSpec ivParameterSpec = new IvParameterSpec(iv);
SecretKeySpec secretKeySpec = new SecretKeySpec(key, AES);
// Encrypt payload
Cipher cipher = Cipher.getInstance(AES_CBC);
cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
byte[] encrypted = cipher.doFinal(getByteArrayFromBuffer(paddedPayload));
// Build ciphertext
ByteBuffer cipherText = ByteBuffer.allocate(AES_CBC_IV_LEN + encrypted.length);
cipherText.put(iv);
cipherText.put(encrypted);
return getByteArrayFromBuffer(cipherText);
}
}
private static int addAndWrapForChecksum(int currentChecksum, int value) {
currentChecksum += value & 0x0000ffff;
// Wrap anything beyond the first 16 bits, and add to lower order bits
return (currentChecksum >>> 16) + (currentChecksum & 0x0000ffff);
}
private static short onesComplement(int val) {
val = (val >>> 16) + (val & 0xffff);
if (val == 0) return 0;
return (short) ((~val) & 0xffff);
}
public static int calculateEspPacketSize(
int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) {
final int ESP_HDRLEN = 4 + 4; // SPI + Seq#
final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length
payloadLen += cryptIvLength; // Initialization Vector
// Align to block size of encryption algorithm
payloadLen = calculateEspEncryptedLength(payloadLen, cryptBlockSize);
return payloadLen + ESP_HDRLEN + ICV_LEN;
}
private static int calculateEspEncryptedLength(int payloadLen, int cryptBlockSize) {
payloadLen += 2; // ESP trailer
// Align to block size of encryption algorithm
return payloadLen + calculateEspPadLen(payloadLen, cryptBlockSize);
}
private static int calculateEspPadLen(int payloadLen, int cryptBlockSize) {
return (cryptBlockSize - (payloadLen % cryptBlockSize)) % cryptBlockSize;
}
private static byte[] getByteArrayFromBuffer(ByteBuffer buffer) {
return Arrays.copyOfRange(buffer.array(), 0, buffer.position());
}
/*
* Debug printing
*/
private static final char[] hexArray = "0123456789ABCDEF".toCharArray();
public static String bytesToHex(byte[] bytes) {
StringBuilder sb = new StringBuilder();
for (byte b : bytes) {
sb.append(hexArray[b >>> 4]);
sb.append(hexArray[b & 0x0F]);
sb.append(' ');
}
return sb.toString();
}
}

View File

@@ -0,0 +1,264 @@
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.net.ipsec.ike.cts;
import static android.net.ipsec.ike.cts.PacketUtils.IP4_HDRLEN;
import static android.net.ipsec.ike.cts.PacketUtils.IP6_HDRLEN;
import static android.net.ipsec.ike.cts.PacketUtils.IPPROTO_ESP;
import static android.net.ipsec.ike.cts.PacketUtils.UDP_HDRLEN;
import static android.system.OsConstants.IPPROTO_UDP;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import android.os.ParcelFileDescriptor;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
/**
* This code is a exact copy of {@link TunUtils} in
* cts/tests/tests/net/src/android/net/cts/TunUtils.java, except the import path of PacketUtils is
* the path to the copy of PacktUtils.
*
* <p>TODO(b/148689509): Statically include the TunUtils source file instead of copying it.
*/
public class TunUtils {
private static final String TAG = TunUtils.class.getSimpleName();
private static final int DATA_BUFFER_LEN = 4096;
private static final int TIMEOUT = 100;
private static final int IP4_PROTO_OFFSET = 9;
private static final int IP6_PROTO_OFFSET = 6;
private static final int IP4_ADDR_OFFSET = 12;
private static final int IP4_ADDR_LEN = 4;
private static final int IP6_ADDR_OFFSET = 8;
private static final int IP6_ADDR_LEN = 16;
private final ParcelFileDescriptor mTunFd;
private final List<byte[]> mPackets = new ArrayList<>();
private final Thread mReaderThread;
public TunUtils(ParcelFileDescriptor tunFd) {
mTunFd = tunFd;
// Start background reader thread
mReaderThread =
new Thread(
() -> {
try {
// Loop will exit and thread will quit when tunFd is closed.
// Receiving either EOF or an exception will exit this reader loop.
// FileInputStream in uninterruptable, so there's no good way to
// ensure that this thread shuts down except upon FD closure.
while (true) {
byte[] intercepted = receiveFromTun();
if (intercepted == null) {
// Exit once we've hit EOF
return;
} else if (intercepted.length > 0) {
// Only save packet if we've received any bytes.
synchronized (mPackets) {
mPackets.add(intercepted);
mPackets.notifyAll();
}
}
}
} catch (IOException ignored) {
// Simply exit this reader thread
return;
}
});
mReaderThread.start();
}
private byte[] receiveFromTun() throws IOException {
FileInputStream in = new FileInputStream(mTunFd.getFileDescriptor());
byte[] inBytes = new byte[DATA_BUFFER_LEN];
int bytesRead = in.read(inBytes);
if (bytesRead < 0) {
return null; // return null for EOF
} else if (bytesRead >= DATA_BUFFER_LEN) {
throw new IllegalStateException("Too big packet. Fragmentation unsupported");
}
return Arrays.copyOf(inBytes, bytesRead);
}
private byte[] getFirstMatchingPacket(Predicate<byte[]> verifier, int startIndex) {
synchronized (mPackets) {
for (int i = startIndex; i < mPackets.size(); i++) {
byte[] pkt = mPackets.get(i);
if (verifier.test(pkt)) {
return pkt;
}
}
}
return null;
}
/**
* Checks if the specified bytes were ever sent in plaintext.
*
* <p>Only checks for known plaintext bytes to prevent triggering on ICMP/RA packets or the like
*
* @param plaintext the plaintext bytes to check for
* @param startIndex the index in the list to check for
*/
public boolean hasPlaintextPacket(byte[] plaintext, int startIndex) {
Predicate<byte[]> verifier =
(pkt) -> {
return Collections.indexOfSubList(Arrays.asList(pkt), Arrays.asList(plaintext))
!= -1;
};
return getFirstMatchingPacket(verifier, startIndex) != null;
}
public byte[] getEspPacket(int spi, boolean encap, int startIndex) {
return getFirstMatchingPacket(
(pkt) -> {
return isEsp(pkt, spi, encap);
},
startIndex);
}
public byte[] awaitEspPacketNoPlaintext(
int spi, byte[] plaintext, boolean useEncap, int expectedPacketSize) throws Exception {
long endTime = System.currentTimeMillis() + TIMEOUT;
int startIndex = 0;
synchronized (mPackets) {
while (System.currentTimeMillis() < endTime) {
byte[] espPkt = getEspPacket(spi, useEncap, startIndex);
if (espPkt != null) {
// Validate packet size
assertEquals(expectedPacketSize, espPkt.length);
// Always check plaintext from start
assertFalse(hasPlaintextPacket(plaintext, 0));
return espPkt; // We've found the packet we're looking for.
}
startIndex = mPackets.size();
// Try to prevent waiting too long. If waitTimeout <= 0, we've already hit timeout
long waitTimeout = endTime - System.currentTimeMillis();
if (waitTimeout > 0) {
mPackets.wait(waitTimeout);
}
}
fail("No such ESP packet found with SPI " + spi);
}
return null;
}
private static boolean isSpiEqual(byte[] pkt, int espOffset, int spi) {
// Check SPI byte by byte.
return pkt[espOffset] == (byte) ((spi >>> 24) & 0xff)
&& pkt[espOffset + 1] == (byte) ((spi >>> 16) & 0xff)
&& pkt[espOffset + 2] == (byte) ((spi >>> 8) & 0xff)
&& pkt[espOffset + 3] == (byte) (spi & 0xff);
}
private static boolean isEsp(byte[] pkt, int spi, boolean encap) {
if (isIpv6(pkt)) {
// IPv6 UDP encap not supported by kernels; assume non-encap.
return pkt[IP6_PROTO_OFFSET] == IPPROTO_ESP && isSpiEqual(pkt, IP6_HDRLEN, spi);
} else {
// Use default IPv4 header length (assuming no options)
if (encap) {
return pkt[IP4_PROTO_OFFSET] == IPPROTO_UDP
&& isSpiEqual(pkt, IP4_HDRLEN + UDP_HDRLEN, spi);
} else {
return pkt[IP4_PROTO_OFFSET] == IPPROTO_ESP && isSpiEqual(pkt, IP4_HDRLEN, spi);
}
}
}
private static boolean isIpv6(byte[] pkt) {
// First nibble shows IP version. 0x60 for IPv6
return (pkt[0] & (byte) 0xF0) == (byte) 0x60;
}
private static byte[] getReflectedPacket(byte[] pkt) {
byte[] reflected = Arrays.copyOf(pkt, pkt.length);
if (isIpv6(pkt)) {
// Set reflected packet's dst to that of the original's src
System.arraycopy(
pkt, // src
IP6_ADDR_OFFSET + IP6_ADDR_LEN, // src offset
reflected, // dst
IP6_ADDR_OFFSET, // dst offset
IP6_ADDR_LEN); // len
// Set reflected packet's src IP to that of the original's dst IP
System.arraycopy(
pkt, // src
IP6_ADDR_OFFSET, // src offset
reflected, // dst
IP6_ADDR_OFFSET + IP6_ADDR_LEN, // dst offset
IP6_ADDR_LEN); // len
} else {
// Set reflected packet's dst to that of the original's src
System.arraycopy(
pkt, // src
IP4_ADDR_OFFSET + IP4_ADDR_LEN, // src offset
reflected, // dst
IP4_ADDR_OFFSET, // dst offset
IP4_ADDR_LEN); // len
// Set reflected packet's src IP to that of the original's dst IP
System.arraycopy(
pkt, // src
IP4_ADDR_OFFSET, // src offset
reflected, // dst
IP4_ADDR_OFFSET + IP4_ADDR_LEN, // dst offset
IP4_ADDR_LEN); // len
}
return reflected;
}
/** Takes all captured packets, flips the src/dst, and re-injects them. */
public void reflectPackets() throws IOException {
synchronized (mPackets) {
for (byte[] pkt : mPackets) {
injectPacket(getReflectedPacket(pkt));
}
}
}
public void injectPacket(byte[] pkt) throws IOException {
FileOutputStream out = new FileOutputStream(mTunFd.getFileDescriptor());
out.write(pkt);
out.flush();
}
/** Resets the intercepted packets. */
public void reset() throws IOException {
synchronized (mPackets) {
mPackets.clear();
}
}
}