Merge "sourcedr: Add VNDK dependency check tool"

This commit is contained in:
Logan Chien
2018-02-13 04:42:22 +00:00
committed by Gerrit Code Review
9 changed files with 2039 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
*#
*.py[co]
*.sw[op]
*~
__pycache__

View File

@@ -0,0 +1,805 @@
#!/usr/bin/env python3
#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module implements a Android.bp parser."""
import collections
import glob
import itertools
import os
import re
import sys
#------------------------------------------------------------------------------
# Python 2 compatibility
#------------------------------------------------------------------------------
if sys.version_info >= (3,):
py3_chr = chr # pylint: disable=invalid-name
else:
def py3_chr(codepoint):
"""Convert an integer character codepoint into a utf-8 string."""
return unichr(codepoint).encode('utf-8')
try:
from enum import Enum
except ImportError:
class _Enum(object): # pylint: disable=too-few-public-methods
"""A name-value pair for each enumeration."""
__slot__ = ('name', 'value')
def __init__(self, name, value):
"""Create a name-value pair."""
self.name = name
self.value = value
def __repr__(self):
"""Return the name of the enumeration."""
return self.name
class _EnumMeta(type): # pylint: disable=too-few-public-methods
"""Metaclass for Enum base class."""
def __new__(mcs, name, bases, attrs):
"""Collects enumerations from attributes of the derived classes."""
enums = []
new_attrs = {'_enums': enums}
for key, value in attrs.iteritems():
if key.startswith('_'):
new_attrs[key] = value
else:
item = _Enum(key, value)
enums.append(item)
new_attrs[key] = item
return type.__new__(mcs, name, bases, new_attrs)
def __iter__(cls):
"""Iterate the list of enumerations."""
return iter(cls._enums)
class Enum(object): # pylint: disable=too-few-public-methods
"""Enum base class."""
__metaclass__ = _EnumMeta
#------------------------------------------------------------------------------
# Lexer
#------------------------------------------------------------------------------
class Token(Enum): # pylint: disable=too-few-public-methods
"""Token enumerations."""
EOF = 0
IDENT = 1
LPAREN = 2
RPAREN = 3
LBRACKET = 4
RBRACKET = 5
LBRACE = 6
RBRACE = 7
COLON = 8
ASSIGN = 9
ASSIGNPLUS = 10
PLUS = 11
COMMA = 12
STRING = 13
COMMENT = 14
SPACE = 15
class LexerError(ValueError):
"""Lexer error exception class."""
def __init__(self, buf, pos, message):
"""Create a lexer error exception object."""
super(LexerError, self).__init__(message)
self.message = message
self.line, self.column = Lexer.compute_line_column(buf, pos)
def __str__(self):
"""Convert lexer error to string representation."""
return 'LexerError: {}:{}: {}'.format(
self.line, self.column, self.message)
class Lexer(object):
"""Lexer to tokenize the input string."""
def __init__(self, buf, offset=0):
"""Tokenize the source code in buf starting from offset.
Args:
buf (string) The source code to be tokenized.
offset (int) The position to start.
"""
self.buf = buf
self.start = None
self.end = offset
self.token = None
self.literal = None
self._next()
def consume(self, *tokens):
"""Consume one or more token."""
for token in tokens:
if token == self.token:
self._next()
else:
raise LexerError(self.buf, self.start,
'unexpected token ' + self.token.name)
def _next(self):
"""Read next non-comment non-space token."""
buf_len = len(self.buf)
while self.end < buf_len:
self.start = self.end
self.token, self.end, self.literal = self.lex(self.buf, self.start)
if self.token != Token.SPACE and self.token != Token.COMMENT:
return
self.start = self.end
self.token = Token.EOF
self.literal = None
@staticmethod
def compute_line_column(buf, pos):
"""Compute the line number and the column number of a given position in
the buffer."""
prior = buf[0:pos]
newline_pos = prior.rfind('\n')
if newline_pos == -1:
return (1, pos + 1)
return (prior.count('\n') + 1, pos - newline_pos)
UNICODE_CHARS_PATTERN = re.compile('[^\\\\\\n"]+')
ESCAPE_CHAR_TABLE = {
'a': '\a', 'b': '\b', 'f': '\f', 'n': '\n', 'r': '\r', 't': '\t',
'v': '\v', '\\': '\\', '\'': '\'', '\"': '\"',
}
OCT_TABLE = {str(i) for i in range(8)}
@staticmethod
def decode_oct(buf, offset, start, end):
"""Read characters from buf[start:end] and interpret them as an octal
integer."""
if end > len(buf):
raise LexerError(buf, offset, 'bad octal escape sequence')
try:
codepoint = int(buf[start:end], 8)
except ValueError:
raise LexerError(buf, offset, 'bad octal escape sequence')
if codepoint > 0xff:
raise LexerError(buf, offset, 'bad octal escape sequence')
return codepoint
@staticmethod
def decode_hex(buf, offset, start, end):
"""Read characters from buf[start:end] and interpret them as a
hexadecimal integer."""
if end > len(buf):
raise LexerError(buf, offset, 'bad hex escape sequence')
try:
return int(buf[start:end], 16)
except ValueError:
raise LexerError(buf, offset, 'bad hex escape sequence')
@classmethod
def lex_interpreted_string(cls, buf, offset):
"""Tokenize a golang interpreted string.
Args:
buf (str) The source code buffer.
offset (int) The position to find a golang interpreted string
literal.
Returns:
A tuple with the end of matched buffer and the interpreted string
literal.
"""
buf_len = len(buf)
pos = offset + 1
literal = ''
while pos < buf_len:
# Match unicode characters
match = cls.UNICODE_CHARS_PATTERN.match(buf, pos)
if match:
literal += match.group(0)
pos = match.end()
# Read the next character
try:
char = buf[pos]
except IndexError:
raise LexerError(buf, pos,
'unclosed interpreted string literal')
if char == '\\':
# Escape sequences
try:
char = buf[pos + 1]
except IndexError:
raise LexerError(buf, pos, 'bad escape sequence')
if char in cls.OCT_TABLE:
literal += chr(cls.decode_oct(buf, pos, pos + 1, pos + 4))
pos += 4
elif char == 'x':
literal += chr(cls.decode_hex(buf, pos, pos + 2, pos + 4))
pos += 4
elif char == 'u':
literal += py3_chr(
cls.decode_hex(buf, pos, pos + 2, pos + 6))
pos += 6
elif char == 'U':
literal += py3_chr(
cls.decode_hex(buf, pos, pos + 2, pos + 10))
pos += 10
else:
try:
literal += cls.ESCAPE_CHAR_TABLE[char]
pos += 2
except KeyError:
raise LexerError(buf, pos, 'bad escape sequence')
continue
if char == '"':
# End of string literal
return (pos + 1, literal)
raise LexerError(buf, pos, 'unclosed interpreted string literal')
@classmethod
def lex_string(cls, buf, offset):
"""Tokenize a golang string literal.
Args:
buf (str) The source code buffer.
offset (int) The position to find a golang string literal.
Returns:
A tuple with the end of matched buffer and the interpreted string
literal.
"""
char = buf[offset]
if char == '`':
try:
end = buf.index('`', offset + 1)
return (end + 1, buf[offset + 1 : end])
except ValueError:
raise LexerError(buf, len(buf), 'unclosed raw string literal')
if char == '"':
return cls.lex_interpreted_string(buf, offset)
raise LexerError(buf, offset, 'no string literal start character')
LEXER_PATTERNS = (
(Token.IDENT, '[A-Za-z_][0-9A-Za-z_]*'),
(Token.LPAREN, '\\('),
(Token.RPAREN, '\\)'),
(Token.LBRACKET, '\\['),
(Token.RBRACKET, '\\]'),
(Token.LBRACE, '\\{'),
(Token.RBRACE, '\\}'),
(Token.COLON, ':'),
(Token.ASSIGN, '='),
(Token.ASSIGNPLUS, '\\+='),
(Token.PLUS, '\\+'),
(Token.COMMA, ','),
(Token.STRING, '["`]'),
(Token.COMMENT,
'/(?:(?:/[^\\n]*)|(?:\\*(?:(?:[^*]*)|(?:\\*+[^/*]))*\\*+/))'),
(Token.SPACE, '\\s+'),
)
LEXER_MATCHER = re.compile('|'.join(
'(' + pattern + ')' for _, pattern in LEXER_PATTERNS))
@classmethod
def lex(cls, buf, offset):
"""Tokenize a token from buf[offset].
Args:
buf (string) The source code buffer.
offset (int) The position to find and tokenize a token.
Return:
A tuple with three elements. The first element is the token id.
The second element is the end of the token. The third element is
the value for strings or identifiers.
"""
match = cls.LEXER_MATCHER.match(buf, offset)
if not match:
raise LexerError(buf, offset, 'unknown token')
token = cls.LEXER_PATTERNS[match.lastindex - 1][0]
if token == Token.STRING:
end, literal = cls.lex_string(buf, offset)
else:
end = match.end()
literal = buf[offset:end] if token == Token.IDENT else None
return (token, end, literal)
#------------------------------------------------------------------------------
# AST
#------------------------------------------------------------------------------
class Expr(object): # pylint: disable=too-few-public-methods
"""Base class for all expressions."""
def eval(self, env):
"""Evaluate the expression under an environment."""
raise NotImplementedError()
class String(Expr, str):
"""String constant literal."""
def eval(self, env):
"""Evaluate the string expression under an environment."""
return self
class Bool(Expr): # pylint: disable=too-few-public-methods
"""Boolean constant literal."""
__slots__ = ('value',)
def __init__(self, value):
"""Create a boolean constant literal."""
self.value = value
def __repr__(self):
"""Convert a boolean constant literal to string representation."""
return repr(self.value)
def __bool__(self):
"""Convert boolean constant literal to Python bool type."""
return self.value
__nonzero__ = __bool__
def __eq__(self, rhs):
"""Compare whether two instances are equal."""
return self.value == rhs.value
def __hash__(self):
"""Compute the hashed value."""
return hash(self.value)
def eval(self, env):
"""Evaluate the boolean expression under an environment."""
return self
class VarRef(Expr): # pylint: disable=too-few-public-methods
"""A reference to a variable."""
def __init__(self, name, value):
"""Create a variable reference with a name and the value under static
scoping."""
self.name = name
self.value = value
def __repr__(self):
"""Convert a variable reference to string representation."""
return self.name
def eval(self, env):
"""Evaluate the identifier under an environment."""
if self.value is None:
return env[self.name].eval(env)
return self.value.eval(env)
class List(Expr, list):
"""List expression."""
def eval(self, env):
"""Evaluate list elements under an environment."""
return List(item.eval(env) for item in self)
class Dict(Expr, collections.OrderedDict):
"""Dictionary expression."""
def __repr__(self):
attrs = ', '.join(key + ': ' + repr(value)
for key, value in self.items())
return '{' + attrs + '}'
def eval(self, env):
"""Evaluate dictionary values under an environment."""
return Dict((key, value.eval(env)) for key, value in self.items())
class Concat(Expr): # pylint: disable=too-few-public-methods
"""List/string concatenation operator."""
__slots__ = ('lhs', 'rhs')
def __init__(self, lhs, rhs):
"""Create a list concatenation expression."""
self.lhs = lhs
self.rhs = rhs
def __repr__(self):
return '(' + repr(self.lhs) + ' + ' + repr(self.rhs) + ')'
def eval(self, env):
"""Evaluate list concatenation operator under an environment."""
lhs = self.lhs.eval(env)
rhs = self.rhs.eval(env)
if isinstance(lhs, List) and isinstance(rhs, List):
return List(itertools.chain(lhs, rhs))
if isinstance(lhs, String) and isinstance(rhs, String):
return String(lhs + rhs)
raise TypeError('bad concatenation')
#------------------------------------------------------------------------------
# Parser
#------------------------------------------------------------------------------
class ParseError(ValueError):
"""Parser error exception class."""
def __init__(self, lexer, message):
"""Create a parser error exception object."""
super(ParseError, self).__init__(message)
self.message = message
self.line, self.column = \
Lexer.compute_line_column(lexer.buf, lexer.start)
def __str__(self):
"""Convert parser error to string representation."""
return 'ParseError: {}:{}: {}'.format(
self.line, self.column, self.message)
class Parser(object):
"""Parser to parse Android.bp files."""
def __init__(self, lexer, inherited_env=None):
"""Initialize the parser with the lexer."""
self.lexer = lexer
self.var_defs = []
self.vars = {} if inherited_env is None else dict(inherited_env)
self.modules = []
def parse(self):
"""Parse AST from tokens."""
lexer = self.lexer
while lexer.token != Token.EOF:
if lexer.token == Token.IDENT:
ident = self.parse_ident_lvalue()
if lexer.token in {Token.ASSIGN, Token.ASSIGNPLUS}:
self.parse_assign(ident, lexer.token)
elif lexer.token in {Token.LBRACE, Token.LPAREN}:
self.parse_module_definition(ident)
else:
raise ParseError(lexer,
'unexpected token ' + lexer.token.name)
else:
raise ParseError(lexer, 'unexpected token ' + lexer.token.name)
lexer.consume(Token.EOF)
def create_var_ref(self, name):
"""Create a variable reference."""
return VarRef(name, self.vars.get(name))
def define_var(self, name, value):
"""Define a variable."""
self.var_defs.append((name, value))
self.vars[name] = value
def parse_assign(self, ident, assign_token):
"""Parse an assignment statement."""
lexer = self.lexer
lexer.consume(assign_token)
value = self.parse_expression()
if assign_token == Token.ASSIGNPLUS:
value = Concat(self.create_var_ref(ident), value)
self.define_var(ident, value)
def parse_module_definition(self, module_ident):
"""Parse a module definition."""
properties = self.parse_dict()
self.modules.append((module_ident, properties))
def parse_ident_lvalue(self):
"""Parse an identifier as an l-value."""
ident = self.lexer.literal
self.lexer.consume(Token.IDENT)
return ident
def parse_ident_rvalue(self):
"""Parse an identifier as a r-value.
Returns:
Returns VarRef if the literal is not 'true' nor 'false'.
Returns Bool(true/false) if the literal is either 'true' or 'false'.
"""
lexer = self.lexer
if lexer.literal in {'true', 'false'}:
result = Bool(lexer.literal == 'true')
else:
result = self.create_var_ref(lexer.literal)
lexer.consume(Token.IDENT)
return result
def parse_string(self):
"""Parse a string."""
lexer = self.lexer
string = String(lexer.literal)
lexer.consume(Token.STRING)
return string
def parse_operand(self):
"""Parse an operand."""
lexer = self.lexer
token = lexer.token
if token == Token.STRING:
return self.parse_string()
if token == Token.IDENT:
return self.parse_ident_rvalue()
if token == Token.LBRACKET:
return self.parse_list()
if token == Token.LBRACE:
return self.parse_dict()
if token == Token.LPAREN:
lexer.consume(Token.LPAREN)
operand = self.parse_expression()
lexer.consume(Token.RPAREN)
return operand
raise ParseError(lexer, 'unexpected token ' + token.name)
def parse_expression(self):
"""Parse an expression."""
lexer = self.lexer
expr = self.parse_operand()
while lexer.token == Token.PLUS:
lexer.consume(Token.PLUS)
expr = Concat(expr, self.parse_operand())
return expr
def parse_list(self):
"""Parse a list."""
result = List()
lexer = self.lexer
lexer.consume(Token.LBRACKET)
while lexer.token != Token.RBRACKET:
result.append(self.parse_expression())
if lexer.token == Token.COMMA:
lexer.consume(Token.COMMA)
lexer.consume(Token.RBRACKET)
return result
def parse_dict(self):
"""Parse a dict."""
result = Dict()
lexer = self.lexer
is_func_syntax = lexer.token == Token.LPAREN
if is_func_syntax:
lexer.consume(Token.LPAREN)
else:
lexer.consume(Token.LBRACE)
while lexer.token != Token.RBRACE and lexer.token != Token.RPAREN:
if lexer.token != Token.IDENT:
raise ParseError(lexer, 'unexpected token ' + lexer.token.name)
key = self.parse_ident_lvalue()
if lexer.token == Token.ASSIGN:
lexer.consume(Token.ASSIGN)
else:
lexer.consume(Token.COLON)
value = self.parse_expression()
result[key] = value
if lexer.token == Token.COMMA:
lexer.consume(Token.COMMA)
if is_func_syntax:
lexer.consume(Token.RPAREN)
else:
lexer.consume(Token.RBRACE)
return result
class RecursiveParser(object):
"""This is a recursive parser which will parse blueprint files
recursively."""
def __init__(self):
self.visited = set()
self.modules = []
@staticmethod
def glob_sub_files(pattern, sub_file_name):
"""List the sub file paths that match with the pattern with
wildcards."""
for path in glob.glob(pattern):
if os.path.isfile(path) and os.path.basename(path) == sub_file_name:
yield path
continue
sub_file_path = os.path.join(path, sub_file_name)
if os.path.isfile(sub_file_path):
yield sub_file_path
@classmethod
def find_sub_files_from_env(cls, rootdir, env,
default_sub_name='Android.bp'):
"""Find the sub files from the names specified in build, subdirs, and
optional_subdirs."""
subs = []
if 'build' in env:
subs.extend(os.path.join(rootdir, filename)
for filename in env['build'].eval(env))
sub_name = env['subname'] if 'subname' in env else default_sub_name
if 'subdirs' in env:
for path in env['subdirs'].eval(env):
subs.extend(
cls.glob_sub_files(os.path.join(rootdir, path), sub_name))
if 'optional_subdirs' in env:
for path in env['optional_subdirs'].eval(env):
subs.extend(
cls.glob_sub_files(os.path.join(rootdir, path), sub_name))
return subs
@staticmethod
def _parse_one_file(path, env):
with open(path, 'r') as bp_file:
content = bp_file.read()
parser = Parser(Lexer(content), env)
parser.parse()
return (parser.modules, parser.vars)
def _parse_file(self, path, env, evaluate):
"""Parse blueprint files recursively."""
self.visited.add(os.path.abspath(path))
modules, sub_env = self._parse_one_file(path, env)
if evaluate:
modules = [(ident, attrs.eval(env)) for ident, attrs in modules]
self.modules += modules
rootdir = os.path.dirname(path)
sub_file_paths = self.find_sub_files_from_env(rootdir, sub_env)
sub_env.pop('build', None)
sub_env.pop('subdirs', None)
sub_env.pop('optional_subdirs', None)
for sub_file_path in sub_file_paths:
if os.path.abspath(sub_file_path) not in self.visited:
self._parse_file(sub_file_path, sub_env, evaluate)
def parse_file(self, path, env=None, evaluate=True):
"""Parse blueprint files recursively."""
self._parse_file(path, {} if env is None else env, evaluate)
#------------------------------------------------------------------------------
# Transformation
#------------------------------------------------------------------------------
def evaluate_default(attrs, default_attrs):
"""Add default attributes if the keys do not exist."""
for key, value in default_attrs.items():
if key not in attrs:
attrs[key] = value
else:
attrs_value = attrs[key]
if isinstance(value, Dict) and isinstance(attrs_value, Dict):
attrs[key] = evaluate_default(attrs_value, value)
return attrs
def evaluate_defaults(modules):
"""Add default attributes to all modules if the keys do not exist."""
mods = {}
for ident, attrs in modules:
mods[attrs['name']] = (ident, attrs)
for i, (ident, attrs) in enumerate(modules):
defaults = attrs.get('defaults')
if defaults is None:
continue
for default in defaults:
attrs = evaluate_default(attrs, mods[default][1])
modules[i] = (ident, attrs)
return modules

