diff --git a/core/java/android/app/usage/NetworkStatsManager.java b/core/java/android/app/usage/NetworkStatsManager.java index 13aeef0058..2e3aca46f9 100644 --- a/core/java/android/app/usage/NetworkStatsManager.java +++ b/core/java/android/app/usage/NetworkStatsManager.java @@ -25,9 +25,15 @@ import android.net.ConnectivityManager; import android.net.DataUsageRequest; import android.net.NetworkIdentity; import android.net.NetworkTemplate; +import android.net.INetworkStatsService; +import android.os.Binder; import android.os.Build; +import android.os.Message; +import android.os.Messenger; import android.os.Handler; +import android.os.Looper; import android.os.RemoteException; +import android.os.ServiceManager; import android.util.Log; /** @@ -75,16 +81,26 @@ import android.util.Log; * not included. */ public class NetworkStatsManager { - private final static String TAG = "NetworkStatsManager"; + private static final String TAG = "NetworkStatsManager"; + private static final boolean DBG = false; + + /** @hide */ + public static final int CALLBACK_LIMIT_REACHED = 0; + /** @hide */ + public static final int CALLBACK_RELEASED = 1; private final Context mContext; + private final INetworkStatsService mService; /** * {@hide} */ public NetworkStatsManager(Context context) { mContext = context; + mService = INetworkStatsService.Stub.asInterface( + ServiceManager.getService(Context.NETWORK_STATS_SERVICE)); } + /** * Query network usage statistics summaries. Result is summarised data usage for the whole * device. Result is a single Bucket aggregated over time, state, uid, tag and roaming. This @@ -322,7 +338,40 @@ public class NetworkStatsManager { checkNotNull(policy, "DataUsagePolicy cannot be null"); checkNotNull(callback, "DataUsageCallback cannot be null"); - // TODO: Implement stub. + final Looper looper; + if (handler == null) { + looper = Looper.myLooper(); + } else { + looper = handler.getLooper(); + } + + if (DBG) Log.d(TAG, "registerDataUsageCallback called with " + policy); + + NetworkTemplate[] templates; + if (policy.subscriberIds == null || policy.subscriberIds.length == 0) { + templates = new NetworkTemplate[1]; + templates[0] = createTemplate(policy.networkType, null /* subscriberId */); + } else { + templates = new NetworkTemplate[policy.subscriberIds.length]; + for (int i = 0; i < policy.subscriberIds.length; i++) { + templates[i] = createTemplate(policy.networkType, policy.subscriberIds[i]); + } + } + DataUsageRequest request = new DataUsageRequest(DataUsageRequest.REQUEST_ID_UNSET, + templates, policy.uids, policy.thresholdInBytes); + try { + CallbackHandler callbackHandler = new CallbackHandler(looper, callback); + callback.request = mService.registerDataUsageCallback( + mContext.getOpPackageName(), request, new Messenger(callbackHandler), + new Binder()); + if (DBG) Log.d(TAG, "registerDataUsageCallback returned " + callback.request); + + if (callback.request == null) { + Log.e(TAG, "Request from callback is null; should not happen"); + } + } catch (RemoteException e) { + if (DBG) Log.d(TAG, "Remote exception when registering callback"); + } } /** @@ -331,9 +380,15 @@ public class NetworkStatsManager { * @param callback The {@link DataUsageCallback} used when registering. */ public void unregisterDataUsageCallback(DataUsageCallback callback) { - checkNotNull(callback, "DataUsageCallback cannot be null"); - - // TODO: Implement stub. + if (callback == null || callback.request == null + || callback.request.requestId == DataUsageRequest.REQUEST_ID_UNSET) { + throw new IllegalArgumentException("Invalid DataUsageCallback"); + } + try { + mService.unregisterDataUsageRequest(callback.request); + } catch (RemoteException e) { + if (DBG) Log.d(TAG, "Remote exception when unregistering callback"); + } } /** @@ -366,4 +421,38 @@ public class NetworkStatsManager { } return template; } + + private static class CallbackHandler extends Handler { + private DataUsageCallback mCallback; + CallbackHandler(Looper looper, DataUsageCallback callback) { + super(looper); + mCallback = callback; + } + + @Override + public void handleMessage(Message message) { + DataUsageRequest request = + (DataUsageRequest) getObject(message, DataUsageRequest.PARCELABLE_KEY); + + switch (message.what) { + case CALLBACK_LIMIT_REACHED: { + if (mCallback != null) { + mCallback.onLimitReached(); + } else { + Log.e(TAG, "limit reached with released callback for " + request); + } + break; + } + case CALLBACK_RELEASED: { + if (DBG) Log.d(TAG, "callback released for " + request); + mCallback = null; + break; + } + } + } + + private static Object getObject(Message msg, String key) { + return msg.getData().getParcelable(key); + } + } } diff --git a/core/java/android/net/DataUsageRequest.java b/core/java/android/net/DataUsageRequest.java index 0e46f4c0cb..5e96cc1fe0 100644 --- a/core/java/android/net/DataUsageRequest.java +++ b/core/java/android/net/DataUsageRequest.java @@ -31,6 +31,11 @@ import java.util.Objects; */ public class DataUsageRequest implements Parcelable { + /** + * @hide + */ + public static final String PARCELABLE_KEY = "DataUsageRequest"; + /** * @hide */ diff --git a/core/java/android/net/INetworkStatsService.aidl b/core/java/android/net/INetworkStatsService.aidl index 6436e42676..2eea9408f9 100644 --- a/core/java/android/net/INetworkStatsService.aidl +++ b/core/java/android/net/INetworkStatsService.aidl @@ -16,10 +16,13 @@ package android.net; +import android.net.DataUsageRequest; import android.net.INetworkStatsSession; import android.net.NetworkStats; import android.net.NetworkStatsHistory; import android.net.NetworkTemplate; +import android.os.IBinder; +import android.os.Messenger; /** {@hide} */ interface INetworkStatsService { @@ -57,4 +60,11 @@ interface INetworkStatsService { /** Advise persistance threshold; may be overridden internally. */ void advisePersistThreshold(long thresholdBytes); + /** Registers a callback on data usage. */ + DataUsageRequest registerDataUsageCallback(String callingPackage, + in DataUsageRequest request, in Messenger messenger, in IBinder binder); + + /** Unregisters a callback on data usage. */ + void unregisterDataUsageRequest(in DataUsageRequest request); + } diff --git a/services/core/java/com/android/server/net/NetworkStatsAccess.java b/services/core/java/com/android/server/net/NetworkStatsAccess.java index 479b065f5c..98fe770774 100644 --- a/services/core/java/com/android/server/net/NetworkStatsAccess.java +++ b/services/core/java/com/android/server/net/NetworkStatsAccess.java @@ -17,6 +17,7 @@ package com.android.server.net; import static android.Manifest.permission.READ_NETWORK_USAGE_HISTORY; +import static android.net.NetworkStats.UID_ALL; import static android.net.TrafficStats.UID_REMOVED; import static android.net.TrafficStats.UID_TETHERING; @@ -48,6 +49,7 @@ public final class NetworkStatsAccess { @IntDef({ Level.DEFAULT, Level.USER, + Level.DEVICESUMMARY, Level.DEVICE, }) @Retention(RetentionPolicy.SOURCE) @@ -147,6 +149,12 @@ public final class NetworkStatsAccess { // Device-level access - can access usage for any uid. return true; case NetworkStatsAccess.Level.DEVICESUMMARY: + // Can access usage for any app running in the same user, along + // with some special uids (system, removed, or tethering) and + // anonymized uids + return uid == android.os.Process.SYSTEM_UID || uid == UID_REMOVED + || uid == UID_TETHERING || uid == UID_ALL + || UserHandle.getUserId(uid) == UserHandle.getUserId(callerUid); case NetworkStatsAccess.Level.USER: // User-level access - can access usage for any app running in the same user, along // with some special uids (system, removed, or tethering). diff --git a/services/core/java/com/android/server/net/NetworkStatsCollection.java b/services/core/java/com/android/server/net/NetworkStatsCollection.java index eec7d931d4..d986e94b02 100644 --- a/services/core/java/com/android/server/net/NetworkStatsCollection.java +++ b/services/core/java/com/android/server/net/NetworkStatsCollection.java @@ -135,7 +135,11 @@ public class NetworkStatsCollection implements FileRotator.Reader { } public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel) { - final int callerUid = Binder.getCallingUid(); + return getRelevantUids(accessLevel, Binder.getCallingUid()); + } + + public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel, + final int callerUid) { IntArray uids = new IntArray(); for (int i = 0; i < mStats.size(); i++) { final Key key = mStats.keyAt(i); @@ -169,7 +173,17 @@ public class NetworkStatsCollection implements FileRotator.Reader { public NetworkStatsHistory getHistory( NetworkTemplate template, int uid, int set, int tag, int fields, long start, long end, @NetworkStatsAccess.Level int accessLevel) { - final int callerUid = Binder.getCallingUid(); + return getHistory(template, uid, set, tag, fields, start, end, accessLevel, + Binder.getCallingUid()); + } + + /** + * Combine all {@link NetworkStatsHistory} in this collection which match + * the requested parameters. + */ + public NetworkStatsHistory getHistory( + NetworkTemplate template, int uid, int set, int tag, int fields, long start, long end, + @NetworkStatsAccess.Level int accessLevel, int callerUid) { if (!NetworkStatsAccess.isAccessibleToUser(uid, callerUid, accessLevel)) { throw new SecurityException("Network stats history of uid " + uid + " is forbidden for caller " + callerUid); @@ -198,6 +212,15 @@ public class NetworkStatsCollection implements FileRotator.Reader { */ public NetworkStats getSummary(NetworkTemplate template, long start, long end, @NetworkStatsAccess.Level int accessLevel) { + return getSummary(template, start, end, accessLevel, Binder.getCallingUid()); + } + + /** + * Summarize all {@link NetworkStatsHistory} in this collection which match + * the requested parameters. + */ + public NetworkStats getSummary(NetworkTemplate template, long start, long end, + @NetworkStatsAccess.Level int accessLevel, int callerUid) { final long now = System.currentTimeMillis(); final NetworkStats stats = new NetworkStats(end - start, 24); @@ -207,7 +230,6 @@ public class NetworkStatsCollection implements FileRotator.Reader { final NetworkStats.Entry entry = new NetworkStats.Entry(); NetworkStatsHistory.Entry historyEntry = null; - final int callerUid = Binder.getCallingUid(); for (int i = 0; i < mStats.size(); i++) { final Key key = mStats.keyAt(i); if (templateMatches(template, key.ident) diff --git a/services/core/java/com/android/server/net/NetworkStatsObservers.java b/services/core/java/com/android/server/net/NetworkStatsObservers.java new file mode 100644 index 0000000000..2f55562bc0 --- /dev/null +++ b/services/core/java/com/android/server/net/NetworkStatsObservers.java @@ -0,0 +1,493 @@ +/* + * Copyright (C) 2016 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.server.net; + +import static android.net.TrafficStats.MB_IN_BYTES; +import static com.android.internal.util.Preconditions.checkArgument; + +import android.app.usage.NetworkStatsManager; +import android.net.DataUsageRequest; +import android.net.NetworkStats; +import android.net.NetworkStats.NonMonotonicObserver; +import android.net.NetworkStatsHistory; +import android.net.NetworkTemplate; +import android.os.Binder; +import android.os.Bundle; +import android.os.Looper; +import android.os.Message; +import android.os.Messenger; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.IBinder; +import android.os.Process; +import android.os.RemoteException; +import android.util.ArrayMap; +import android.util.IntArray; +import android.util.SparseArray; +import android.util.Slog; + +import com.android.internal.annotations.VisibleForTesting; +import com.android.internal.net.VpnInfo; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Manages observers of {@link NetworkStats}. Allows observers to be notified when + * data usage has been reported in {@link NetworkStatsService}. An observer can set + * a threshold of how much data it cares about to be notified. + */ +class NetworkStatsObservers { + private static final String TAG = "NetworkStatsObservers"; + private static final boolean LOGV = true; + + private static final long MIN_THRESHOLD_BYTES = 2 * MB_IN_BYTES; + + private static final int MSG_REGISTER = 1; + private static final int MSG_UNREGISTER = 2; + private static final int MSG_UPDATE_STATS = 3; + + // All access to this map must be done from the handler thread. + // indexed by DataUsageRequest#requestId + private final SparseArray mDataUsageRequests = new SparseArray<>(); + + // Sequence number of DataUsageRequests + private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger(); + + // Lazily instantiated when an observer is registered. + private Handler mHandler; + + /** + * Creates a wrapper that contains the caller context and a normalized request. + * The request should be returned to the caller app, and the wrapper should be sent to this + * object through #addObserver by the service handler. + * + *

It will register the observer asynchronously, so it is safe to call from any thread. + * + * @return the normalized request wrapped within {@link RequestInfo}. + */ + public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger, + IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) { + checkVisibilityUids(callingUid, accessLevel, inputRequest.uids); + + DataUsageRequest request = buildRequest(inputRequest); + RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid, + accessLevel); + + if (LOGV) Slog.v(TAG, "Registering observer for " + request); + getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo)); + return request; + } + + /** + * Unregister a data usage observer. + * + *

It will unregister the observer asynchronously, so it is safe to call from any thread. + */ + public void unregister(DataUsageRequest request, int callingUid) { + getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */, + request)); + } + + /** + * Updates data usage statistics of registered observers and notifies if limits are reached. + * + *

It will update stats asynchronously, so it is safe to call from any thread. + */ + public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, + ArrayMap activeIfaces, + ArrayMap activeUidIfaces, + VpnInfo[] vpnArray, long currentTime) { + StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces, + activeUidIfaces, vpnArray, currentTime); + getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext)); + } + + private Handler getHandler() { + if (mHandler == null) { + synchronized (this) { + if (mHandler == null) { + if (LOGV) Slog.v(TAG, "Creating handler"); + mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback); + } + } + } + return mHandler; + } + + @VisibleForTesting + protected Looper getHandlerLooperLocked() { + HandlerThread handlerThread = new HandlerThread(TAG); + handlerThread.start(); + return handlerThread.getLooper(); + } + + private Handler.Callback mHandlerCallback = new Handler.Callback() { + @Override + public boolean handleMessage(Message msg) { + switch (msg.what) { + case MSG_REGISTER: { + handleRegister((RequestInfo) msg.obj); + return true; + } + case MSG_UNREGISTER: { + handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */); + return true; + } + case MSG_UPDATE_STATS: { + handleUpdateStats((StatsContext) msg.obj); + return true; + } + default: { + return false; + } + } + } + }; + + /** + * Adds a {@link RequestInfo} as an observer. + * Should only be called from the handler thread otherwise there will be a race condition + * on mDataUsageRequests. + */ + private void handleRegister(RequestInfo requestInfo) { + mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo); + } + + /** + * Removes a {@link DataUsageRequest} if the calling uid is authorized. + * Should only be called from the handler thread otherwise there will be a race condition + * on mDataUsageRequests. + */ + private void handleUnregister(DataUsageRequest request, int callingUid) { + RequestInfo requestInfo; + requestInfo = mDataUsageRequests.get(request.requestId); + if (requestInfo == null) { + if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request); + return; + } + if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) { + Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request); + return; + } + + if (LOGV) Slog.v(TAG, "Unregistering " + request); + mDataUsageRequests.remove(request.requestId); + requestInfo.unlinkDeathRecipient(); + requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED); + } + + private void handleUpdateStats(StatsContext statsContext) { + if (mDataUsageRequests.size() == 0) { + if (LOGV) Slog.v(TAG, "No registered listeners of data usage"); + return; + } + + if (LOGV) Slog.v(TAG, "Checking if any registered observer needs to be notified"); + for (int i = 0; i < mDataUsageRequests.size(); i++) { + RequestInfo requestInfo = mDataUsageRequests.valueAt(i); + requestInfo.updateStats(statsContext); + } + } + + private DataUsageRequest buildRequest(DataUsageRequest request) { + // Cap the minimum threshold to a safe default to avoid too many callbacks + long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes); + if (thresholdInBytes < request.thresholdInBytes) { + Slog.w(TAG, "Threshold was too low for " + request + + ". Overriding to a safer default of " + thresholdInBytes + " bytes"); + } + return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(), + request.templates, request.uids, thresholdInBytes); + } + + private RequestInfo buildRequestInfo(DataUsageRequest request, + Messenger messenger, IBinder binder, int callingUid, + @NetworkStatsAccess.Level int accessLevel) { + if (accessLevel <= NetworkStatsAccess.Level.USER + || request.uids != null && request.uids.length > 0) { + return new UserUsageRequestInfo(this, request, messenger, binder, callingUid, + accessLevel); + } else { + // Safety check in case a new access level is added and we forgot to update this + checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY); + return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid, + accessLevel); + } + } + + private void checkVisibilityUids(int callingUid, @NetworkStatsAccess.Level int accessLevel, + int[] uids) { + if (uids == null) { + return; + } + for (int i = 0; i < uids.length; i++) { + if (!NetworkStatsAccess.isAccessibleToUser(uids[i], callingUid, accessLevel)) { + throw new SecurityException("Caller " + callingUid + " cannot monitor network stats" + + " for uid " + uids[i] + " with accessLevel " + accessLevel); + } + } + } + + /** + * Tracks information relevant to a data usage observer. + * It will notice when the calling process dies so we can self-expire. + */ + private abstract static class RequestInfo implements IBinder.DeathRecipient { + private final NetworkStatsObservers mStatsObserver; + protected final DataUsageRequest mRequest; + private final Messenger mMessenger; + private final IBinder mBinder; + protected final int mCallingUid; + protected final @NetworkStatsAccess.Level int mAccessLevel; + protected NetworkStatsRecorder mRecorder; + protected NetworkStatsCollection mCollection; + + RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, + Messenger messenger, IBinder binder, int callingUid, + @NetworkStatsAccess.Level int accessLevel) { + mStatsObserver = statsObserver; + mRequest = request; + mMessenger = messenger; + mBinder = binder; + mCallingUid = callingUid; + mAccessLevel = accessLevel; + + try { + mBinder.linkToDeath(this, 0); + } catch (RemoteException e) { + binderDied(); + } + } + + @Override + public void binderDied() { + if (LOGV) Slog.v(TAG, "RequestInfo binderDied(" + + mRequest + ", " + mBinder + ")"); + mStatsObserver.unregister(mRequest, Process.SYSTEM_UID); + callCallback(NetworkStatsManager.CALLBACK_RELEASED); + } + + @Override + public String toString() { + return "RequestInfo from uid:" + mCallingUid + + " for " + mRequest + " accessLevel:" + mAccessLevel; + } + + private void unlinkDeathRecipient() { + if (mBinder != null) { + mBinder.unlinkToDeath(this, 0); + } + } + + /** + * Update stats given the samples and interface to identity mappings. + */ + private void updateStats(StatsContext statsContext) { + if (mRecorder == null) { + // First run; establish baseline stats + resetRecorder(); + recordSample(statsContext); + return; + } + recordSample(statsContext); + + if (checkStats()) { + resetRecorder(); + callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED); + } + } + + private void callCallback(int callbackType) { + Bundle bundle = new Bundle(); + bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest); + Message msg = Message.obtain(); + msg.what = callbackType; + msg.setData(bundle); + try { + if (LOGV) { + Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType) + + " for " + mRequest); + } + mMessenger.send(msg); + } catch (RemoteException e) { + // May occur naturally in the race of binder death. + Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest); + } + } + + private void resetRecorder() { + mRecorder = new NetworkStatsRecorder(); + mCollection = mRecorder.getSinceBoot(); + } + + protected abstract boolean checkStats(); + + protected abstract void recordSample(StatsContext statsContext); + + private String callbackTypeToName(int callbackType) { + switch (callbackType) { + case NetworkStatsManager.CALLBACK_LIMIT_REACHED: + return "LIMIT_REACHED"; + case NetworkStatsManager.CALLBACK_RELEASED: + return "RELEASED"; + default: + return "UNKNOWN"; + } + } + } + + private static class NetworkUsageRequestInfo extends RequestInfo { + NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, + Messenger messenger, IBinder binder, int callingUid, + @NetworkStatsAccess.Level int accessLevel) { + super(statsObserver, request, messenger, binder, callingUid, accessLevel); + } + + @Override + protected boolean checkStats() { + for (int i = 0; i < mRequest.templates.length; i++) { + long bytesSoFar = getTotalBytesForNetwork(mRequest.templates[i]); + if (LOGV) { + Slog.v(TAG, bytesSoFar + " bytes so far since notification for " + + mRequest.templates[i]); + } + if (bytesSoFar > mRequest.thresholdInBytes) { + return true; + } + } + return false; + } + + @Override + protected void recordSample(StatsContext statsContext) { + // Recorder does not need to be locked in this context since only the handler + // thread will update it + mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces, + statsContext.mVpnArray, statsContext.mCurrentTime); + } + + /** + * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate + * over all buckets, which in this case should be only one since we built it big enough + * that it will outlive the caller. If it doesn't, then there will be multiple buckets. + */ + private long getTotalBytesForNetwork(NetworkTemplate template) { + NetworkStats stats = mCollection.getSummary(template, + Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */, + mAccessLevel, mCallingUid); + if (LOGV) { + Slog.v(TAG, "Netstats for " + template + ": " + stats); + } + return stats.getTotalBytes(); + } + } + + private static class UserUsageRequestInfo extends RequestInfo { + UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, + Messenger messenger, IBinder binder, int callingUid, + @NetworkStatsAccess.Level int accessLevel) { + super(statsObserver, request, messenger, binder, callingUid, accessLevel); + } + + @Override + protected boolean checkStats() { + int[] uidsToMonitor = getUidsToMonitor(); + + for (int i = 0; i < mRequest.templates.length; i++) { + for (int j = 0; j < uidsToMonitor.length; j++) { + long bytesSoFar = getTotalBytesForNetworkUid(mRequest.templates[i], + uidsToMonitor[j]); + + if (LOGV) { + Slog.v(TAG, bytesSoFar + " bytes so far since notification for " + + mRequest.templates[i] + " for uid=" + uidsToMonitor[j]); + } + if (bytesSoFar > mRequest.thresholdInBytes) { + return true; + } + } + } + return false; + } + + @Override + protected void recordSample(StatsContext statsContext) { + // Recorder does not need to be locked in this context since only the handler + // thread will update it + mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces, + statsContext.mVpnArray, statsContext.mCurrentTime); + } + + /** + * Reads all stats matching the given template and uid. Ther history will likely only + * contain one bucket per ident since we build it big enough that it will outlive the + * caller lifetime. + */ + private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) { + try { + NetworkStatsHistory history = mCollection.getHistory(template, uid, + NetworkStats.SET_ALL, NetworkStats.TAG_NONE, + NetworkStatsHistory.FIELD_ALL, + Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */, + mAccessLevel, mCallingUid); + return history.getTotalBytes(); + } catch (SecurityException e) { + if (LOGV) { + Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid " + + uid); + } + return 0; + } + } + + private int[] getUidsToMonitor() { + if (mRequest.uids == null || mRequest.uids.length == 0) { + return mCollection.getRelevantUids(mAccessLevel, mCallingUid); + } + // Pick only uids from the request that are currently accessible to the user + IntArray accessibleUids = new IntArray(mRequest.uids.length); + for (int i = 0; i < mRequest.uids.length; i++) { + int uid = mRequest.uids[i]; + if (NetworkStatsAccess.isAccessibleToUser(uid, mCallingUid, mAccessLevel)) { + accessibleUids.add(uid); + } + } + return accessibleUids.toArray(); + } + } + + private static class StatsContext { + NetworkStats mXtSnapshot; + NetworkStats mUidSnapshot; + ArrayMap mActiveIfaces; + ArrayMap mActiveUidIfaces; + VpnInfo[] mVpnArray; + long mCurrentTime; + + StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, + ArrayMap activeIfaces, + ArrayMap activeUidIfaces, + VpnInfo[] vpnArray, long currentTime) { + mXtSnapshot = xtSnapshot; + mUidSnapshot = uidSnapshot; + mActiveIfaces = activeIfaces; + mActiveUidIfaces = activeUidIfaces; + mVpnArray = vpnArray; + mCurrentTime = currentTime; + } + } +} diff --git a/services/core/java/com/android/server/net/NetworkStatsRecorder.java b/services/core/java/com/android/server/net/NetworkStatsRecorder.java index c09196006b..04dc917ec3 100644 --- a/services/core/java/com/android/server/net/NetworkStatsRecorder.java +++ b/services/core/java/com/android/server/net/NetworkStatsRecorder.java @@ -19,6 +19,7 @@ package com.android.server.net; import static android.net.NetworkStats.TAG_NONE; import static android.net.TrafficStats.KB_IN_BYTES; import static android.net.TrafficStats.MB_IN_BYTES; +import static android.text.format.DateUtils.YEAR_IN_MILLIS; import static com.android.internal.util.Preconditions.checkNotNull; import android.net.NetworkStats; @@ -54,7 +55,7 @@ import libcore.io.IoUtils; * Logic to record deltas between periodic {@link NetworkStats} snapshots into * {@link NetworkStatsHistory} that belong to {@link NetworkStatsCollection}. * Keeps pending changes in memory until they pass a specific threshold, in - * bytes. Uses {@link FileRotator} for persistence logic. + * bytes. Uses {@link FileRotator} for persistence logic if present. *

* Not inherently thread safe. */ @@ -86,6 +87,29 @@ public class NetworkStatsRecorder { private WeakReference mComplete; + /** + * Non-persisted recorder, with only one bucket. Used by {@link NetworkStatsObservers}. + */ + public NetworkStatsRecorder() { + mRotator = null; + mObserver = null; + mDropBox = null; + mCookie = null; + + // set the bucket big enough to have all data in one bucket, but allow some + // slack to avoid overflow + mBucketDuration = YEAR_IN_MILLIS; + mOnlyTags = false; + + mPending = null; + mSinceBoot = new NetworkStatsCollection(mBucketDuration); + + mPendingRewriter = null; + } + + /** + * Persisted recorder. + */ public NetworkStatsRecorder(FileRotator rotator, NonMonotonicObserver observer, DropBoxManager dropBox, String cookie, long bucketDuration, boolean onlyTags) { mRotator = checkNotNull(rotator, "missing FileRotator"); @@ -110,9 +134,15 @@ public class NetworkStatsRecorder { public void resetLocked() { mLastSnapshot = null; - mPending.reset(); - mSinceBoot.reset(); - mComplete.clear(); + if (mPending != null) { + mPending.reset(); + } + if (mSinceBoot != null) { + mSinceBoot.reset(); + } + if (mComplete != null) { + mComplete.clear(); + } } public NetworkStats.Entry getTotalSinceBootLocked(NetworkTemplate template) { @@ -120,6 +150,10 @@ public class NetworkStatsRecorder { NetworkStatsAccess.Level.DEVICE).getTotal(null); } + public NetworkStatsCollection getSinceBoot() { + return mSinceBoot; + } + /** * Load complete history represented by {@link FileRotator}. Caches * internally as a {@link WeakReference}, and updated with future @@ -127,6 +161,7 @@ public class NetworkStatsRecorder { * as reference is valid. */ public NetworkStatsCollection getOrLoadCompleteLocked() { + checkNotNull(mRotator, "missing FileRotator"); NetworkStatsCollection res = mComplete != null ? mComplete.get() : null; if (res == null) { res = loadLocked(Long.MIN_VALUE, Long.MAX_VALUE); @@ -136,6 +171,7 @@ public class NetworkStatsRecorder { } public NetworkStatsCollection getOrLoadPartialLocked(long start, long end) { + checkNotNull(mRotator, "missing FileRotator"); NetworkStatsCollection res = mComplete != null ? mComplete.get() : null; if (res == null) { res = loadLocked(start, end); @@ -205,7 +241,9 @@ public class NetworkStatsRecorder { // only record tag data when requested if ((entry.tag == TAG_NONE) != mOnlyTags) { - mPending.recordData(ident, entry.uid, entry.set, entry.tag, start, end, entry); + if (mPending != null) { + mPending.recordData(ident, entry.uid, entry.set, entry.tag, start, end, entry); + } // also record against boot stats when present if (mSinceBoot != null) { @@ -231,6 +269,7 @@ public class NetworkStatsRecorder { * {@link #mPersistThresholdBytes}. */ public void maybePersistLocked(long currentTimeMillis) { + checkNotNull(mRotator, "missing FileRotator"); final long pendingBytes = mPending.getTotalBytes(); if (pendingBytes >= mPersistThresholdBytes) { forcePersistLocked(currentTimeMillis); @@ -243,6 +282,7 @@ public class NetworkStatsRecorder { * Force persisting any pending deltas. */ public void forcePersistLocked(long currentTimeMillis) { + checkNotNull(mRotator, "missing FileRotator"); if (mPending.isDirty()) { if (LOGD) Slog.d(TAG, "forcePersistLocked() writing for " + mCookie); try { @@ -264,20 +304,26 @@ public class NetworkStatsRecorder { * to {@link TrafficStats#UID_REMOVED}. */ public void removeUidsLocked(int[] uids) { - try { - // Rewrite all persisted data to migrate UID stats - mRotator.rewriteAll(new RemoveUidRewriter(mBucketDuration, uids)); - } catch (IOException e) { - Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e); - recoverFromWtf(); - } catch (OutOfMemoryError e) { - Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e); - recoverFromWtf(); + if (mRotator != null) { + try { + // Rewrite all persisted data to migrate UID stats + mRotator.rewriteAll(new RemoveUidRewriter(mBucketDuration, uids)); + } catch (IOException e) { + Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e); + recoverFromWtf(); + } catch (OutOfMemoryError e) { + Log.wtf(TAG, "problem removing UIDs " + Arrays.toString(uids), e); + recoverFromWtf(); + } } // Remove any pending stats - mPending.removeUids(uids); - mSinceBoot.removeUids(uids); + if (mPending != null) { + mPending.removeUids(uids); + } + if (mSinceBoot != null) { + mSinceBoot.removeUids(uids); + } // Clear UID from current stats snapshot if (mLastSnapshot != null) { @@ -361,6 +407,8 @@ public class NetworkStatsRecorder { } public void importLegacyNetworkLocked(File file) throws IOException { + checkNotNull(mRotator, "missing FileRotator"); + // legacy file still exists; start empty to avoid double importing mRotator.deleteAll(); @@ -379,6 +427,8 @@ public class NetworkStatsRecorder { } public void importLegacyUidLocked(File file) throws IOException { + checkNotNull(mRotator, "missing FileRotator"); + // legacy file still exists; start empty to avoid double importing mRotator.deleteAll(); @@ -397,7 +447,9 @@ public class NetworkStatsRecorder { } public void dumpLocked(IndentingPrintWriter pw, boolean fullHistory) { - pw.print("Pending bytes: "); pw.println(mPending.getTotalBytes()); + if (mPending != null) { + pw.print("Pending bytes: "); pw.println(mPending.getTotalBytes()); + } if (fullHistory) { pw.println("Complete history:"); getOrLoadCompleteLocked().dump(pw); diff --git a/services/core/java/com/android/server/net/NetworkStatsService.java b/services/core/java/com/android/server/net/NetworkStatsService.java index 3aeceef51d..2c2e9b91b5 100644 --- a/services/core/java/com/android/server/net/NetworkStatsService.java +++ b/services/core/java/com/android/server/net/NetworkStatsService.java @@ -57,6 +57,7 @@ import static android.text.format.DateUtils.DAY_IN_MILLIS; import static android.text.format.DateUtils.HOUR_IN_MILLIS; import static android.text.format.DateUtils.MINUTE_IN_MILLIS; import static android.text.format.DateUtils.SECOND_IN_MILLIS; +import static com.android.internal.util.Preconditions.checkArgument; import static com.android.internal.util.Preconditions.checkNotNull; import static com.android.server.NetworkManagementService.LIMIT_GLOBAL_ALERT; import static com.android.server.NetworkManagementSocketTagger.resetKernelUidStats; @@ -72,6 +73,7 @@ import android.content.Intent; import android.content.IntentFilter; import android.content.pm.ApplicationInfo; import android.content.pm.PackageManager; +import android.net.DataUsageRequest; import android.net.IConnectivityManager; import android.net.INetworkManagementEventObserver; import android.net.INetworkStatsService; @@ -90,8 +92,10 @@ import android.os.DropBoxManager; import android.os.Environment; import android.os.Handler; import android.os.HandlerThread; +import android.os.IBinder; import android.os.INetworkManagementService; import android.os.Message; +import android.os.Messenger; import android.os.PowerManager; import android.os.RemoteException; import android.os.SystemClock; @@ -152,6 +156,7 @@ public class NetworkStatsService extends INetworkStatsService.Stub { private final TrustedTime mTime; private final TelephonyManager mTeleManager; private final NetworkStatsSettings mSettings; + private final NetworkStatsObservers mStatsObservers; private final File mSystemDir; private final File mBaseDir; @@ -233,43 +238,65 @@ public class NetworkStatsService extends INetworkStatsService.Stub { /** Data layer operation counters for splicing into other structures. */ private NetworkStats mUidOperations = new NetworkStats(0L, 10); - private final Handler mHandler; + /** Must be set in factory by calling #setHandler. */ + private Handler mHandler; + private Handler.Callback mHandlerCallback; private boolean mSystemReady; private long mPersistThreshold = 2 * MB_IN_BYTES; private long mGlobalAlertBytes; - public NetworkStatsService( - Context context, INetworkManagementService networkManager, IAlarmManager alarmManager) { - this(context, networkManager, alarmManager, NtpTrustedTime.getInstance(context), - getDefaultSystemDir(), new DefaultNetworkStatsSettings(context)); - } - private static File getDefaultSystemDir() { return new File(Environment.getDataDirectory(), "system"); } - public NetworkStatsService(Context context, INetworkManagementService networkManager, - IAlarmManager alarmManager, TrustedTime time, File systemDir, - NetworkStatsSettings settings) { + private static File getDefaultBaseDir() { + File baseDir = new File(getDefaultSystemDir(), "netstats"); + baseDir.mkdirs(); + return baseDir; + } + + public static NetworkStatsService create(Context context, + INetworkManagementService networkManager) { + AlarmManager alarmManager = (AlarmManager) context.getSystemService(Context.ALARM_SERVICE); + PowerManager powerManager = (PowerManager) context.getSystemService(Context.POWER_SERVICE); + PowerManager.WakeLock wakeLock = + powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, TAG); + + NetworkStatsService service = new NetworkStatsService(context, networkManager, alarmManager, + wakeLock, NtpTrustedTime.getInstance(context), TelephonyManager.getDefault(), + new DefaultNetworkStatsSettings(context), new NetworkStatsObservers(), + getDefaultSystemDir(), getDefaultBaseDir()); + + HandlerThread handlerThread = new HandlerThread(TAG); + Handler.Callback callback = new HandlerCallback(service); + handlerThread.start(); + Handler handler = new Handler(handlerThread.getLooper(), callback); + service.setHandler(handler, callback); + return service; + } + + @VisibleForTesting + NetworkStatsService(Context context, INetworkManagementService networkManager, + AlarmManager alarmManager, PowerManager.WakeLock wakeLock, TrustedTime time, + TelephonyManager teleManager, NetworkStatsSettings settings, + NetworkStatsObservers statsObservers, File systemDir, File baseDir) { mContext = checkNotNull(context, "missing Context"); mNetworkManager = checkNotNull(networkManager, "missing INetworkManagementService"); + mAlarmManager = checkNotNull(alarmManager, "missing AlarmManager"); mTime = checkNotNull(time, "missing TrustedTime"); - mTeleManager = checkNotNull(TelephonyManager.getDefault(), "missing TelephonyManager"); mSettings = checkNotNull(settings, "missing NetworkStatsSettings"); - mAlarmManager = (AlarmManager) context.getSystemService(Context.ALARM_SERVICE); + mTeleManager = checkNotNull(teleManager, "missing TelephonyManager"); + mWakeLock = checkNotNull(wakeLock, "missing WakeLock"); + mStatsObservers = checkNotNull(statsObservers, "missing NetworkStatsObservers"); + mSystemDir = checkNotNull(systemDir, "missing systemDir"); + mBaseDir = checkNotNull(baseDir, "missing baseDir"); + } - final PowerManager powerManager = (PowerManager) context.getSystemService( - Context.POWER_SERVICE); - mWakeLock = powerManager.newWakeLock(PowerManager.PARTIAL_WAKE_LOCK, TAG); - - HandlerThread thread = new HandlerThread(TAG); - thread.start(); - mHandler = new Handler(thread.getLooper(), mHandlerCallback); - - mSystemDir = checkNotNull(systemDir); - mBaseDir = new File(systemDir, "netstats"); - mBaseDir.mkdirs(); + @VisibleForTesting + void setHandler(Handler handler, Handler.Callback callback) { + mHandler = handler; + mHandlerCallback = callback; } public void bindConnectivityManager(IConnectivityManager connManager) { @@ -733,6 +760,46 @@ public class NetworkStatsService extends INetworkStatsService.Stub { registerGlobalAlert(); } + @Override + public DataUsageRequest registerDataUsageCallback(String callingPackage, + DataUsageRequest request, Messenger messenger, IBinder binder) { + checkNotNull(callingPackage, "calling package is null"); + checkNotNull(request, "DataUsageRequest is null"); + checkNotNull(request.templates, "NetworkTemplate is null"); + checkArgument(request.templates.length > 0); + checkNotNull(messenger, "messenger is null"); + checkNotNull(binder, "binder is null"); + + int callingUid = Binder.getCallingUid(); + @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(callingPackage); + DataUsageRequest normalizedRequest; + final long token = Binder.clearCallingIdentity(); + try { + normalizedRequest = mStatsObservers.register(request, messenger, binder, + callingUid, accessLevel); + } finally { + Binder.restoreCallingIdentity(token); + } + + // Create baseline stats + mHandler.sendMessage(mHandler.obtainMessage(MSG_PERFORM_POLL, FLAG_PERSIST_ALL)); + + return normalizedRequest; + } + + @Override + public void unregisterDataUsageRequest(DataUsageRequest request) { + checkNotNull(request, "DataUsageRequest is null"); + + int callingUid = Binder.getCallingUid(); + final long token = Binder.clearCallingIdentity(); + try { + mStatsObservers.unregister(request, callingUid); + } finally { + Binder.restoreCallingIdentity(token); + } + } + /** * Update {@link NetworkStatsRecorder} and {@link #mGlobalAlertBytes} to * reflect current {@link #mPersistThreshold} value. Always defers to @@ -945,6 +1012,11 @@ public class NetworkStatsService extends INetworkStatsService.Stub { mXtRecorder.recordSnapshotLocked(xtSnapshot, mActiveIfaces, null, currentTime); mUidRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, vpnArray, currentTime); mUidTagRecorder.recordSnapshotLocked(uidSnapshot, mActiveUidIfaces, vpnArray, currentTime); + + // We need to make copies of member fields that are sent to the observer to avoid + // a race condition between the service handler thread and the observer's + mStatsObservers.updateStats(xtSnapshot, uidSnapshot, new ArrayMap<>(mActiveIfaces), + new ArrayMap<>(mActiveUidIfaces), vpnArray, currentTime); } /** @@ -1243,21 +1315,28 @@ public class NetworkStatsService extends INetworkStatsService.Stub { } } - private Handler.Callback mHandlerCallback = new Handler.Callback() { + @VisibleForTesting + static class HandlerCallback implements Handler.Callback { + private final NetworkStatsService mService; + + HandlerCallback(NetworkStatsService service) { + this.mService = service; + } + @Override public boolean handleMessage(Message msg) { switch (msg.what) { case MSG_PERFORM_POLL: { final int flags = msg.arg1; - performPoll(flags); + mService.performPoll(flags); return true; } case MSG_UPDATE_IFACES: { - updateIfaces(); + mService.updateIfaces(); return true; } case MSG_REGISTER_GLOBAL_ALERT: { - registerGlobalAlert(); + mService.registerGlobalAlert(); return true; } default: { @@ -1265,7 +1344,7 @@ public class NetworkStatsService extends INetworkStatsService.Stub { } } } - }; + } private void assertBandwidthControlEnabled() { if (!isBandwidthControlEnabled()) {