Add type information to adb module.

mypy still complains about a few things here, but that looks to be
mostly coming from code that is very, very dead in python 3.

The tests don't run, and haven't since the python 3 switch. Will try to
revive those next, but it likely requires moving files around to fix the
package structure (source needs to go in a subdirectory to make a real
package, as do the tests).

Bug: None
Test: mypy . && pylint .
Change-Id: Ide55a41babecbd6684b73787b17e7f5fdb81c090
This commit is contained in:
Dan Albert
2022-12-16 13:34:43 -08:00
parent 16a9f493c9
commit 316dbf5332
3 changed files with 119 additions and 76 deletions

View File

@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from __future__ import annotations
import atexit import atexit
import base64 import base64
import logging import logging
import os import os
import re import re
import subprocess import subprocess
from typing import Any, Callable
class FindDeviceError(RuntimeError): class FindDeviceError(RuntimeError):
@@ -26,19 +29,21 @@ class FindDeviceError(RuntimeError):
class DeviceNotFoundError(FindDeviceError): class DeviceNotFoundError(FindDeviceError):
def __init__(self, serial): def __init__(self, serial: str) -> None:
self.serial = serial self.serial = serial
super(DeviceNotFoundError, self).__init__( super(DeviceNotFoundError, self).__init__(
'No device with serial {}'.format(serial)) 'No device with serial {}'.format(serial))
class NoUniqueDeviceError(FindDeviceError): class NoUniqueDeviceError(FindDeviceError):
def __init__(self): def __init__(self) -> None:
super(NoUniqueDeviceError, self).__init__('No unique device') super(NoUniqueDeviceError, self).__init__('No unique device')
class ShellError(RuntimeError): class ShellError(RuntimeError):
def __init__(self, cmd, stdout, stderr, exit_code): def __init__(
self, cmd: list[str], stdout: str, stderr: str, exit_code: int
) -> None:
super(ShellError, self).__init__( super(ShellError, self).__init__(
'`{0}` exited with code {1}'.format(cmd, exit_code)) '`{0}` exited with code {1}'.format(cmd, exit_code))
self.cmd = cmd self.cmd = cmd
@@ -47,7 +52,7 @@ class ShellError(RuntimeError):
self.exit_code = exit_code self.exit_code = exit_code
def get_devices(adb_path='adb'): def get_devices(adb_path: str = 'adb') -> list[str]:
with open(os.devnull, 'wb') as devnull: with open(os.devnull, 'wb') as devnull:
subprocess.check_call([adb_path, 'start-server'], stdout=devnull, subprocess.check_call([adb_path, 'start-server'], stdout=devnull,
stderr=devnull) stderr=devnull)
@@ -68,21 +73,27 @@ def get_devices(adb_path='adb'):
return devices return devices
def _get_unique_device(product=None, adb_path='adb'): def _get_unique_device(
product: str | None = None, adb_path: str = 'adb'
) -> AndroidDevice:
devices = get_devices(adb_path=adb_path) devices = get_devices(adb_path=adb_path)
if len(devices) != 1: if len(devices) != 1:
raise NoUniqueDeviceError() raise NoUniqueDeviceError()
return AndroidDevice(devices[0], product, adb_path) return AndroidDevice(devices[0], product, adb_path)
def _get_device_by_serial(serial, product=None, adb_path='adb'): def _get_device_by_serial(
serial: str, product: str | None = None, adb_path: str = 'adb'
) -> AndroidDevice:
for device in get_devices(adb_path=adb_path): for device in get_devices(adb_path=adb_path):
if device == serial: if device == serial:
return AndroidDevice(serial, product, adb_path) return AndroidDevice(serial, product, adb_path)
raise DeviceNotFoundError(serial) raise DeviceNotFoundError(serial)
def get_device(serial=None, product=None, adb_path='adb'): def get_device(
serial: str | None = None, product: str | None = None, adb_path: str = 'adb'
) -> AndroidDevice:
"""Get a uniquely identified AndroidDevice if one is available. """Get a uniquely identified AndroidDevice if one is available.
Raises: Raises:
@@ -113,7 +124,7 @@ def get_device(serial=None, product=None, adb_path='adb'):
return _get_unique_device(product, adb_path=adb_path) return _get_unique_device(product, adb_path=adb_path)
def _get_device_by_type(flag, adb_path): def _get_device_by_type(flag: str, adb_path: str) -> AndroidDevice:
with open(os.devnull, 'wb') as devnull: with open(os.devnull, 'wb') as devnull:
subprocess.check_call([adb_path, 'start-server'], stdout=devnull, subprocess.check_call([adb_path, 'start-server'], stdout=devnull,
stderr=devnull) stderr=devnull)
@@ -127,7 +138,7 @@ def _get_device_by_type(flag, adb_path):
return _get_device_by_serial(serial, adb_path=adb_path) return _get_device_by_serial(serial, adb_path=adb_path)
def get_usb_device(adb_path='adb'): def get_usb_device(adb_path: str = 'adb') -> AndroidDevice:
"""Get the unique USB-connected AndroidDevice if it is available. """Get the unique USB-connected AndroidDevice if it is available.
Raises: Raises:
@@ -140,7 +151,7 @@ def get_usb_device(adb_path='adb'):
return _get_device_by_type('-d', adb_path=adb_path) return _get_device_by_type('-d', adb_path=adb_path)
def get_emulator_device(adb_path='adb'): def get_emulator_device(adb_path: str = 'adb') -> AndroidDevice:
"""Get the unique emulator AndroidDevice if it is available. """Get the unique emulator AndroidDevice if it is available.
Raises: Raises:
@@ -153,25 +164,31 @@ def get_emulator_device(adb_path='adb'):
return _get_device_by_type('-e', adb_path=adb_path) return _get_device_by_type('-e', adb_path=adb_path)
# TODO: Refactor so this invoked subprocess rather than returning arguments for it.
# This function is pretty type-resistant because it returns the arguments that should be
# passed to subprocess rather than the result of the call. Most of what's here looks
# like python2 workarounds anyway, so it might be something that can be done away with.
# For now, just return Any :(
#
# If necessary, modifies subprocess.check_output() or subprocess.Popen() args # If necessary, modifies subprocess.check_output() or subprocess.Popen() args
# to run the subprocess via Windows PowerShell to work-around an issue in # to run the subprocess via Windows PowerShell to work-around an issue in
# Python 2's subprocess class on Windows where it doesn't support Unicode. # Python 2's subprocess class on Windows where it doesn't support Unicode.
def _get_subprocess_args(args): def _get_subprocess_args(args: tuple[Any, ...]) -> tuple[Any, ...]:
# Only do this slow work-around if Unicode is in the cmd line on Windows. # Only do this slow work-around if Unicode is in the cmd line on Windows.
# PowerShell takes 600-700ms to startup on a 2013-2014 machine, which is # PowerShell takes 600-700ms to startup on a 2013-2014 machine, which is
# very slow. # very slow.
if os.name != 'nt' or all(not isinstance(arg, unicode) for arg in args[0]): if os.name != 'nt' or all(not isinstance(arg, unicode) for arg in args[0]):
return args return tuple(args)
def escape_arg(arg): def escape_arg(arg: str) -> str:
# Escape for the parsing that the C Runtime does in Windows apps. In # Escape for the parsing that the C Runtime does in Windows apps. In
# particular, this will take care of double-quotes. # particular, this will take care of double-quotes.
arg = subprocess.list2cmdline([arg]) arg = subprocess.list2cmdline([arg])
# Escape single-quote with another single-quote because we're about # Escape single-quote with another single-quote because we're about
# to... # to...
arg = arg.replace(u"'", u"''") arg = arg.replace("'", "''")
# ...put the arg in a single-quoted string for PowerShell to parse. # ...put the arg in a single-quoted string for PowerShell to parse.
arg = u"'" + arg + u"'" arg = "'" + arg + "'"
return arg return arg
# Escape command line args. # Escape command line args.
@@ -188,19 +205,19 @@ def _get_subprocess_args(args):
ps_code += u'\r\nExit $LastExitCode' ps_code += u'\r\nExit $LastExitCode'
# Encode as UTF-16LE (without Byte-Order-Mark) which Windows natively # Encode as UTF-16LE (without Byte-Order-Mark) which Windows natively
# understands. # understands.
ps_code = ps_code.encode('utf-16le') ps_code_encoded = ps_code.encode('utf-16le')
# Encode the PowerShell command as base64 and use the special # Encode the PowerShell command as base64 and use the special
# -EncodedCommand option that base64 decodes. Base64 is just plain ASCII, # -EncodedCommand option that base64 decodes. Base64 is just plain ASCII,
# so it should have no problem passing through Win32 CreateProcessA() # so it should have no problem passing through Win32 CreateProcessA()
# (which python erroneously calls instead of CreateProcessW()). # (which python erroneously calls instead of CreateProcessW()).
return (['powershell.exe', '-NoProfile', '-NonInteractive', return (['powershell.exe', '-NoProfile', '-NonInteractive',
'-EncodedCommand', base64.b64encode(ps_code)],) + args[1:] '-EncodedCommand', base64.b64encode(ps_code_encoded)],) + args[1:]
# Call this instead of subprocess.check_output() to work-around issue in Python # Call this instead of subprocess.check_output() to work-around issue in Python
# 2's subprocess class on Windows where it doesn't support Unicode. # 2's subprocess class on Windows where it doesn't support Unicode.
def _subprocess_check_output(*args, **kwargs): def _subprocess_check_output(*args: Any, **kwargs: Any) -> Any:
try: try:
return subprocess.check_output(*_get_subprocess_args(args), **kwargs) return subprocess.check_output(*_get_subprocess_args(args), **kwargs)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@@ -210,17 +227,17 @@ def _subprocess_check_output(*args, **kwargs):
# Call this instead of subprocess.Popen(). Like _subprocess_check_output(). # Call this instead of subprocess.Popen(). Like _subprocess_check_output().
def _subprocess_Popen(*args, **kwargs): def _subprocess_Popen(*args: Any, **kwargs: Any) -> Any:
return subprocess.Popen(*_get_subprocess_args(args), **kwargs) return subprocess.Popen(*_get_subprocess_args(args), **kwargs)
def split_lines(s): def split_lines(s: str) -> list[str]:
"""Splits lines in a way that works even on Windows and old devices. """Splits lines in a way that works even on Windows and old devices.
Windows will see \r\n instead of \n, old devices do the same, old devices Windows will see \r\n instead of \n, old devices do the same, old devices
on Windows will see \r\r\n. on Windows will see \r\r\n.
""" """
# rstrip is used here to workaround a difference between splineslines and # rstrip is used here to workaround a difference between splitlines and
# re.split: # re.split:
# >>> 'foo\n'.splitlines() # >>> 'foo\n'.splitlines()
# ['foo'] # ['foo']
@@ -229,12 +246,11 @@ def split_lines(s):
return re.split(r'[\r\n]+', s.rstrip()) return re.split(r'[\r\n]+', s.rstrip())
def version(adb_path=None): def version(adb_path: list[str] | None = None) -> int:
"""Get the version of adb (in terms of ADB_SERVER_VERSION).""" """Get the version of adb (in terms of ADB_SERVER_VERSION)."""
adb_path = adb_path if adb_path is not None else ['adb'] adb_path = adb_path if adb_path is not None else ['adb']
version_output = subprocess.check_output(adb_path + ['version']) version_output = subprocess.check_output(adb_path + ['version'], encoding='utf-8')
version_output = version_output.decode('utf-8')
pattern = r'^Android Debug Bridge version 1.0.(\d+)$' pattern = r'^Android Debug Bridge version 1.0.(\d+)$'
result = re.match(pattern, version_output.splitlines()[0]) result = re.match(pattern, version_output.splitlines()[0])
if not result: if not result:
@@ -259,28 +275,30 @@ class AndroidDevice(object):
_RETURN_CODE_SEARCH_LENGTH = len( _RETURN_CODE_SEARCH_LENGTH = len(
'{0}255\r\r\n'.format(_RETURN_CODE_DELIMITER)) '{0}255\r\r\n'.format(_RETURN_CODE_DELIMITER))
def __init__(self, serial, product=None, adb_path='adb'): def __init__(
self, serial: str | None, product: str | None = None, adb_path: str = 'adb'
) -> None:
self.serial = serial self.serial = serial
self.product = product self.product = product
self.adb_path = adb_path self.adb_path = adb_path
self.adb_cmd = [adb_path] self.adb_cmd = [adb_path]
if self.serial is not None: if self.serial is not None:
self.adb_cmd.extend(['-s', serial]) self.adb_cmd.extend(['-s', self.serial])
if self.product is not None: if self.product is not None:
self.adb_cmd.extend(['-p', product]) self.adb_cmd.extend(['-p', self.product])
self._linesep = None self._linesep: str | None = None
self._features = None self._features: list[str] | None = None
@property @property
def linesep(self): def linesep(self) -> str:
if self._linesep is None: if self._linesep is None:
self._linesep = subprocess.check_output( self._linesep = subprocess.check_output(
self.adb_cmd + ['shell', 'echo']).decode('utf-8') self.adb_cmd + ['shell', 'echo'], encoding='utf-8')
return self._linesep return self._linesep
@property @property
def features(self): def features(self) -> list[str]:
if self._features is None: if self._features is None:
try: try:
self._features = split_lines(self._simple_call(['features'])) self._features = split_lines(self._simple_call(['features']))
@@ -288,16 +306,16 @@ class AndroidDevice(object):
self._features = [] self._features = []
return self._features return self._features
def has_shell_protocol(self): def has_shell_protocol(self) -> bool:
return version(self.adb_cmd) >= 35 and 'shell_v2' in self.features return version(self.adb_cmd) >= 35 and 'shell_v2' in self.features
def _make_shell_cmd(self, user_cmd): def _make_shell_cmd(self, user_cmd: list[str]) -> list[str]:
command = self.adb_cmd + ['shell'] + user_cmd command = self.adb_cmd + ['shell'] + user_cmd
if not self.has_shell_protocol(): if not self.has_shell_protocol():
command += self._RETURN_CODE_PROBE command += self._RETURN_CODE_PROBE
return command return command
def _parse_shell_output(self, out): def _parse_shell_output(self, out: str) -> tuple[int, str]:
"""Finds the exit code string from shell output. """Finds the exit code string from shell output.
Args: Args:
@@ -325,12 +343,12 @@ class AndroidDevice(object):
out = out[:-len(partition[1]) - len(partition[2])] out = out[:-len(partition[1]) - len(partition[2])]
return result, out return result, out
def _simple_call(self, cmd): def _simple_call(self, cmd: list[str]) -> str:
logging.info(' '.join(self.adb_cmd + cmd)) logging.info(' '.join(self.adb_cmd + cmd))
return _subprocess_check_output( return _subprocess_check_output(
self.adb_cmd + cmd, stderr=subprocess.STDOUT).decode('utf-8') self.adb_cmd + cmd, stderr=subprocess.STDOUT).decode('utf-8')
def shell(self, cmd): def shell(self, cmd: list[str]) -> tuple[str, str]:
"""Calls `adb shell` """Calls `adb shell`
Args: Args:
@@ -348,7 +366,7 @@ class AndroidDevice(object):
raise ShellError(cmd, stdout, stderr, exit_code) raise ShellError(cmd, stdout, stderr, exit_code)
return stdout, stderr return stdout, stderr
def shell_nocheck(self, cmd): def shell_nocheck(self, cmd: list[str]) -> tuple[int, str, str]:
"""Calls `adb shell` """Calls `adb shell`
Args: Args:
@@ -371,8 +389,14 @@ class AndroidDevice(object):
exit_code, stdout = self._parse_shell_output(stdout) exit_code, stdout = self._parse_shell_output(stdout)
return exit_code, stdout, stderr return exit_code, stdout, stderr
def shell_popen(self, cmd, kill_atexit=True, preexec_fn=None, def shell_popen(
creationflags=0, **kwargs): self,
cmd: list[str],
kill_atexit: bool = True,
preexec_fn: Callable[[], None] | None = None,
creationflags: int = 0,
**kwargs: Any,
) -> subprocess.Popen[Any]:
"""Calls `adb shell` and returns a handle to the adb process. """Calls `adb shell` and returns a handle to the adb process.
This function provides direct access to the subprocess used to run the This function provides direct access to the subprocess used to run the
@@ -400,7 +424,7 @@ class AndroidDevice(object):
preexec_fn = os.setpgrp preexec_fn = os.setpgrp
elif preexec_fn is not os.setpgrp: elif preexec_fn is not os.setpgrp:
fn = preexec_fn fn = preexec_fn
def _wrapper(): def _wrapper() -> None:
fn() fn()
os.setpgrp() os.setpgrp()
preexec_fn = _wrapper preexec_fn = _wrapper
@@ -413,14 +437,14 @@ class AndroidDevice(object):
return p return p
def install(self, filename, replace=False): def install(self, filename: str, replace: bool = False) -> str:
cmd = ['install'] cmd = ['install']
if replace: if replace:
cmd.append('-r') cmd.append('-r')
cmd.append(filename) cmd.append(filename)
return self._simple_call(cmd) return self._simple_call(cmd)
def push(self, local, remote, sync=False): def push(self, local: str | list[str], remote: str, sync: bool = False) -> str:
"""Transfer a local file or directory to the device. """Transfer a local file or directory to the device.
Args: Args:
@@ -430,7 +454,7 @@ class AndroidDevice(object):
those on the device. If False, transfers all files. those on the device. If False, transfers all files.
Returns: Returns:
Exit status of the push command. Output of the command.
""" """
cmd = ['push'] cmd = ['push']
if sync: if sync:
@@ -444,73 +468,73 @@ class AndroidDevice(object):
return self._simple_call(cmd) return self._simple_call(cmd)
def pull(self, remote, local): def pull(self, remote: str, local: str) -> str:
return self._simple_call(['pull', remote, local]) return self._simple_call(['pull', remote, local])
def sync(self, directory=None): def sync(self, directory: str | None = None) -> str:
cmd = ['sync'] cmd = ['sync']
if directory is not None: if directory is not None:
cmd.append(directory) cmd.append(directory)
return self._simple_call(cmd) return self._simple_call(cmd)
def tcpip(self, port): def tcpip(self, port: str) -> str:
return self._simple_call(['tcpip', port]) return self._simple_call(['tcpip', port])
def usb(self): def usb(self) -> str:
return self._simple_call(['usb']) return self._simple_call(['usb'])
def reboot(self): def reboot(self) -> str:
return self._simple_call(['reboot']) return self._simple_call(['reboot'])
def remount(self): def remount(self) -> str:
return self._simple_call(['remount']) return self._simple_call(['remount'])
def root(self): def root(self) -> str:
return self._simple_call(['root']) return self._simple_call(['root'])
def unroot(self): def unroot(self) -> str:
return self._simple_call(['unroot']) return self._simple_call(['unroot'])
def connect(self, host): def connect(self, host: str) -> str:
return self._simple_call(['connect', host]) return self._simple_call(['connect', host])
def disconnect(self, host): def disconnect(self, host: str) -> str:
return self._simple_call(['disconnect', host]) return self._simple_call(['disconnect', host])
def forward(self, local, remote): def forward(self, local: str, remote: str) -> str:
return self._simple_call(['forward', local, remote]) return self._simple_call(['forward', local, remote])
def forward_list(self): def forward_list(self) -> str:
return self._simple_call(['forward', '--list']) return self._simple_call(['forward', '--list'])
def forward_no_rebind(self, local, remote): def forward_no_rebind(self, local: str, remote: str) -> str:
return self._simple_call(['forward', '--no-rebind', local, remote]) return self._simple_call(['forward', '--no-rebind', local, remote])
def forward_remove(self, local): def forward_remove(self, local: str) -> str:
return self._simple_call(['forward', '--remove', local]) return self._simple_call(['forward', '--remove', local])
def forward_remove_all(self): def forward_remove_all(self) -> str:
return self._simple_call(['forward', '--remove-all']) return self._simple_call(['forward', '--remove-all'])
def reverse(self, remote, local): def reverse(self, remote: str, local: str) -> str:
return self._simple_call(['reverse', remote, local]) return self._simple_call(['reverse', remote, local])
def reverse_list(self): def reverse_list(self) -> str:
return self._simple_call(['reverse', '--list']) return self._simple_call(['reverse', '--list'])
def reverse_no_rebind(self, local, remote): def reverse_no_rebind(self, local: str, remote: str) -> str:
return self._simple_call(['reverse', '--no-rebind', local, remote]) return self._simple_call(['reverse', '--no-rebind', local, remote])
def reverse_remove_all(self): def reverse_remove_all(self) -> str:
return self._simple_call(['reverse', '--remove-all']) return self._simple_call(['reverse', '--remove-all'])
def reverse_remove(self, remote): def reverse_remove(self, remote: str) -> str:
return self._simple_call(['reverse', '--remove', remote]) return self._simple_call(['reverse', '--remove', remote])
def wait(self): def wait(self) -> str:
return self._simple_call(['wait-for-device']) return self._simple_call(['wait-for-device'])
def get_prop(self, prop_name): def get_prop(self, prop_name: str) -> str | None:
output = split_lines(self.shell(['getprop', prop_name])[0]) output = split_lines(self.shell(['getprop', prop_name])[0])
if len(output) != 1: if len(output) != 1:
raise RuntimeError('Too many lines in getprop output:\n' + raise RuntimeError('Too many lines in getprop output:\n' +
@@ -520,7 +544,7 @@ class AndroidDevice(object):
return None return None
return value return value
def set_prop(self, prop_name, value): def set_prop(self, prop_name: str, value: str) -> None:
self.shell(['setprop', prop_name, value]) self.shell(['setprop', prop_name, value])
def logcat(self) -> str: def logcat(self) -> str:

View File

@@ -0,0 +1,19 @@
[mypy]
check_untyped_defs = true
disallow_any_generics = true
disallow_any_unimported = true
disallow_subclassing_any = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
follow_imports = silent
implicit_reexport = false
namespace_packages = true
no_implicit_optional = true
show_error_codes = true
strict_equality = true
warn_redundant_casts = true
# Can't enable because mypy cannot reason about _get_subprocess_args.
# warn_return_any = true
warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true

View File

@@ -20,12 +20,12 @@ import mock
import adb import adb
class GetDeviceTest(unittest.TestCase): class GetDeviceTest(unittest.TestCase):
def setUp(self): def setUp(self) -> None:
self.android_serial = os.getenv('ANDROID_SERIAL') self.android_serial = os.getenv('ANDROID_SERIAL')
if 'ANDROID_SERIAL' in os.environ: if 'ANDROID_SERIAL' in os.environ:
del os.environ['ANDROID_SERIAL'] del os.environ['ANDROID_SERIAL']
def tearDown(self): def tearDown(self) -> None:
if self.android_serial is not None: if self.android_serial is not None:
os.environ['ANDROID_SERIAL'] = self.android_serial os.environ['ANDROID_SERIAL'] = self.android_serial
else: else:
@@ -33,27 +33,27 @@ class GetDeviceTest(unittest.TestCase):
del os.environ['ANDROID_SERIAL'] del os.environ['ANDROID_SERIAL']
@mock.patch('adb.device.get_devices') @mock.patch('adb.device.get_devices')
def test_explicit(self, mock_get_devices): def test_explicit(self, mock_get_devices: mock.Mock) -> None:
mock_get_devices.return_value = ['foo', 'bar'] mock_get_devices.return_value = ['foo', 'bar']
device = adb.get_device('foo') device = adb.get_device('foo')
self.assertEqual(device.serial, 'foo') self.assertEqual(device.serial, 'foo')
@mock.patch('adb.device.get_devices') @mock.patch('adb.device.get_devices')
def test_from_env(self, mock_get_devices): def test_from_env(self, mock_get_devices: mock.Mock) -> None:
mock_get_devices.return_value = ['foo', 'bar'] mock_get_devices.return_value = ['foo', 'bar']
os.environ['ANDROID_SERIAL'] = 'foo' os.environ['ANDROID_SERIAL'] = 'foo'
device = adb.get_device() device = adb.get_device()
self.assertEqual(device.serial, 'foo') self.assertEqual(device.serial, 'foo')
@mock.patch('adb.device.get_devices') @mock.patch('adb.device.get_devices')
def test_arg_beats_env(self, mock_get_devices): def test_arg_beats_env(self, mock_get_devices: mock.Mock) -> None:
mock_get_devices.return_value = ['foo', 'bar'] mock_get_devices.return_value = ['foo', 'bar']
os.environ['ANDROID_SERIAL'] = 'bar' os.environ['ANDROID_SERIAL'] = 'bar'
device = adb.get_device('foo') device = adb.get_device('foo')
self.assertEqual(device.serial, 'foo') self.assertEqual(device.serial, 'foo')
@mock.patch('adb.device.get_devices') @mock.patch('adb.device.get_devices')
def test_no_such_device(self, mock_get_devices): def test_no_such_device(self, mock_get_devices: mock.Mock) -> None:
mock_get_devices.return_value = ['foo', 'bar'] mock_get_devices.return_value = ['foo', 'bar']
self.assertRaises(adb.DeviceNotFoundError, adb.get_device, ['baz']) self.assertRaises(adb.DeviceNotFoundError, adb.get_device, ['baz'])
@@ -61,18 +61,18 @@ class GetDeviceTest(unittest.TestCase):
self.assertRaises(adb.DeviceNotFoundError, adb.get_device) self.assertRaises(adb.DeviceNotFoundError, adb.get_device)
@mock.patch('adb.device.get_devices') @mock.patch('adb.device.get_devices')
def test_unique_device(self, mock_get_devices): def test_unique_device(self, mock_get_devices: mock.Mock) -> None:
mock_get_devices.return_value = ['foo'] mock_get_devices.return_value = ['foo']
device = adb.get_device() device = adb.get_device()
self.assertEqual(device.serial, 'foo') self.assertEqual(device.serial, 'foo')
@mock.patch('adb.device.get_devices') @mock.patch('adb.device.get_devices')
def test_no_unique_device(self, mock_get_devices): def test_no_unique_device(self, mock_get_devices: mock.Mock) -> None:
mock_get_devices.return_value = ['foo', 'bar'] mock_get_devices.return_value = ['foo', 'bar']
self.assertRaises(adb.NoUniqueDeviceError, adb.get_device) self.assertRaises(adb.NoUniqueDeviceError, adb.get_device)
def main(): def main() -> None:
suite = unittest.TestLoader().loadTestsFromName(__name__) suite = unittest.TestLoader().loadTestsFromName(__name__)
unittest.TextTestRunner(verbosity=3).run(suite) unittest.TextTestRunner(verbosity=3).run(suite)