diff --git a/core/java/android/net/NetworkUtils.java b/core/java/android/net/NetworkUtils.java index 333fcc67bc..b24d3969f7 100644 --- a/core/java/android/net/NetworkUtils.java +++ b/core/java/android/net/NetworkUtils.java @@ -103,6 +103,11 @@ public class NetworkUtils { */ public native static String getDhcpError(); + /** + * Set the SO_MARK of {@code socketfd} to {@code mark} + */ + public native static void markSocket(int socketfd, int mark); + /** * Convert a IPv4 address from an integer to an InetAddress. * @param hostAddress an int corresponding to the IPv4 address in network byte order diff --git a/core/jni/android_net_NetUtils.cpp b/core/jni/android_net_NetUtils.cpp index faae11ec3b..526159f4f1 100644 --- a/core/jni/android_net_NetUtils.cpp +++ b/core/jni/android_net_NetUtils.cpp @@ -17,6 +17,7 @@ #define LOG_TAG "NetUtils" #include "jni.h" +#include "JNIHelp.h" #include #include #include @@ -239,6 +240,13 @@ static jstring android_net_utils_getDhcpError(JNIEnv* env, jobject clazz) return env->NewStringUTF(::dhcp_get_errmsg()); } +static void android_net_utils_markSocket(JNIEnv *env, jobject thiz, jint socket, jint mark) +{ + if (setsockopt(socket, SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) { + jniThrowException(env, "java/lang/IllegalStateException", "Error marking socket"); + } +} + // ---------------------------------------------------------------------------- /* @@ -255,6 +263,7 @@ static JNINativeMethod gNetworkUtilMethods[] = { { "stopDhcp", "(Ljava/lang/String;)Z", (void *)android_net_utils_stopDhcp }, { "releaseDhcpLease", "(Ljava/lang/String;)Z", (void *)android_net_utils_releaseDhcpLease }, { "getDhcpError", "()Ljava/lang/String;", (void*) android_net_utils_getDhcpError }, + { "markSocket", "(II)V", (void*) android_net_utils_markSocket }, }; int register_android_net_NetworkUtils(JNIEnv* env) diff --git a/services/java/com/android/server/ConnectivityService.java b/services/java/com/android/server/ConnectivityService.java index b148b91219..a6344cafab 100644 --- a/services/java/com/android/server/ConnectivityService.java +++ b/services/java/com/android/server/ConnectivityService.java @@ -97,6 +97,7 @@ import android.telephony.TelephonyManager; import android.text.TextUtils; import android.util.Slog; import android.util.SparseIntArray; +import android.util.SparseArray; import com.android.internal.R; import com.android.internal.net.LegacyVpnInfo; @@ -116,6 +117,8 @@ import com.android.server.net.LockdownVpnTracker; import com.google.android.collect.Lists; import com.google.android.collect.Sets; +import com.android.internal.annotations.GuardedBy; + import dalvik.system.DexClassLoader; import java.io.FileDescriptor; @@ -171,7 +174,8 @@ public class ConnectivityService extends IConnectivityManager.Stub { private KeyStore mKeyStore; - private Vpn mVpn; + @GuardedBy("mVpns") + private final SparseArray mVpns = new SparseArray(); private VpnCallback mVpnCallback = new VpnCallback(); private boolean mLockdownEnabled; @@ -583,10 +587,13 @@ public class ConnectivityService extends IConnectivityManager.Stub { mTethering.getTetherableWifiRegexs().length != 0 || mTethering.getTetherableBluetoothRegexs().length != 0) && mTethering.getUpstreamIfaceTypes().length != 0); + //set up the listener for user state for creating user VPNs - mVpn = new Vpn(mContext, mVpnCallback, mNetd, this); - mVpn.startMonitoring(mContext, mTrackerHandler); - + IntentFilter intentFilter = new IntentFilter(); + intentFilter.addAction(Intent.ACTION_USER_STARTING); + intentFilter.addAction(Intent.ACTION_USER_STOPPING); + mContext.registerReceiverAsUser( + mUserIntentReceiver, UserHandle.ALL, intentFilter, null, null); mClat = new Nat464Xlat(mContext, mNetd, this, mTrackerHandler); try { @@ -2313,7 +2320,11 @@ public class ConnectivityService extends IConnectivityManager.Stub { // Tell VPN the interface is down. It is a temporary // but effective fix to make VPN aware of the change. if ((resetMask & NetworkUtils.RESET_IPV4_ADDRESSES) != 0) { - mVpn.interfaceStatusChanged(iface, false); + synchronized(mVpns) { + for (int i = 0; i < mVpns.size(); i++) { + mVpns.valueAt(i).interfaceStatusChanged(iface, false); + } + } } } if (resetDns) { @@ -2570,7 +2581,6 @@ public class ConnectivityService extends IConnectivityManager.Stub { try { mNetd.setDnsServersForInterface(iface, NetworkUtils.makeStrings(dnses), domains); - mNetd.setDefaultInterfaceForDns(iface); for (InetAddress dns : dnses) { ++last; String key = "net.dns" + last; @@ -3305,8 +3315,12 @@ public class ConnectivityService extends IConnectivityManager.Stub { throwIfLockdownEnabled(); try { int type = mActiveDefaultNetwork; + int user = UserHandle.getUserId(Binder.getCallingUid()); if (ConnectivityManager.isNetworkTypeValid(type) && mNetTrackers[type] != null) { - mVpn.protect(socket, mNetTrackers[type].getLinkProperties().getInterfaceName()); + synchronized(mVpns) { + mVpns.get(user).protect(socket, + mNetTrackers[type].getLinkProperties().getInterfaceName()); + } return true; } } catch (Exception e) { @@ -3330,7 +3344,10 @@ public class ConnectivityService extends IConnectivityManager.Stub { @Override public boolean prepareVpn(String oldPackage, String newPackage) { throwIfLockdownEnabled(); - return mVpn.prepare(oldPackage, newPackage); + int user = UserHandle.getUserId(Binder.getCallingUid()); + synchronized(mVpns) { + return mVpns.get(user).prepare(oldPackage, newPackage); + } } /** @@ -3343,7 +3360,10 @@ public class ConnectivityService extends IConnectivityManager.Stub { @Override public ParcelFileDescriptor establishVpn(VpnConfig config) { throwIfLockdownEnabled(); - return mVpn.establish(config); + int user = UserHandle.getUserId(Binder.getCallingUid()); + synchronized(mVpns) { + return mVpns.get(user).establish(config); + } } /** @@ -3357,7 +3377,10 @@ public class ConnectivityService extends IConnectivityManager.Stub { if (egress == null) { throw new IllegalStateException("Missing active network connection"); } - mVpn.startLegacyVpn(profile, mKeyStore, egress); + int user = UserHandle.getUserId(Binder.getCallingUid()); + synchronized(mVpns) { + mVpns.get(user).startLegacyVpn(profile, mKeyStore, egress); + } } /** @@ -3369,7 +3392,10 @@ public class ConnectivityService extends IConnectivityManager.Stub { @Override public LegacyVpnInfo getLegacyVpnInfo() { throwIfLockdownEnabled(); - return mVpn.getLegacyVpnInfo(); + int user = UserHandle.getUserId(Binder.getCallingUid()); + synchronized(mVpns) { + return mVpns.get(user).getLegacyVpnInfo(); + } } /** @@ -3390,7 +3416,7 @@ public class ConnectivityService extends IConnectivityManager.Stub { mHandler.obtainMessage(EVENT_VPN_STATE_CHANGED, info).sendToTarget(); } - public void override(List dnsServers, List searchDomains) { + public void override(String iface, List dnsServers, List searchDomains) { if (dnsServers == null) { restore(); return; @@ -3422,7 +3448,7 @@ public class ConnectivityService extends IConnectivityManager.Stub { // Apply DNS changes. synchronized (mDnsLock) { - updateDnsLocked("VPN", "VPN", addresses, domains); + updateDnsLocked("VPN", iface, addresses, domains); mDnsOverridden = true; } @@ -3451,6 +3477,67 @@ public class ConnectivityService extends IConnectivityManager.Stub { } } } + + public void protect(ParcelFileDescriptor socket) { + try { + final int mark = mNetd.getMarkForProtect(); + NetworkUtils.markSocket(socket.getFd(), mark); + } catch (RemoteException e) { + } + } + + public void setRoutes(String interfaze, List routes) { + for (RouteInfo route : routes) { + try { + mNetd.setMarkedForwardingRoute(interfaze, route); + } catch (RemoteException e) { + } + } + } + + public void setMarkedForwarding(String interfaze) { + try { + mNetd.setMarkedForwarding(interfaze); + } catch (RemoteException e) { + } + } + + public void clearMarkedForwarding(String interfaze) { + try { + mNetd.clearMarkedForwarding(interfaze); + } catch (RemoteException e) { + } + } + + public void addUserForwarding(String interfaze, int uid) { + int uidStart = uid * UserHandle.PER_USER_RANGE; + int uidEnd = uidStart + UserHandle.PER_USER_RANGE - 1; + addUidForwarding(interfaze, uidStart, uidEnd); + } + + public void clearUserForwarding(String interfaze, int uid) { + int uidStart = uid * UserHandle.PER_USER_RANGE; + int uidEnd = uidStart + UserHandle.PER_USER_RANGE - 1; + clearUidForwarding(interfaze, uidStart, uidEnd); + } + + public void addUidForwarding(String interfaze, int uidStart, int uidEnd) { + try { + mNetd.setUidRangeRoute(interfaze,uidStart, uidEnd); + mNetd.setDnsInterfaceForUidRange(interfaze, uidStart, uidEnd); + } catch (RemoteException e) { + } + + } + + public void clearUidForwarding(String interfaze, int uidStart, int uidEnd) { + try { + mNetd.clearUidRangeRoute(interfaze, uidStart, uidEnd); + mNetd.clearDnsInterfaceForUidRange(uidStart, uidEnd); + } catch (RemoteException e) { + } + + } } @Override @@ -3471,7 +3558,11 @@ public class ConnectivityService extends IConnectivityManager.Stub { final String profileName = new String(mKeyStore.get(Credentials.LOCKDOWN_VPN)); final VpnProfile profile = VpnProfile.decode( profileName, mKeyStore.get(Credentials.VPN + profileName)); - setLockdownTracker(new LockdownVpnTracker(mContext, mNetd, this, mVpn, profile)); + int user = UserHandle.getUserId(Binder.getCallingUid()); + synchronized(mVpns) { + setLockdownTracker(new LockdownVpnTracker(mContext, mNetd, this, mVpns.get(user), + profile)); + } } else { setLockdownTracker(null); } @@ -4002,4 +4093,43 @@ public class ConnectivityService extends IConnectivityManager.Stub { return url; } + + private void onUserStart(int userId) { + synchronized(mVpns) { + Vpn userVpn = mVpns.get(userId); + if (userVpn != null) { + loge("Starting user already has a VPN"); + return; + } + userVpn = new Vpn(mContext, mVpnCallback, mNetd, this, userId); + mVpns.put(userId, userVpn); + userVpn.startMonitoring(mContext, mTrackerHandler); + } + } + + private void onUserStop(int userId) { + synchronized(mVpns) { + Vpn userVpn = mVpns.get(userId); + if (userVpn == null) { + loge("Stopping user has no VPN"); + return; + } + mVpns.delete(userId); + } + } + + private BroadcastReceiver mUserIntentReceiver = new BroadcastReceiver() { + @Override + public void onReceive(Context context, Intent intent) { + final String action = intent.getAction(); + final int userId = intent.getIntExtra(Intent.EXTRA_USER_HANDLE, UserHandle.USER_NULL); + if (userId == UserHandle.USER_NULL) return; + + if (Intent.ACTION_USER_STARTING.equals(action)) { + onUserStart(userId); + } else if (Intent.ACTION_USER_STOPPING.equals(action)) { + onUserStop(userId); + } + } + }; }