View File

@@ -0,0 +1,186 @@
#!/usr/bin/env python3
#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This script scans all Android.bp in an android source tree and check the
correctness of dependencies."""
from __future__ import print_function
import argparse
import itertools
import sys
from blueprint import RecursiveParser, evaluate_defaults
def _is_vndk(module):
"""Get the `vndk.enabled` module property."""
try:
return bool(module['vndk']['enabled'])
except KeyError:
return False
def _is_vndk_sp(module):
"""Get the `vndk.support_system_process` module property."""
try:
return bool(module['vndk']['support_system_process'])
except KeyError:
return False
def _is_vendor(module):
"""Get the `vendor` module property."""
try:
return bool(module['vendor'])
except KeyError:
return False
def _is_vendor_available(module):
"""Get the `vendor_available` module property."""
try:
return bool(module['vendor_available'])
except KeyError:
return False
def _has_vendor_variant(module):
"""Check whether the module is VNDK or vendor available."""
return _is_vndk(module) or _is_vendor_available(module)
def _get_dependencies(module):
"""Get module dependencies."""
shared_libs = set(module.get('shared_libs', []))
static_libs = set(module.get('static_libs', []))
header_libs = set(module.get('header_libs', []))
try:
target_vendor = module['target']['vendor']
shared_libs -= set(target_vendor.get('exclude_shared_libs', []))
static_libs -= set(target_vendor.get('exclude_static_libs', []))
header_libs -= set(target_vendor.get('exclude_header_libs', []))
except KeyError:
pass
return (sorted(shared_libs), sorted(static_libs), sorted(header_libs))
def _build_module_dict(modules):
"""Build module dictionaries that map module names to modules."""
all_libs = {}
llndk_libs = {}
for rule, module in modules:
name = module['name']
if rule == 'llndk_library':
llndk_libs[name] = (rule, module)
if rule in {'llndk_library', 'ndk_library'}:
continue
if rule.endswith('_library') or \
rule.endswith('_library_shared') or \
rule.endswith('_library_static') or \
rule.endswith('_headers'):
all_libs[name] = (rule, module)
return (all_libs, llndk_libs)
def _check_module_deps(all_libs, llndk_libs, module):
"""Check the dependencies of a module."""
bad_deps = set()
shared_deps, static_deps, header_deps = _get_dependencies(module)
# Check vendor module dependencies requirements.
for dep_name in itertools.chain(shared_deps, static_deps, header_deps):
if dep_name in llndk_libs:
continue
dep_module = all_libs[dep_name][1]
if _is_vendor(dep_module):
continue
if _is_vendor_available(dep_module):
continue
if _is_vndk(dep_module) and not _is_vendor(module):
continue
bad_deps.add(dep_name)
# Check VNDK dependencies requirements.
if _is_vndk(module) and not _is_vendor(module):
is_vndk_sp = _is_vndk_sp(module)
for dep_name in shared_deps:
if dep_name in llndk_libs:
continue
dep_module = all_libs[dep_name][1]
if not _is_vndk(dep_module):
# VNDK must be self-contained.
bad_deps.add(dep_name)
break
if is_vndk_sp and not _is_vndk_sp(dep_module):
# VNDK-SP must be self-contained.
bad_deps.add(dep_name)
break
return bad_deps
def _check_modules_deps(modules):
"""Check the dependencies of modules."""
all_libs, llndk_libs = _build_module_dict(modules)
# Check the dependencies of modules
all_bad_deps = []
for name, (_, module) in all_libs.items():
if not _has_vendor_variant(module) and not _is_vendor(module):
continue
bad_deps = _check_module_deps(all_libs, llndk_libs, module)
if bad_deps:
all_bad_deps.append((name, sorted(bad_deps)))
return sorted(all_bad_deps)
def _parse_args():
"""Parse command line options."""
parser = argparse.ArgumentParser()
parser.add_argument('root_bp', help='android source tree root')
return parser.parse_args()
def main():
"""Main function."""
args = _parse_args()
parser = RecursiveParser()
parser.parse_file(args.root_bp)
all_bad_deps = _check_modules_deps(evaluate_defaults(parser.modules))
for name, bad_deps in all_bad_deps:
print('ERROR: {!r} must not depend on {}'.format(name, bad_deps),
file=sys.stderr)
if all_bad_deps:
sys.exit(1)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,20 @@
#!/bin/bash -ex
#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
python3 -m unittest discover "$@"
python -m unittest discover "$@"

View File

@@ -0,0 +1 @@
#!/usr/bin/env python3

View File

@@ -0,0 +1,251 @@
#!/usr/bin/env python3
#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains the unit tests to check the AST classes."""
import unittest
from blueprint import Bool, Concat, Dict, Expr, List, String, VarRef
#------------------------------------------------------------------------------
# Expr
#------------------------------------------------------------------------------
class ExprTest(unittest.TestCase):
"""Unit tests for the Expr class."""
def test_eval(self):
"""Test whether Expr.eval() raises NotImplementedError."""
with self.assertRaises(NotImplementedError):
Expr().eval({})
#------------------------------------------------------------------------------
# Bool
#------------------------------------------------------------------------------
class BoolTest(unittest.TestCase):
"""Unit tests for the Bool class."""
def test_bool(self):
"""Test Bool.__init__(), Bool.__bool__(), and Bool.eval() methods."""
false_expr = Bool(False)
self.assertFalse(bool(false_expr))
self.assertFalse(false_expr.eval({}))
true_expr = Bool(True)
self.assertTrue(bool(true_expr))
self.assertTrue(true_expr.eval({}))
self.assertEqual(Bool(False), false_expr)
self.assertEqual(Bool(True), true_expr)
def test_equal(self):
"""Test Bool.__eq__() method."""
false_expr1 = Bool(False)
false_expr2 = Bool(False)
true_expr1 = Bool(True)
true_expr2 = Bool(True)
self.assertIsNot(false_expr1, false_expr2)
self.assertEqual(false_expr1, false_expr2)
self.assertIsNot(true_expr1, true_expr2)
self.assertEqual(true_expr1, true_expr2)
def test_hash(self):
"""Test Bool.__hash__() method."""
false_expr = Bool(False)
true_expr = Bool(True)
self.assertEqual(hash(Bool(False)), hash(false_expr))
self.assertEqual(hash(Bool(True)), hash(true_expr))
def test_repr(self):
"""Test Bool.__repr__() method."""
self.assertEqual('False', repr(Bool(False)))
self.assertEqual('True', repr(Bool(True)))
#------------------------------------------------------------------------------
# String
#------------------------------------------------------------------------------
class StringTest(unittest.TestCase):
"""Unit tests for the String class."""
def test_string(self):
"""Test String.__init__() and String.eval() methods."""
expr = String('test')
self.assertEqual('test', expr.eval({}))
#------------------------------------------------------------------------------
# VarRef
#------------------------------------------------------------------------------
class VarRefTest(unittest.TestCase):
"""Unit tests for the VarRef class."""
def test_var_ref(self):
"""Test VarRef.__init__() and VarRef.eval() methods."""
expr = VarRef('a', String('b'))
self.assertEqual('a', expr.name)
self.assertEqual(String('b'), expr.value)
self.assertEqual('b', expr.eval({}))
def test_eval_with_value(self):
"""Test evaluation of local variables."""
expr = VarRef('a', String('1'))
self.assertEqual('1', expr.eval({'a': String('2')}))
def test_eval_without_value(self):
"""Test evaluation of external variables."""
expr = VarRef('a', None)
self.assertEqual('2', expr.eval({'a': String('2')}))
def test_eval_recursive(self):
"""Test recursive evaluation."""
expr = VarRef('a', List([VarRef('x', None), VarRef('y', None)]))
expr_eval = expr.eval({'x': String('1'), 'y': String('2')})
self.assertIsInstance(expr_eval, List)
self.assertEqual('1', expr_eval[0])
self.assertEqual('2', expr_eval[1])
#------------------------------------------------------------------------------
# List
#------------------------------------------------------------------------------
class ListTest(unittest.TestCase):
"""Unit tests for the List class."""
def test_list(self):
"""Test List.__init__() and List.eval() methods."""
expr = List([String('a'), String('b')])
self.assertEqual(String('a'), expr[0])
self.assertEqual(String('b'), expr[1])
expr = List([VarRef('a', None), VarRef('b', None)])
expr_eval = expr.eval({'a': String('1'), 'b': String('2')})
self.assertEqual('1', expr_eval[0])
self.assertEqual('2', expr_eval[1])
#------------------------------------------------------------------------------
# Concatenation
#------------------------------------------------------------------------------
class ConcatTest(unittest.TestCase):
"""Unit tests for the Concat class."""
def test_concat_list(self):
"""Test Concat.__init__() and Concat.eval() methods for List."""
lhs = List([String('a'), String('b')])
rhs = List([String('c'), String('d')])
expr = Concat(lhs, rhs)
self.assertIs(expr.lhs, lhs)
self.assertIs(expr.rhs, rhs)
expr_eval = expr.eval({})
self.assertIsInstance(expr_eval, List)
self.assertEqual('a', expr_eval[0])
self.assertEqual('b', expr_eval[1])
self.assertEqual('c', expr_eval[2])
self.assertEqual('d', expr_eval[3])
def test_concat_string(self):
"""Test Concat.__init__() and Concat.eval() methods for String."""
lhs = String('a')
rhs = String('b')
expr = Concat(lhs, rhs)
self.assertIs(expr.lhs, lhs)
self.assertIs(expr.rhs, rhs)
expr_eval = expr.eval({})
self.assertIsInstance(expr_eval, String)
self.assertEqual('ab', expr_eval)
def test_type_error(self):
"""Test the type check in eval()."""
str_obj = String('a')
list_obj = List()
with self.assertRaises(TypeError):
Concat(str_obj, list_obj).eval({})
with self.assertRaises(TypeError):
Concat(list_obj, str_obj).eval({})
#------------------------------------------------------------------------------
# Dictionary
#------------------------------------------------------------------------------
class DictTest(unittest.TestCase):
"""Unit tests for the Dict class."""
def test_dict(self):
"""Test Dict.__init__() method."""
expr = Dict([('a', String('1')), ('b', Bool(True))])
self.assertIn('a', expr)
self.assertEqual(String('1'), expr['a'])
self.assertIn('b', expr)
self.assertEqual(Bool(True), expr['b'])
def test_eval(self):
"""Test Dict.eval() method."""
expr = Dict([('a', VarRef('a', None)), ('b', VarRef('b', None))])
expr_eval = expr.eval({'a': String('1'), 'b': String('2')})
self.assertIn('a', expr_eval)
self.assertEqual('1', expr_eval['a'])
self.assertIn('b', expr_eval)
self.assertEqual('2', expr_eval['b'])
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,600 @@
#!/usr/bin/env python3
#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains the unit tests to check the Lexer class."""
import sys
import unittest
from blueprint import Lexer, LexerError, Token
#------------------------------------------------------------------------------
# Python 2 compatibility
#------------------------------------------------------------------------------
if sys.version_info >= (3,):
py3_str = str # pylint: disable=invalid-name
else:
def py3_str(string):
"""Convert a string into a utf-8 encoded string."""
return unicode(string).encode('utf-8')
#------------------------------------------------------------------------------
# LexerError
#------------------------------------------------------------------------------
class LexerErrorTest(unittest.TestCase):
"""Unit tests for LexerError class."""
def test_lexer_error(self):
"""Test LexerError __init__(), __str__(), line, column, and message."""
exc = LexerError('a %', 2, 'unexpected character')
self.assertEqual(exc.line, 1)
self.assertEqual(exc.column, 3)
self.assertEqual(exc.message, 'unexpected character')
self.assertEqual(str(exc), 'LexerError: 1:3: unexpected character')
exc = LexerError('a\nb\ncde %', 8, 'unexpected character')
self.assertEqual(exc.line, 3)
self.assertEqual(exc.column, 5)
self.assertEqual(exc.message, 'unexpected character')
self.assertEqual(str(exc), 'LexerError: 3:5: unexpected character')
def test_hierarchy(self):
"""Test the hierarchy of LexerError."""
with self.assertRaises(ValueError):
raise LexerError('a', 0, 'error')
class LexComputeLineColumn(unittest.TestCase):
"""Unit tests for Lexer.compute_line_column() method."""
def test_compute_line_column(self):
"""Test the line and column computation."""
# Line 1
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 0)
self.assertEqual(line, 1)
self.assertEqual(column, 1)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 1)
self.assertEqual(line, 1)
self.assertEqual(column, 2)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 2)
self.assertEqual(line, 1)
self.assertEqual(column, 3)
# Line 2
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 3)
self.assertEqual(line, 2)
self.assertEqual(column, 1)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 4)
self.assertEqual(line, 2)
self.assertEqual(column, 2)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 5)
self.assertEqual(line, 2)
self.assertEqual(column, 3)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 6)
self.assertEqual(line, 2)
self.assertEqual(column, 4)
# Line 3
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 7)
self.assertEqual(line, 3)
self.assertEqual(column, 1)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 8)
self.assertEqual(line, 3)
self.assertEqual(column, 2)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 9)
self.assertEqual(line, 3)
self.assertEqual(column, 3)
# Line 4 (empty line)
line, column = Lexer.compute_line_column('ab\ncde\nfg\n', 10)
self.assertEqual(line, 4)
self.assertEqual(column, 1)
#------------------------------------------------------------------------------
# Lex.lex_string()
#------------------------------------------------------------------------------
class LexStringTest(unittest.TestCase):
"""Unit tests for the Lexer.lex_string() method."""
def test_raw_string_lit(self):
"""Test whether Lexer.lex_string() can tokenize raw string literal."""
end, lit = Lexer.lex_string('`a`', 0)
self.assertEqual(end, 3)
self.assertEqual(lit, 'a')
end, lit = Lexer.lex_string('`a\nb`', 0)
self.assertEqual(end, 5)
self.assertEqual(lit, 'a\nb')
end, lit = Lexer.lex_string('"a""b"', 3)
self.assertEqual(end, 6)
self.assertEqual(lit, 'b')
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string('`a', 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 3)
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string('"a\nb"', 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 3)
def test_interpreted_string_literal(self):
"""Test whether Lexer.lex_string() can tokenize interpreted string
literal."""
end, lit = Lexer.lex_string('"a"', 0)
self.assertEqual(end, 3)
self.assertEqual(lit, 'a')
end, lit = Lexer.lex_string('"n"', 0)
self.assertEqual(end, 3)
self.assertEqual(lit, 'n')
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string('"\\', 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 2)
def test_literal_escape_char(self):
"""Test whether Lexer.lex_string() can tokenize interpreted string
literal with a escaped character."""
end, lit = Lexer.lex_string('"\\a"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\a')
end, lit = Lexer.lex_string('"\\b"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\b')
end, lit = Lexer.lex_string('"\\f"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\f')
end, lit = Lexer.lex_string('"\\n"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\n')
end, lit = Lexer.lex_string('"\\r"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\r')
end, lit = Lexer.lex_string('"\\t"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\t')
end, lit = Lexer.lex_string('"\\v"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\v')
end, lit = Lexer.lex_string('"\\\\"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\\')
end, lit = Lexer.lex_string('"\\\'"', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\'')
end, lit = Lexer.lex_string('"\\\""', 0)
self.assertEqual(end, 4)
self.assertEqual(lit, '\"')
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string('"\\?"', 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 2)
def test_literal_escape_octal(self):
"""Test whether Lexer.lex_string() can tokenize interpreted string
literal with an octal escape sequence."""
end, lit = Lexer.lex_string('"\\000"', 0)
self.assertEqual(end, 6)
self.assertEqual(lit, '\0')
end, lit = Lexer.lex_string('"\\377"', 0)
self.assertEqual(end, 6)
self.assertEqual(lit, '\377')
tests = [
'"\\0',
'"\\0" ',
'"\\09" ',
'"\\009"',
]
for test in tests:
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string(test, 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 2)
def test_literal_escape_hex(self):
"""Test whether Lexer.lex_string() can tokenize interpreted string
literal with a hexadecimal escape sequence."""
end, lit = Lexer.lex_string('"\\x00"', 0)
self.assertEqual(end, 6)
self.assertEqual(lit, '\0')
end, lit = Lexer.lex_string('"\\xff"', 0)
self.assertEqual(end, 6)
self.assertEqual(lit, '\xff')
tests = [
'"\\x',
'"\\x" ',
'"\\x0" ',
'"\\xg" ',
'"\\x0g"',
]
for test in tests:
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string(test, 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 2)
def test_literal_escape_little_u(self):
"""Test whether Lexer.lex_string() can tokenize interpreted string
literal with a little u escape sequence."""
end, lit = Lexer.lex_string('"\\u0000"', 0)
self.assertEqual(end, 8)
self.assertEqual(lit, '\0')
end, lit = Lexer.lex_string('"\\uffff"', 0)
self.assertEqual(end, 8)
self.assertEqual(lit, py3_str(u'\uffff'))
tests = [
'"\\u',
'"\\u" ',
'"\\u0" ',
'"\\ug" ',
'"\\u0g" ',
'"\\u00g" ',
'"\\u000g"',
]
for test in tests:
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string(test, 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 2)
def test_literal_escape_big_u(self):
"""Test whether Lexer.lex_string() can tokenize interpreted string
literal with a big u escape sequence."""
end, lit = Lexer.lex_string('"\\U00000000"', 0)
self.assertEqual(end, 12)
self.assertEqual(lit, '\0')
end, lit = Lexer.lex_string('"\\U0001ffff"', 0)
self.assertEqual(end, 12)
self.assertEqual(lit, py3_str(u'\U0001ffff'))
tests = [
'"\\U',
'"\\U" ',
'"\\U0" ',
'"\\Ug" ',
'"\\U0g" ',
'"\\U00g" ',
'"\\U000g" ',
'"\\U000g" ',
'"\\U0000g" ',
'"\\U00000g" ',
'"\\U000000g" ',
'"\\U0000000g"',
]
for test in tests:
with self.assertRaises(LexerError) as ctx:
Lexer.lex_string(test, 0)
self.assertEqual(ctx.exception.line, 1)
self.assertEqual(ctx.exception.column, 2)
#------------------------------------------------------------------------------
# Lexer.lex()
#------------------------------------------------------------------------------
class LexTest(unittest.TestCase):
"""Unit tests for the Lexer.lex() method."""
def test_lex_char(self):
"""Test whether Lexer.lex() can lex a character."""
token, end, lit = Lexer.lex('(', 0)
self.assertEqual(token, Token.LPAREN)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex(')', 0)
self.assertEqual(token, Token.RPAREN)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('[', 0)
self.assertEqual(token, Token.LBRACKET)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex(']', 0)
self.assertEqual(token, Token.RBRACKET)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('{', 0)
self.assertEqual(token, Token.LBRACE)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('}', 0)
self.assertEqual(token, Token.RBRACE)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex(':', 0)
self.assertEqual(token, Token.COLON)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('=', 0)
self.assertEqual(token, Token.ASSIGN)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('+', 0)
self.assertEqual(token, Token.PLUS)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex(',', 0)
self.assertEqual(token, Token.COMMA)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
def test_lex_assign_plus(self):
"""Test whether Lexer.lex() can lex `+=` without problems."""
token, end, lit = Lexer.lex('+=', 0)
self.assertEqual(token, Token.ASSIGNPLUS)
self.assertEqual(end, 2)
self.assertEqual(lit, None)
def test_lex_space(self):
"""Test whether Lexer.lex() can lex whitespaces."""
token, end, lit = Lexer.lex(' ', 0)
self.assertEqual(token, Token.SPACE)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('\t', 0)
self.assertEqual(token, Token.SPACE)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('\r', 0)
self.assertEqual(token, Token.SPACE)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('\n', 0)
self.assertEqual(token, Token.SPACE)
self.assertEqual(end, 1)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('\n \r\t\n', 0)
self.assertEqual(token, Token.SPACE)
self.assertEqual(end, 5)
self.assertEqual(lit, None)
def test_lex_comment(self):
"""Test whether Lexer.lex() can lex comments."""
token, end, lit = Lexer.lex('// abcd', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 7)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('// abcd\nnext', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 7)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/*a\nb*/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 7)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/*a\n *b*/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 9)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/*a**b*/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 8)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/*a***b*/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 9)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/**/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 4)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/***/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 5)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/**a*/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 6)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/*a**/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 6)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/***a*/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 7)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('/*a***/', 0)
self.assertEqual(token, Token.COMMENT)
self.assertEqual(end, 7)
self.assertEqual(lit, None)
def test_lex_string(self):
"""Test whether Lexer.lex() can lex a string."""
token, end, lit = Lexer.lex('"a"', 0)
self.assertEqual(token, Token.STRING)
self.assertEqual(end, 3)
self.assertEqual(lit, 'a')
token, end, lit = Lexer.lex('`a\nb`', 0)
self.assertEqual(token, Token.STRING)
self.assertEqual(end, 5)
self.assertEqual(lit, 'a\nb')
def test_lex_ident(self):
"""Test whether Lexer.lex() can lex an identifier."""
token, end, lit = Lexer.lex('ident', 0)
self.assertEqual(token, Token.IDENT)
self.assertEqual(end, 5)
self.assertEqual(lit, 'ident')
def test_lex_offset(self):
"""Test the offset argument of Lexer.lex()."""
token, end, lit = Lexer.lex('a "b"', 0)
self.assertEqual(token, Token.IDENT)
self.assertEqual(end, 1)
self.assertEqual(lit, 'a')
token, end, lit = Lexer.lex('a "b"', end)
self.assertEqual(token, Token.SPACE)
self.assertEqual(end, 2)
self.assertEqual(lit, None)
token, end, lit = Lexer.lex('a "b"', end)
self.assertEqual(token, Token.STRING)
self.assertEqual(end, 5)
self.assertEqual(lit, 'b')
#------------------------------------------------------------------------------
# Lexer class test
#------------------------------------------------------------------------------
class LexerTest(unittest.TestCase):
"""Unit tests for the Lexer class."""
def test_lexer(self):
"""Test token, start, end, literal, and consume()."""
lexer = Lexer('a b //a\n "c"', 0)
self.assertEqual(lexer.start, 0)
self.assertEqual(lexer.end, 1)
self.assertEqual(lexer.token, Token.IDENT)
self.assertEqual(lexer.literal, 'a')
lexer.consume(Token.IDENT)
self.assertEqual(lexer.start, 2)
self.assertEqual(lexer.end, 3)
self.assertEqual(lexer.token, Token.IDENT)
self.assertEqual(lexer.literal, 'b')
lexer.consume(Token.IDENT)
self.assertEqual(lexer.start, 9)
self.assertEqual(lexer.end, 12)
self.assertEqual(lexer.token, Token.STRING)
self.assertEqual(lexer.literal, 'c')
lexer.consume(Token.STRING)
self.assertEqual(lexer.start, 12)
self.assertEqual(lexer.end, 12)
self.assertEqual(lexer.token, Token.EOF)
self.assertEqual(lexer.literal, None)
def test_lexer_offset(self):
"""Test the offset argument of Lexer.__init__()."""
lexer = Lexer('a b', 2)
self.assertEqual(lexer.start, 2)
self.assertEqual(lexer.end, 3)
self.assertEqual(lexer.token, Token.IDENT)
self.assertEqual(lexer.literal, 'b')
lexer.consume(Token.IDENT)
self.assertEqual(lexer.start, 3)
self.assertEqual(lexer.end, 3)
self.assertEqual(lexer.token, Token.EOF)
self.assertEqual(lexer.literal, None)
lexer.consume(Token.EOF)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,171 @@
#!/usr/bin/env python3
#
# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module contains the unit tests to check the Parser class."""
import unittest
from blueprint import Lexer, Parser, String, VarRef
#------------------------------------------------------------------------------
# Variable Definition
#------------------------------------------------------------------------------
class DefineVarTest(unittest.TestCase):
def test_define_var(self):
parser = Parser(None)
str1 = String(1)
parser.define_var('a', str1)
self.assertEqual(len(parser.var_defs), 1)
self.assertEqual(len(parser.vars), 1)
self.assertIn('a', parser.vars)
self.assertIs(parser.vars['a'], str1)
str2 = String(2)
parser.define_var('a', str2)
self.assertEqual(len(parser.var_defs), 2)
self.assertEqual(len(parser.vars), 1)
self.assertIn('a', parser.vars)
self.assertIs(parser.vars['a'], str2)
def test_create_var_ref(self):
parser = Parser(None)
str1 = String(1)
parser.define_var('a', str1)
var1 = parser.create_var_ref('a')
self.assertIsInstance(var1, VarRef)
self.assertEqual(var1.name, 'a')
self.assertIs(var1.value, str1)
var2 = parser.create_var_ref('b')
self.assertIsInstance(var2, VarRef)
self.assertEqual(var2.name, 'b')
self.assertIs(var2.value, None)
#------------------------------------------------------------------------------
# Parser
#------------------------------------------------------------------------------
class ParserTest(unittest.TestCase):
def test_assign_string(self):
lexer = Lexer('a = "example"')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr('example'))
def test_list_empty(self):
lexer = Lexer('a = []')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr([]))
def test_list_one_element(self):
lexer = Lexer('a = ["x"]')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr(['x']))
def test_list_one_element_comma(self):
lexer = Lexer('a = ["x",]')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr(['x']))
def test_list_two_elements(self):
lexer = Lexer('a = ["x", "y"]')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr(['x', 'y']))
def test_list_two_elements_comma(self):
lexer = Lexer('a = ["x", "y",]')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr(['x', 'y']))
def test_dict_empty(self):
lexer = Lexer('a = {}')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), repr({}))
def test_dict_one_element(self):
lexer = Lexer('a = {x: "1"}')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), '{x: \'1\'}')
def test_dict_one_element_comma(self):
lexer = Lexer('a = {x: "1",}')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), '{x: \'1\'}')
def test_dict_two_elements(self):
lexer = Lexer('a = {x: "1", y: "2"}')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), '{x: \'1\', y: \'2\'}')
def test_dict_two_elements_comma(self):
lexer = Lexer('a = {x: "1", y: "2",}')
parser = Parser(lexer)
parser.parse()
self.assertEqual(parser.var_defs[0][0], 'a')
self.assertEqual(repr(parser.var_defs[0][1]), '{x: \'1\', y: \'2\'}')
if __name__ == '__main__':
unittest.main()