diff --git a/.github/workflows/smoke_test.yml b/.github/workflows/smoke_test.yml index b97df00b..38609308 100644 --- a/.github/workflows/smoke_test.yml +++ b/.github/workflows/smoke_test.yml @@ -97,11 +97,11 @@ jobs: echo "a4db_RESULT=$(quark -a apk-samples/malware-samples/13667fe3b0ad496a0cd157f34b7e0c991d72a4db.apk -s -t 100 | grep 100% | wc -l | awk '{print $1}')" >> $GITHUB_ENV echo "e273e_RESULT=$(quark -a apk-samples/malware-samples/14d9f1a92dd984d6040cc41ed06e273e.apk -s -t 100 | grep 100% | wc -l | awk '{print $1}')" >> $GITHUB_ENV - - name: Check Ahmyt Result + - name: Check Ahmyth Result shell: bash - # This sample should have 16 behaviors with 100% confidence + # This sample should have 39 behaviors with 100% confidence run: | - if [ "${{ env.Ahmyth_RESULT }}" == "40" ]; then + if [ "${{ env.Ahmyth_RESULT }}" == "39" ]; then exit 0 else exit 1 @@ -109,9 +109,9 @@ jobs: - name: Check 13667fe3b0ad496a0cd157f34b7e0c991d72a4db.apk Result shell: bash - # This sample should have 11 behaviors with 100% confidence + # This sample should have 19 behaviors with 100% confidence run: | - if [ "${{ env.a4db_RESULT }}" == "20" ]; then + if [ "${{ env.a4db_RESULT }}" == "19" ]; then exit 0 else exit 1 @@ -119,9 +119,9 @@ jobs: - name: Check 14d9f1a92dd984d6040cc41ed06e273e.apk Result shell: bash - # This sample should have 15 behaviors with 100% confidence + # This sample should have 41 behaviors with 100% confidence run: | - if [ "${{ env.e273e_RESULT }}" == "43" ]; then + if [ "${{ env.e273e_RESULT }}" == "41" ]; then exit 0 else exit 1 diff --git a/quark/core/apkinfo.py b/quark/core/apkinfo.py index 5a619683..a764836b 100644 --- a/quark/core/apkinfo.py +++ b/quark/core/apkinfo.py @@ -22,8 +22,6 @@ class AndroguardImp(BaseApkinfo): """Information about apk based on androguard analysis""" - __slots__ = ("apk", "dalvikvmformat", "analysis", "_manifest") - def __init__(self, apk_filepath: Union[str, PathLike]): super().__init__(apk_filepath, "androguard") @@ -59,7 +57,7 @@ def custom_methods(self) -> Set[MethodObject]: if not meth_analysis.is_external() } - @property + @functools.cached_property def all_methods(self) -> Set[MethodObject]: return { self._convert_to_method_object(meth_analysis) @@ -261,7 +259,7 @@ def get_wrapper_smali( return result - @property + @functools.cached_property def superclass_relationships(self) -> Dict[str, Set[str]]: hierarchy_dict = defaultdict(set) diff --git a/quark/core/quark.py b/quark/core/quark.py index 14b260ce..3b222c94 100644 --- a/quark/core/quark.py +++ b/quark/core/quark.py @@ -2,11 +2,11 @@ # This file is part of Quark-Engine - https://github.com/quark-engine/quark-engine # See the file 'LICENSE' for copying permission. -import collections +from itertools import product import operator import os import re -from typing import Generator, List, Tuple +from typing import Any, Generator, List, Mapping, Sequence, Tuple import numpy as np import csv @@ -16,7 +16,10 @@ from quark.core.apkinfo import AndroguardImp from quark.core.rzapkinfo import RizinImp from quark.core.r2apkinfo import R2Imp +from quark.core.struct.methodobject import MethodObject +from quark.core.struct.registerobject import RegisterObject from quark.evaluator.pyeval import PyEval +from quark.core.struct.valuenode import MethodCall, iteratePriorCalls, iteratePriorPrimitives from quark.utils import tools from quark.utils.colors import ( colorful_report, @@ -224,189 +227,189 @@ def _evaluate_method(self, method) -> List[List[str]]: return pyeval.show_table() - def check_parameter_on_single_method( + def checkParameterOnSingleMethod( self, - usage_table, - first_method, - second_method, - keyword_item_list=None, - regex=False, - ) -> Generator[Tuple[str, List[str]], None, None]: - """Check the usage of the same parameter between two method. - - :param usage_table: the usage of the involved registers - :param first_method: the first API or the method calling the first APIs - :param second_method: the second API or the method calling the second - APIs - :param keyword_item_list: keywords required to be present in the usage - , defaults to None - :param regex: treat the keywords as regular expressions, defaults to - False - :yield: _description_ + usageTable: Mapping[int, Sequence[RegisterObject]], + firstMethod: MethodObject, + secondMethod: MethodObject, + firstMethodKeywords: Sequence | None = None, + secondMethodKeywords: Sequence | None = None, + regex: bool = False, + ) -> Generator[Tuple[Tuple[MethodCall, MethodCall], List[str]], None, None]: + """Check the usage of the same parameter between two methods. + + :param usageTable: a table that records the usage of each register + :param firstMethod: the first API or a method calling the first API + :param secondMethod: the second API or a method calling the second API + :param firstMethodKeywords: keywords to match for the first method, + defaults to None + :param secondMethodKeywords: keywords to match for the second method, + defaults to None + :param regex: treat keywords as regular expressions, defaults to False + :yield: a tuple of matched method call pairs and matched keywords """ - first_method_pattern = PyEval.get_method_pattern( - first_method.class_name, first_method.name, first_method.descriptor - ) - - second_method_pattern = PyEval.get_method_pattern( - second_method.class_name, - second_method.name, - second_method.descriptor, - ) - - register_usage_records = ( - c_func - for table in usage_table - for val_obj in table - for c_func in val_obj.called_by_func - ) - - register_usage_records = ( - c_func - for table in usage_table - for val_obj in table - for c_func in val_obj.called_by_func + # Find the first and second method call that share same register + matchedCallPairs = self.findMethodCallPairs( + usageTable, + ( + firstMethod.class_name, + firstMethod.name, + firstMethod.descriptor, + ), + ( + secondMethod.class_name, + secondMethod.name, + secondMethod.descriptor, + ), ) - matched_records = filter( - lambda r: first_method_pattern in r and second_method_pattern in r, - register_usage_records, - ) + # Skip if no keywords provided for both methods. + if not firstMethodKeywords and not secondMethodKeywords: + for matchedCallPair in matchedCallPairs: + yield (matchedCallPair, []) + return - for record in matched_records: - if keyword_item_list and list(keyword_item_list): - matched_keyword_list = self.check_parameter_values( - record, - (first_method_pattern, second_method_pattern), - keyword_item_list, - regex, + # Do keyword matching + for firstCall, secondCall in matchedCallPairs: + matchedKeywords = [] + if firstMethodKeywords is not None: + # Check if arguments of the first call match any keywords. + first_matched = self.getMatchedKeywords( + firstCall, firstMethodKeywords, regex=regex ) + matchedKeywords.extend(first_matched) - if matched_keyword_list: - yield (record, matched_keyword_list) + if secondMethodKeywords is not None: + # Check if arguments of the second call match any keywords. + second_matched = self.getMatchedKeywords( + secondCall, secondMethodKeywords, regex=regex + ) + matchedKeywords.extend(second_matched) - else: - yield (record, None) + # Pass only if at least one keyword was matched. + if len(matchedKeywords) > 0: + yield ((firstCall, secondCall), matchedKeywords) def check_parameter( self, - parent_function, - first_method_list, - second_method_list, - keyword_item_list=None, - regex=False, - ): - """ - Check the usage of the same parameter between two method. + parent_method: MethodObject, + first_method_list: Sequence[MethodObject], + second_method_list: Sequence[MethodObject], + first_method_keywords: Sequence[Any] | None = None, + second_method_keywords: Sequence[Any] | None = None, + regex: bool = False, + ) -> bool: + """Check the usage of the same parameter between two method. - :param parent_function: function that call the first function and - second functions at the same time. - :param first_method_list: function which calls before the second - method. - :param second_method_list: function which calls after the first method. + :param parent_method: the method to do the check + :param first_method_list: a list of first API and methods that calls + the first API + :param second_method_list: a list of second API and methods that calls + the second API + :param first_method_keywords: keywords to match for the first method, + defaults to None + :param second_method_keywords: keywords to match for the second method, + defaults to None + :param regex: treat keywords as regular expressions, defaults to False :return: True or False """ - if parent_function is None: + if parent_method is None: raise TypeError("Parent function is None.") if first_method_list is None or second_method_list is None: raise TypeError("First or second method list is None.") - if keyword_item_list: - keyword_item_list = list(keyword_item_list) - if not any(keyword_item_list): - keyword_item_list = None - - state = False - # Evaluate the opcode in the parent function - usage_table = self._evaluate_method(parent_function) + usage_table = self._evaluate_method(parent_method) # Check if any of the target methods (the first and second methods) # used the same registers. state = False - for first_call_method in first_method_list: - for second_call_method in second_method_list: - result_generator = self.check_parameter_on_single_method( - usage_table, - first_call_method, - second_call_method, - keyword_item_list, - regex, - ) + for first_method, second_method in product( + first_method_list, second_method_list + ): + results = self.checkParameterOnSingleMethod( + usage_table, + first_method, + second_method, + first_method_keywords, + second_method_keywords, + regex, + ) - found = next(result_generator, None) is not None - - # Build for the call graph - if found: - call_graph_analysis = { - "parent": parent_function, - "first_call": first_call_method, - "second_call": second_call_method, - "apkinfo": self.apkinfo, - "first_api": self.quark_analysis.first_api, - "second_api": self.quark_analysis.second_api, - "crime": self.quark_analysis.crime_description, - } - self.quark_analysis.call_graph_analysis_list.append( - call_graph_analysis - ) + found = next(results, None) is not None - # Record the mapping between the parent function and the - # wrapper method - self.quark_analysis.parent_wrapper_mapping[ - parent_function.full_name - ] = self.apkinfo.get_wrapper_smali( - parent_function, - first_call_method, - second_call_method, - ) + if not found: + continue - state = True + # Build for the call graph + call_graph_analysis = { + "parent": parent_method, + "first_call": first_method, + "second_call": second_method, + "apkinfo": self.apkinfo, + "first_api": self.quark_analysis.first_api, + "second_api": self.quark_analysis.second_api, + "crime": self.quark_analysis.crime_description, + } + self.quark_analysis.call_graph_analysis_list.append( + call_graph_analysis + ) + + # Record the mapping between the parent function and the + # wrapper method + self.quark_analysis.parent_wrapper_mapping[ + parent_method.full_name + ] = self.apkinfo.get_wrapper_smali( + parent_method, + first_method, + second_method, + ) + + state = True return state @staticmethod - def check_parameter_values( - source_str, pattern_list, keyword_item_list, regex=False + def getMatchedKeywords( + methodCall: MethodCall, keywords: Sequence, regex: bool ) -> List[str]: - matched_string_set = set() + """Get matched keywords from the parameters of a method call. - parameter_strs = [ - tools.get_parenthetic_contents( - source_str, source_str.index(pattern) + len(pattern) - ) - for pattern in pattern_list - ] - - for parameter_str, keyword_item in zip( - parameter_strs, keyword_item_list - ): - if keyword_item is None: - continue + :param method_call: the method call to be checked + :param keywords: keywords to be matched + :param regex: whether to treat keywords as regular expressions + :yield: a list of matched keywords + """ + matchedStrSet = set() + primitiveValues = { + str(primitive.value) + for primitive in iteratePriorPrimitives(methodCall) + } - for keyword in keyword_item: - if regex: - matched_strings = re.findall(keyword, parameter_str) - if any(matched_strings): - matched_strings = filter(bool, matched_strings) - matched_strings = list(matched_strings) + if not regex: + return [ + value + for value in primitiveValues + if any(kw in value for kw in keywords) + ] - element = matched_strings[0] - if isinstance( - element, collections.abc.Sequence - ) and not isinstance(element, str): - for str_list in matched_strings: - matched_string_set.update(str_list) + for keyword in keywords: + regexPattern = re.compile(keyword) - else: - matched_string_set.update(matched_strings) + for value in primitiveValues: + matchedStrings = regexPattern.findall(value) + if not any(matchedStrings): + continue + # Filter out empty strings from tuples in the result + if isinstance(matchedStrings[0], tuple): + for matchTuple in matchedStrings: + matchedStrSet.update(filter(bool, matchTuple)) else: - if str(keyword) in parameter_str: - matched_string_set.add(keyword) + matchedStrSet.update(filter(bool, matchedStrings)) - return [e for e in list(matched_string_set) if bool(e)] + return [e for e in list(matchedStrSet) if bool(e)] def find_api_usage(self, class_name, method_name, descriptor_name): method_list = [] @@ -561,17 +564,13 @@ def run(self, rule_obj): parent_function ) - keyword_item_list = ( - rule_obj.api[i].get("match_keywords", None) - for i in range(2) - ) - # Level 5: Handling The Same Register Check if self.check_parameter( parent_function, first_wrapper, second_wrapper, - keyword_item_list=keyword_item_list, + rule_obj.firstApiKeywords, + rule_obj.secondApiKeywords, ): rule_obj.check_item[4] = True self.quark_analysis.level_5_result.append( @@ -857,6 +856,39 @@ def show_rule_classification(self): output_parent_function_json(data_bundle) output_parent_function_graph(data_bundle) + @staticmethod + def findMethodCallPairs( + usageTable: Mapping[int, Sequence[RegisterObject]], + firstMethodInfo: tuple[str, str, str], + secondMethodInfo: tuple[str, str, str], + ) -> Generator[tuple[MethodCall, MethodCall], None, None]: + + allRegisterValues = ( + registerValue + for register in usageTable.values() + for registerValue in register + ) + + secondMethodPattern = PyEval.get_method_pattern(*secondMethodInfo) + + secondAPICalls = ( + call + for registerValue in allRegisterValues + for call in registerValue.iterateInvolvedCalls() + if call.method == secondMethodPattern + ) + + firstMethodPattern = PyEval.get_method_pattern(*firstMethodInfo) + + matchedCallPairs = ( + (firstCall, secondCall) + for secondCall in secondAPICalls + for firstCall in iteratePriorCalls(secondCall) + if firstCall.method == firstMethodPattern + ) + + yield from matchedCallPairs + if __name__ == "__main__": pass diff --git a/quark/core/struct/registerobject.py b/quark/core/struct/registerobject.py index e6863254..f979ddb6 100644 --- a/quark/core/struct/registerobject.py +++ b/quark/core/struct/registerobject.py @@ -3,17 +3,30 @@ # See the file 'LICENSE' for copying permission. +from typing import Generator + +from quark.core.struct.valuenode import ( + MethodCall, + ValueNode, + iteratePriorCalls, +) + + class RegisterObject: """The RegisterObject is used to record the state of each register""" __slots__ = [ "_value", "_called_by_func", - "_current_type", - "_type_history", + "_current_type" ] - def __init__(self, value, called_by_func=None, value_type=None): + def __init__( + self, + value: ValueNode, + called_by_func: ValueNode | None = None, + value_type=None, + ): """ A data structure for creating the bytecode variable object, which used to record the state of each register. @@ -27,13 +40,16 @@ def __init__(self, value, called_by_func=None, value_type=None): """ self._value = value self._current_type = value_type - self._type_history = [] self._called_by_func = [] if called_by_func is not None: self._called_by_func.append(called_by_func) def __repr__(self): - return f"" + return ( + f"" + ) def __eq__(self, obj): return ( @@ -61,7 +77,6 @@ def called_by_func(self, called_by_func): :return: None """ self._called_by_func.append(called_by_func) - self._type_history.append(self._current_type) @property def value(self): @@ -96,18 +111,23 @@ def current_type(self): def current_type(self, value): self._current_type = value - @property - def type_histroy(self): - return self._type_history - def bears_object(self) -> bool: """ - Check whether the register bears an object. + Check whether the register bears an object or has an unknown type. - :return: True if the register bears an object, False otherwise + :return: True if the register holds an object or its type is unknown; + False otherwise. :rtype: bool """ - return self.current_type is not None and self.current_type.startswith("L") + return self.current_type is None or self.current_type.startswith("L") + + def iterateInvolvedCalls(self) -> Generator[MethodCall, None, None]: + """ + Yield all method calls involved by this register. + """ + for call in self._called_by_func: + yield from iteratePriorCalls(call) + if __name__ == "__main__": pass diff --git a/quark/core/struct/ruleobject.py b/quark/core/struct/ruleobject.py index 863489ea..ef65e01d 100644 --- a/quark/core/struct/ruleobject.py +++ b/quark/core/struct/ruleobject.py @@ -100,6 +100,26 @@ def score(self): :return: integer """ return self._score + + @property + def firstApiKeywords(self) -> list | None: + """ + The keywords to match for the first API in the rule. + + :return: list or None + """ + keywords = self._api[0].get("match_keywords", None) + return keywords if isinstance(keywords, list) and len(keywords) > 0 else None + + @property + def secondApiKeywords(self) -> list | None: + """ + The keywords to match for the second API in the rule. + + :return: list or None + """ + keywords = self._api[1].get("match_keywords", None) + return keywords if isinstance(keywords, list) and len(keywords) > 0 else None def get_score(self, confidence): """ diff --git a/quark/core/struct/tableobject.py b/quark/core/struct/tableobject.py index b51e360e..f179e23b 100644 --- a/quark/core/struct/tableobject.py +++ b/quark/core/struct/tableobject.py @@ -2,50 +2,41 @@ # This file is part of Quark-Engine - https://github.com/quark-engine/quark-engine # See the file 'LICENSE' for copying permission. +from collections import defaultdict +from quark.core.struct.registerobject import RegisterObject + class TableObject: """This table is used to track the usage of variables in the register""" __slots__ = ["hash_table"] - def __init__(self, count_reg): - """ - This table used to store the variable object, which uses a hash table - with a stack-based list to generate the bytecode variable tracker table. - - :param count_reg: the maximum number of register to initialize - """ - self.hash_table = [[] for _ in range(count_reg)] + def __init__(self): + self.hash_table = defaultdict(list) def __repr__(self): return f"" - def insert(self, index, var_obj): + def insert(self, index: int, registerValue: RegisterObject) -> None: """ - Insert VariableObject into the nested list in the hashtable. + Insert RegisterObject into the nested list in the hash table. :param index: the index to insert to the table - :param var_obj: instance of VariableObject + :param var_obj: instance of RegisterObject :return: None """ - try: - self.hash_table[index].append(var_obj) - except IndexError: - pass + self.hash_table[index].append(registerValue) - def get_obj_list(self, index): + def getRegValues(self, index: int) -> list[RegisterObject]: """ - Return the list which contains the VariableObject. + Return the list which contains the RegisterObject. - :param index: the index to get the corresponding VariableObject - :return: a list containing VariableObject + :param index: the index to get the corresponding RegisterObject + :return: a list containing RegisterObject """ - try: - return self.hash_table[index] - except IndexError: - return None + return self.hash_table[index] - def get_table(self): + def getTable(self) -> dict[int, list[RegisterObject]]: """ Get the entire hash table. @@ -53,13 +44,12 @@ def get_table(self): """ return self.hash_table - def pop(self, index): + def getLatestRegValue(self, index: int) -> RegisterObject: """ - Override the built-in pop function, to get the top element, which - is VariableObject on the stack while not delete it. + Get the latest RegisterObject for the given index. - :param index: the index to get the corresponding VariableObject - :return: VariableObject + :param index: the index to get the corresponding RegisterObject + :return: RegisterObject """ return self.hash_table[index][-1] diff --git a/quark/core/struct/valuenode.py b/quark/core/struct/valuenode.py new file mode 100644 index 00000000..550b16cb --- /dev/null +++ b/quark/core/struct/valuenode.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# This file is part of Quark-Engine - https://github.com/quark-engine/quark-engine +# See the file 'LICENSE' for copying permission. + +from abc import ABC, abstractmethod +import collections +from dataclasses import dataclass +from typing import Any, Generator, Type, TypeVar +from weakref import WeakValueDictionary + + +@dataclass() +class ValueNode(ABC): + """Abstract base class for value node.""" + + def resolve(self, evaluateArgs: bool = True) -> str: + """Resolve the value into a string representation. + + :param evaluateArgs: True to evaluate argument base on its type, + default to True + :return: a string representation of the value + """ + return iterativeResolve(self, evaluateArgs=evaluateArgs) + + @abstractmethod + def _getChildren(self) -> tuple["ValueNode", ...]: + """Get the child ValueNodes of this node. + + :return: a tuple of child ValueNodes + """ + pass + + @abstractmethod + def _assembleResolvedString( + self, childStrs: tuple[str, ...], evaluateArgs: bool + ) -> str: + """Assemble the resolved string from child strings. + + :param childStrs: a tuple of resolved child strings + :param evaluateArgs: True to evaluate argument base on its type, + default to True + :return: the assembled resolved string + """ + pass + + def __eq__(self, value: object) -> bool: + return self is value + + def __hash__(self): + return id(self) + + +@dataclass(slots=True, eq=False) +class Primitive(ValueNode): + """A ValueNode that wraps a primitive type (str, int, etc.).""" + + value: Any + value_type: str | None + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"Primitive({self.value!r})" + + def _getChildren(self) -> tuple["ValueNode", ...]: + return () + + def _assembleResolvedString(self, _, evaluateArgs: bool) -> str: + return ( + str(evaluateArgument(self.value, self.value_type)) + if evaluateArgs + else str(self.value) + ) + + +@dataclass(slots=True, eq=False) +class MethodCall(ValueNode): + """A ValueNode that represents a method call.""" + + method: str + argumentNodes: tuple[ValueNode, ...] + + def __str__(self): + return f"" + + def __repr__(self): + return f"MethodCall({self.method!r}, {self.argumentNodes!r})" + + def _getChildren(self) -> tuple["ValueNode", ...]: + return self.argumentNodes + + def _assembleResolvedString(self, argStrs: tuple[str, ...], _) -> str: + return f"{self.method}({','.join(argStrs)})" + + def getArguments(self, evaluateArgs: bool = True) -> list[Any]: + return [ + ( + evaluateArgument(rawArg.value, rawArg.value_type) + if evaluateArgs and isinstance(rawArg, Primitive) + else rawArg.resolve(evaluateArgs) + ) + for rawArg in self.argumentNodes + ] + + +@dataclass(slots=True, eq=False) +class BytecodeOps(ValueNode): + """A ValueNode that represents a bytecode operation (e.g., binop, cast).""" + + strFormat: str + operands: tuple[ValueNode, ...] + data: Any + + def __str__(self): + return f"" + + def __repr__(self): + return ( + f"BytecodeOps({self.strFormat!r}, {self.operands!r}, {self.data!r})" + ) + + def _getChildren(self) -> tuple[ValueNode, ...]: + return self.operands + + def _assembleResolvedString(self, operandStrs: tuple[str, ...], _) -> str: + value_dict = { + f"src{index}": value for index, value in enumerate(operandStrs) + } + value_dict["data"] = str(self.data) + return self.strFormat.format(**value_dict) + + +T = TypeVar("T", bound=ValueNode) + +__resolvedCache: WeakValueDictionary[int, "StringWrapper"] = ( + WeakValueDictionary() +) + + +@dataclass(frozen=True) +class StringWrapper: + value: str + + +def iterativeResolve(node: ValueNode, evaluateArgs: bool) -> str: + """Resolve the value node into a string representation. + + :param node: value node to resolve + :param evaluateArgs: True to evaluate argument base on its type + :return: a string representation of the value + """ + stack = [(node, [])] + visiting = {id(node)} + + while stack: + current, childStrs = stack[-1] + children = current._getChildren() + + if len(childStrs) < len(children): + # Still has children to process + child = children[len(childStrs)] + + cachedValue = __resolvedCache.get(id(child)) + if cachedValue is not None: + # Use cached resolved value + childStrs.append(cachedValue.value) + continue + + if id(child) in visiting: + childStrs.append("") + continue + + # Update current node to continue with next child later + visiting.add(id(child)) + stack.append((child, [])) + continue + + result = current._assembleResolvedString( + tuple(childStrs), evaluateArgs + ) + __resolvedCache[id(current)] = StringWrapper(result) + + # Current node is fully processed, pop from stack + visiting.remove(id(current)) + stack.pop() + + if not stack: + # No parent, this is the root node + return result + + # Append result to parent's list + _, parentProcessedChildren = stack[-1] + parentProcessedChildren.append(result) + + raise RuntimeError("Unreachable code reached in iterativeResolve") + + +def iteratePriorNodes( + node: ValueNode, nodeType: Type[T] +) -> Generator[T, None, None]: + """Yield all prior ValueNodes that contribute to the given ValueNode, + including itself. + + :param node: root node to start + :param nodeType: node type to yield + :yield: value nodes of given node types + """ + visited = set() + queue = collections.deque([node]) + + while queue: + node = queue.popleft() + if id(node) in visited: + continue + visited.add(id(node)) + + if isinstance(node, nodeType): + yield node + + match node: + case MethodCall(): + queue.extend(node.argumentNodes) + case BytecodeOps(): + queue.extend(node.operands) + + +def iteratePriorCalls( + methodCall: MethodCall, +) -> Generator[MethodCall, None, None]: + """Yield all prior calls that supply arguments to the given method call, + including itself. + + :param methodCall: root method call to iterate + :yield: method calls that supply arguments to the given method call + """ + yield from iteratePriorNodes(methodCall, nodeType=MethodCall) + + +def iteratePriorPrimitives( + valueNode: ValueNode, +) -> Generator[Primitive, None, None]: + """Yield all prior Primitive nodes that contribute to the given ValueNode. + + :param valueNode: root node to iterate + :yield: primitives that contribute to the given node + """ + yield from iteratePriorNodes(valueNode, nodeType=Primitive) + + +def evaluateArgument( + argument: str, typeHint: str | None +) -> int | float | bool | str: + """Evaluate the argument based on the given type hint. + If the type hint is missing or None, no evaluation is performed. + + :param argument: argument to be evaluated + :param typeHint: type hint suggesting how the argument should be evaluated + :return: evaluated argument + """ + try: + if typeHint in ["I", "B", "S", "J"]: + return int(argument) + elif typeHint == "Z": + return bool(int(argument)) + elif typeHint in ["F", "D"]: + return float(argument) + except ValueError: + pass + + return argument diff --git a/quark/evaluator/pyeval.py b/quark/evaluator/pyeval.py index 6c6bc320..94655f59 100644 --- a/quark/evaluator/pyeval.py +++ b/quark/evaluator/pyeval.py @@ -10,9 +10,11 @@ from quark import config from quark.core.struct.registerobject import RegisterObject from quark.core.struct.tableobject import TableObject +from quark.core.struct.valuenode import ( + Primitive, MethodCall, BytecodeOps +) from quark.utils.logger import defaultHandler -MAX_REG_COUNT = 40 log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) log.addHandler(defaultHandler) @@ -126,7 +128,7 @@ def __init__(self, apkinfo): self.eval[ "fill-array-data" ] = lambda ins: self._move_value_and_data_to_register( - (ins[0], ins[1], ins[1], ins[2]), "Embedded-array-data()[" + (ins[0], ins[1], ins[1], ins[2]), "Embedded-array-data()[]" ) self.type_mapping = { @@ -140,21 +142,27 @@ def __init__(self, apkinfo): "double": "D", } - self.table_obj = TableObject(MAX_REG_COUNT) + self.table_obj = TableObject() self.ret_stack = [] self.ret_type = "" self.apkinfo = apkinfo def _invoke(self, instruction, look_up=False, skip_self=False): """ - Function call in Android smali code. It will check if the corresponding table field has a value, if it does, + Function call in Android smali code. It will check if the corresponding + table field has a value, if it does, inserts its own function name into called_by_func column. """ - if look_up: + opcode, *regList, targetMethod = instruction + regIdxList = [int(r[1:]) for r in regList] + + if look_up and len(regIdxList) > 0: try: - instruction[-1] = self._lookup_implement( - self.table_obj.pop(int(instruction[1][1:])).current_type, + targetMethod = self._lookup_implement( + self.table_obj.getLatestRegValue( + regIdxList[0] + ).current_type, instruction[-1], skip_self=skip_self, ) @@ -163,51 +171,70 @@ def _invoke(self, instruction, look_up=False, skip_self=False): except IndexError: pass - executed_fuc = instruction[-1] - reg_list = instruction[1 : len(instruction) - 1] - value_of_reg_list = [] - + valueOfRegList = [] # query the value from hash table based on register index. - for reg in reg_list: - index = int(reg[1:]) - obj_stack = self.table_obj.get_obj_list(index) - if obj_stack: - var_obj = self.table_obj.pop(index) - value_of_reg_list.append(var_obj.value) - - # Remove duplicate parameter values to save memory - seen = {} - for idx, val in enumerate(value_of_reg_list): - if val in seen: - value_of_reg_list[idx] = f"" - else: - seen[val] = idx + for index in regIdxList: + if not self.table_obj.getRegValues(index): + # Insert a RegisterObject if one is missing. + # Therefore, we can trace the usage of this register. + self.table_obj.insert( + index, RegisterObject(Primitive("", None)) + ) - invoked_state = f"{executed_fuc}({','.join(value_of_reg_list)})" + value = self.table_obj.getLatestRegValue(index) + valueOfRegList.append(value.value) + + # Check whether any argument is missing a value type. + argIdxWithoutType = [ + idx + for idx, arg in enumerate(valueOfRegList) + if isinstance(arg, Primitive) and arg.value_type == "" + ] + if len(argIdxWithoutType) > 0: + # Set the missing value types based on the method's descriptor. + argTypes = ( + [] + if opcode.startswith("invoke-static") + else [targetMethod[: targetMethod.find("->")]] + ) - # insert the function and the parameter into called_by_func - for reg in reg_list: - index = int(reg[1:]) + rawArgTypes = targetMethod[ + targetMethod.find("(") + 1 : targetMethod.find(")") + ].split(" ") + + for argType in rawArgTypes: + argTypes.append(argType) + if argType in ["J", "D"]: + # Put long and double twice + # because these types take up two registers. + argTypes.append(argType) + + for argIdx in argIdxWithoutType: + valueOfRegList[argIdx].value_type = argTypes[argIdx] - if not self.table_obj.get_obj_list(index): - continue + methodCall = MethodCall(targetMethod, tuple(valueOfRegList)) + # insert the function and the parameter into called_by_func + for index in regIdxList: # add the function name into each parameter table - var_obj = self.table_obj.pop(index) - var_obj.called_by_func = invoked_state + value = self.table_obj.getLatestRegValue(index) + value.called_by_func = methodCall - if var_obj.bears_object(): + if ( + value.bears_object() + and value.current_type != "Ljava/lang/String;" + ): # If the register bears an object, update its value to reflect # the method invocation since the method may modify the # internal state of the object. - var_obj.value = invoked_state + value.value = methodCall - if not executed_fuc.endswith(")V"): + if not targetMethod.endswith(")V"): # push the return value into ret_stack - self.ret_stack.append(invoked_state) + self.ret_stack.append(methodCall) # Extract the type of return value - self.ret_type = executed_fuc[executed_fuc.index(")") + 1:] + self.ret_type = targetMethod[targetMethod.index(")") + 1 :] def _move_result(self, instruction): @@ -230,7 +257,7 @@ def _move_object(self, src_reg_idx: int, dest_reg_idx: int): RegisterObject. This allow both registers to point to the same object. """ # Get the source object from the table - src_obj = self.table_obj.pop(src_reg_idx) + src_obj = self.table_obj.getLatestRegValue(src_reg_idx) # Insert the source object to the destination register. self.table_obj.insert(dest_reg_idx, src_obj) @@ -241,7 +268,9 @@ def _assign_value(self, instruction, value_type=""): value = instruction[2] index = int(reg[1:]) - variable_object = RegisterObject(value=value, value_type=value_type) + wrapped_value = Primitive(value, value_type) + + variable_object = RegisterObject(value=wrapped_value, value_type=value_type) self.table_obj.insert(index, variable_object) def _assign_value_wide(self, instruction, value_type=""): @@ -252,8 +281,10 @@ def _assign_value_wide(self, instruction, value_type=""): value = instruction[2] index = int(reg[1:]) - variable_object = RegisterObject(value=value, value_type=value_type) - variable_object2 = RegisterObject(value=value, value_type=value_type) + wrapped_value = Primitive(value, value_type) + + variable_object = RegisterObject(value=wrapped_value, value_type=value_type) + variable_object2 = RegisterObject(value=wrapped_value, value_type=value_type) self.table_obj.insert(index, variable_object) self.table_obj.insert(index + 1, variable_object2) @@ -502,7 +533,9 @@ def AGET_KIND(self, instruction): value_type = self.type_mapping[instruction[0][index:]] else: array_reg_index = int(instruction[2][1:]) - value_type = self.table_obj.pop(array_reg_index).current_type + value_type = self.table_obj.getLatestRegValue( + array_reg_index + ).current_type # If value_type is not None if value_type: value_type = value_type[1:] @@ -543,7 +576,10 @@ def AGET_WIDE_KIND(self, instruction): try: - value_type = self.table_obj.pop(array_reg_index).current_type[1:] + array_reg = self.table_obj.getLatestRegValue(array_reg_index) + value_type = ( + array_reg.current_type[1:] if array_reg.current_type else None + ) destination = int(instruction[1][1:]) source_list = [int(reg[1:]) for reg in instruction[2:]] @@ -672,8 +708,8 @@ def BINOP_KIND(self, instruction): except IndexError as e: log.exception(f"{e} in BINOP_KIND") - def show_table(self): - return self.table_obj.get_table() + def show_table(self) -> dict[int, list[RegisterObject]]: + return self.table_obj.getTable() def _move_value_to_register( self, instruction, str_format, wide=False, value_type=None @@ -695,7 +731,15 @@ def _move_value_to_register( value_type=value_type, ) - def _lookup_implement(self, instance_type, method_full_name, skip_self=False): + def _lookup_implement( + self, + instance_type: str | None, + method_full_name: str, + skip_self: bool = False, + ): + if instance_type is None or len(instance_type) == 0: + return method_full_name + class_name, signature = method_full_name.split("->") index = signature.index("(") method_name, descriptor = signature[:index], signature[index:] @@ -742,14 +786,11 @@ def _move_value_and_data_to_register( data = instruction[-1] for source in source_list: - if not self.table_obj.get_obj_list(source): - value_dict = { - f"src{0}": instruction - } - value_dict["data"] = data - + if not self.table_obj.getRegValues(source): + # A source register used by the instruction is not initialized, + # create a RegisterObject for it. new_register = RegisterObject( - value=str_format.format(**value_dict), + value=Primitive("", value_type), value_type=value_type, ) self.table_obj.insert(source, new_register) @@ -785,21 +826,26 @@ def _transfer_register( self, source_list, destination, str_format, data=None, value_type=None ): try: - source_register_list = [self.table_obj.pop(index) for index in source_list] + source_register_list = [ + self.table_obj.getLatestRegValue(index) + for index in source_list + ] except IndexError: return + # If it's a simple move, preserve the value node. + if len(source_register_list) == 1 and str_format == "{src0}": + new_value = source_register_list[0].value + else: + # For complex operations, create a value node for the operation. + source_values = [reg.value for reg in source_register_list] + new_value = BytecodeOps(str_format, tuple(source_values), data) + if not value_type: value_type = source_register_list[0].current_type - value_dict = { - f"src{index}": register.value - for index, register in enumerate(source_register_list) - } - value_dict["data"] = data - new_register = RegisterObject( - value=str_format.format(**value_dict), + value=new_value, value_type=value_type, ) diff --git a/quark/script/__init__.py b/quark/script/__init__.py index 27387824..aec0b3cf 100644 --- a/quark/script/__init__.py +++ b/quark/script/__init__.py @@ -15,13 +15,9 @@ from quark.core.quark import Quark from quark.core.struct.methodobject import MethodObject from quark.core.struct.ruleobject import RuleObject as Rule +from quark.core.struct.valuenode import iteratePriorCalls from quark.evaluator.pyeval import PyEval from quark.utils.regex import URL_REGEX -from quark.utils.tools import ( - get_arguments_from_argument_str, - get_parenthetic_contents, -) - @functools.lru_cache def _getQuark(apk: PathLike) -> Quark: @@ -76,7 +72,6 @@ def isDebuggable(self) -> bool: :return: True/False """ debuggable = self._getAttribute("debuggable") - print(debuggable) if debuggable is None: return False @@ -226,58 +221,54 @@ def getArguments(self) -> List[Any]: self.innerObj ) - register_usage_records = ( - c_func - for table in usageTable - for val_obj in table - for c_func in val_obj.called_by_func + methodCalls = ( + methodCall + for register in usageTable.values() + for register_value in register + for methodCall in register_value.iterateInvolvedCalls() ) methodPattern = PyEval.get_method_pattern( self.targetMethod.innerObj.class_name, self.targetMethod.innerObj.name, - self.targetMethod.innerObj.descriptor + self.targetMethod.innerObj.descriptor, ) - matchedRecords = list(filter( - lambda record: methodPattern in record, - register_usage_records)) - - argumentStr = max(matchedRecords, key=len, default="")[:-1] - filterStr = f"{self.targetMethod.innerObj.class_name}->" + \ - self.targetMethod.innerObj.name + \ - self.targetMethod.descriptor - - argumentStr = argumentStr.replace(filterStr, "")[1:] - - return get_arguments_from_argument_str( - argumentStr, self.targetMethod.innerObj.descriptor + matchedCall = next( + ( + priorCall + for call in methodCalls + for priorCall in iteratePriorCalls(call) + if priorCall.method == methodPattern + ), + None, ) - allResult = self.behavior.hasString(".*", True) - argumentStr = max(allResult, key=len)[1:-1] - - argumentsOfSecondAPI = get_arguments_from_argument_str( - argumentStr, self.descriptor) + return matchedCall.getArguments() if matchedCall else [] - if self == self.behavior.secondAPI: - return argumentsOfSecondAPI - else: - methodPattern = PyEval.get_method_pattern( - self.className, self.methodName, self.descriptor - ) - - argumentsOfFirstAPI = ( - get_parenthetic_contents( - argument, argument.find(methodPattern) - ) - for argument in argumentsOfSecondAPI - if methodPattern in argument - ) + firstAPI, secondAPI = self.behavior.firstAPI, self.behavior.secondAPI + firstAPICall, secondAPICall = next( + Quark.findMethodCallPairs( + self.quarkResult.quark._evaluate_method( + self.behavior.methodCaller.innerObj + ), + ( + firstAPI.className, + firstAPI.methodName, + firstAPI.descriptor), + ( + secondAPI.className, + secondAPI.methodName, + secondAPI.descriptor, + ), + ), + (None, None), + ) - return get_arguments_from_argument_str( - next(argumentsOfFirstAPI, ""), self.descriptor - ) + apiCall = ( + firstAPICall if self == self.behavior.firstAPI else secondAPICall + ) + return apiCall.getArguments() if apiCall else [] def findSuperclassHierarchy(self) -> List[str]: """Find all superclasses of this method object. @@ -359,11 +350,12 @@ def hasString(self, pattern: str, isRegex=False) -> List[str]: ) result_generator = ( - self.quarkResult.quark.check_parameter_on_single_method( - usage_table=usageTable, - first_method=self.firstAPI.innerObj, - second_method=self.secondAPI.innerObj, - keyword_item_list=[(pattern,), (pattern,)], + self.quarkResult.quark.checkParameterOnSingleMethod( + usageTable=usageTable, + firstMethod=self.firstAPI.innerObj, + secondMethod=self.secondAPI.innerObj, + firstMethodKeywords=(pattern,), + secondMethodKeywords=(pattern,), regex=isRegex, ) ) @@ -388,48 +380,18 @@ def getParamValues(self) -> List[Any]: :return: python list containing parameter values """ - def __getArgumentFromMethodCall(method_call_str: str): - - # Extract the part after the method name - # e.g. 'La/String;->init(II)V;('ab)_',3)' extracts 'V;('ab)_',3)' - method_start_idx = method_call_str.find("(") - method_with_args = method_call_str[method_start_idx + 1:] - method_end_idx = method_with_args.find(")") - method_with_args = method_with_args[method_end_idx + 1:] - - # Extract and split the arguments - # e.g. 'V;('ab)_',3)' extracts 'ab)_' and '3' - args_start_idx = method_with_args.find("(") - args_with_parentheses = method_with_args[args_start_idx + 1:] - - args_end_idx = args_with_parentheses.rfind(")") - - args_str = args_with_parentheses[:args_end_idx] - extracted_arguments = args_str.split(",") - - return extracted_arguments - - allResult = self.hasString(".*", True) - argumentStr = max(allResult, key=len)[1:-1] - - arguments = get_arguments_from_argument_str( - argumentStr, self.secondAPI.descriptor) - new_arguments = [] - - for argument in arguments: - if not isinstance(argument, str): - new_arguments.append(argument) - continue - - # Extract the arguments from method call and remove class arguments - if ";->" in argument: - method_call = argument.split(";->")[-1] - new_args = __getArgumentFromMethodCall(method_call) - new_arguments.extend(new_args) - elif not (argument.startswith("L") and argument.endswith(";")): - new_arguments.append(argument) - - return new_arguments + usageTable = self.quarkResult.quark._evaluate_method(self.methodCaller.innerObj) + + firstAPI, secondAPI = self.firstAPI, self.secondAPI + _, secondAPI = next( + Quark.findMethodCallPairs( + usageTable, + (firstAPI.className, firstAPI.methodName, firstAPI.descriptor), + (secondAPI.className, secondAPI.methodName, secondAPI.descriptor), + ) + ) + + return secondAPI.getArguments() def isArgFromMethod(self, targetMethod: List[str]) -> bool: """Check if there are any argument from the target method. @@ -438,40 +400,60 @@ def isArgFromMethod(self, targetMethod: List[str]) -> bool: descriptor of target method :return: True/False """ - className, methodName, descriptor = targetMethod - - pattern = PyEval.get_method_pattern(className, methodName, descriptor) + usageTable = self.quarkResult.quark._evaluate_method( + self.methodCaller.innerObj + ) - return bool(self.hasString(pattern)) + firstAPI, secondAPI = self.firstAPI, self.secondAPI + apiPairs = Quark.findMethodCallPairs( + usageTable, + (firstAPI.className, firstAPI.methodName, firstAPI.descriptor), + (secondAPI.className, secondAPI.methodName, secondAPI.descriptor), + ) + + methodCallsInArgs = ( + call + for _, secondApiCall in apiPairs + for call in iteratePriorCalls(secondApiCall) + ) + + targetMethodPattern = PyEval.get_method_pattern(*targetMethod) + return any( + methodCall.method == targetMethodPattern + for methodCall in methodCallsInArgs + ) def getMethodsInArgs(self) -> List[str]: """Get the methods which the arguments in API2 has passed through. :return: python list containing method instances """ - METHOD_REGEX = r"L(.*?)\;\(" methodCalled = [] + usageTable = self.quarkResult.quark._evaluate_method(self.methodCaller.innerObj) + + firstAPI, secondAPI = self.firstAPI, self.secondAPI + _, secondAPICall = next( + Quark.findMethodCallPairs( + usageTable, + (firstAPI.className, firstAPI.methodName, firstAPI.descriptor), + (secondAPI.className, secondAPI.methodName, secondAPI.descriptor), + ) + ) - allResult = self.hasString(".*", True) - argumentStr = max(allResult, key=len)[1:-1] - - arguments = get_arguments_from_argument_str( - argumentStr, self.secondAPI.descriptor) - - for param in arguments: - for result in re.findall(METHOD_REGEX, param): - className = "L" + result.split("->")[0] - methodName = re.findall(r"->(.*?)\(", result)[0] - descriptor = result.split(methodName)[-1] + ";" + for methodCall in iteratePriorCalls(secondAPICall): + result = methodCall.method + className = result.split("->")[0] + methodName = re.findall(r"->(.*?)\(", result)[0] + descriptor = result.split(methodName)[-1] - methodObj_list = self.quarkResult.quark.apkinfo.find_method( - class_name=className, - method_name=methodName, - descriptor=descriptor - ) + methodObj_list = self.quarkResult.quark.apkinfo.find_method( + class_name=className, + method_name=methodName, + descriptor=descriptor + ) - for methodObj in methodObj_list: - methodCalled.append(Method(methodObj=methodObj)) + for methodObj in methodObj_list: + methodCalled.append(Method(methodObj=methodObj)) return methodCalled @@ -497,13 +479,16 @@ def behaviorOccurList(self): Behavior( quarkResultInstance=self, methodCaller=self._wrapMethodObject( - call_graph_analysis["parent"] + methodObj=call_graph_analysis["parent"], + quark=self.quark ), firstAPI=self._wrapMethodObject( - call_graph_analysis["first_call"] + call_graph_analysis["first_call"], + quark=self.quark ), secondAPI=self._wrapMethodObject( - call_graph_analysis["second_call"] + call_graph_analysis["second_call"], + quark=self.quark ), ) for call_graph_analysis in self.innerObj.call_graph_analysis_list @@ -528,13 +513,7 @@ def getMethodXrefFrom(self, method: Method) -> List[Method]: return [self._wrapMethodObject(caller) for caller in list(caller_set)] def _wrapMethodObject(self, methodObj: MethodObject, quark: Quark = None, targetMethod: Method = None) -> Method: - if methodObj: - if targetMethod: - return Method(self, methodObj=methodObj, quark=quark, targetMethod=targetMethod) - else: - return Method(self, methodObj) - else: - return None + return Method(self, methodObj=methodObj, quark=quark, targetMethod=targetMethod) def getAllStrings(self) -> List[str]: """ diff --git a/quark/utils/tools.py b/quark/utils/tools.py index 02e7493b..0fd1fce0 100644 --- a/quark/utils/tools.py +++ b/quark/utils/tools.py @@ -107,78 +107,3 @@ def filter_api_by_usage_count(data, api_pool, percentile_rank=0.2): S_set.append(str_statistic_result[api]) return P_set, S_set - - -def get_parenthetic_contents(string: str, start_index: int) -> str: - """Get the content between a pair of parentheses. - - :param string: string to be parsed - :param start_index: index to specify the parenthesis - :return: string holding the content - """ - start_index = string.find("(", start_index) - if start_index == -1: - return string - - parenthetic_counter = 0 - for idx, char in enumerate(string[start_index:]): - if char == "(": - parenthetic_counter += 1 - elif char == ")": - parenthetic_counter -= 1 - - if parenthetic_counter == 0: - end_index = idx + start_index + 1 - return string[start_index:end_index] - - return string[start_index:] - - -def get_arguments_from_argument_str( - argument_str: str, descriptor: str -) -> List[Any]: - """Get arguments from an argument string. - - :param argument_str: string that holds multiple arguments and uses commas - as separators - :param descriptor: string that holds a descriptor for type inference - :return: python list that holds the arguments - """ - - def __valueOf(argument: str, type_hint: str) -> Union[int, float, bool]: - try: - if type_hint in ["I", "B", "S", "J"]: - return int(argument) - elif type_hint == "Z": - return bool(int(argument)) - elif type_hint in ["F", "D"]: - return float(argument) - except ValueError: - pass - - return argument - - arguments = [] - - parentheses_counter = 0 - index_of_last_separator = 0 - for index, char in enumerate(argument_str): - if char == "(": - parentheses_counter += 1 - elif char == ")": - parentheses_counter -= 1 - elif char == "," and parentheses_counter == 0: - arguments.append(argument_str[index_of_last_separator:index]) - index_of_last_separator = index + 1 - - arguments.append(argument_str[index_of_last_separator:]) - - type_hints = descriptor[1: descriptor.find(")")].split() - type_hints = reversed(type_hints) - arguments = [ - __valueOf(argument, next(type_hints, "")) - for argument in reversed(arguments) - ] - arguments.reverse() - - return arguments diff --git a/tests/core/struct/test_registerobject.py b/tests/core/struct/test_registerobject.py index 05d6f07e..2bd33a75 100644 --- a/tests/core/struct/test_registerobject.py +++ b/tests/core/struct/test_registerobject.py @@ -1,6 +1,7 @@ import pytest from quark.core.struct.registerobject import RegisterObject +from quark.core.struct.valuenode import MethodCall, Primitive @pytest.fixture() @@ -46,4 +47,40 @@ def test_bears_object(self): assert reg_with_object.bears_object() is True assert reg_with_primitive.bears_object() is False - assert reg_with_none.bears_object() is False + assert reg_with_none.bears_object() is True + + def test_iterate_involved_calls_returns_nested_calls(self): + inner_call = MethodCall( + "Lfoo/Bar;->first()V", (Primitive("alpha", "Ljava/lang/String;"),) + ) + outer_call = MethodCall( + "Lfoo/Bar;->second()V", (inner_call, Primitive("beta", "I")) + ) + register_obj = RegisterObject("value", outer_call) + + calls = list(register_obj.iterateInvolvedCalls()) + + assert calls == [outer_call, inner_call] + + def test_iterate_involved_calls_with_multiple_sources(self): + shared_inner_call = MethodCall( + "Lfoo/Bar;->shared()V", (Primitive("1", "I"),) + ) + first_call = MethodCall( + "Lfoo/Bar;->first()V", (shared_inner_call,) + ) + second_call = MethodCall( + "Lfoo/Bar;->second()V", (Primitive("value", "Ljava/lang/String;"),) + ) + + register_obj = RegisterObject("value", first_call) + register_obj.called_by_func = second_call + + calls = list(register_obj.iterateInvolvedCalls()) + + assert calls == [first_call, shared_inner_call, second_call] + + def test_iterate_involved_calls_empty_when_no_history(self): + register_obj = RegisterObject("value") + + assert list(register_obj.iterateInvolvedCalls()) == [] diff --git a/tests/core/struct/test_ruleobject.py b/tests/core/struct/test_ruleobject.py index a726b043..aab0d41e 100644 --- a/tests/core/struct/test_ruleobject.py +++ b/tests/core/struct/test_ruleobject.py @@ -5,6 +5,39 @@ from quark.core.struct.ruleobject import RuleObject +def build_rule_json(first_keywords=None, second_keywords=None): + api_entries = [ + { + "class": "Landroid/telephony/TelephonyManager", + "method": "getDeviceId", + "descriptor": "()Ljava/lang/String;", + }, + { + "class": "Landroid/telephony/SmsManager", + "method": "sendTextMessage", + "descriptor": "()V", + }, + ] + + if first_keywords is not None: + api_entries[0]["match_keywords"] = first_keywords + + if second_keywords is not None: + api_entries[1]["match_keywords"] = second_keywords + + return { + "crime": "Send Location via SMS", + "permission": [ + "android.permission.SEND_SMS", + "android.permission.ACCESS_COARSE_LOCATION", + "android.permission.ACCESS_FINE_LOCATION", + ], + "api": api_entries, + "score": 4, + "label": ["location", "collection"], + } + + @pytest.fixture() def rule_obj(scope="function"): rule_json = """ @@ -129,3 +162,37 @@ def test_androguard_format_api(rule_obj): @staticmethod def test_java_format_api(rule_obj): assert rule_obj.api[1]["descriptor"] == "(I Ljava/lang/String; [B J)V" + + @staticmethod + def test_first_api_keywords_returns_list(): + rule = RuleObject( + "dummy.json", + jsonData=build_rule_json(first_keywords=["imei", "iccid"]), + ) + + assert rule.firstApiKeywords == ["imei", "iccid"] + + @staticmethod + def test_first_api_keywords_returns_none_when_missing(): + rule = RuleObject( + "dummy.json", jsonData=build_rule_json(first_keywords=[]) + ) + + assert rule.firstApiKeywords is None + + @staticmethod + def test_second_api_keywords_returns_list(): + rule = RuleObject( + "dummy.json", + jsonData=build_rule_json(second_keywords=["sms", "send"]), + ) + + assert rule.secondApiKeywords == ["sms", "send"] + + @staticmethod + def test_second_api_keywords_returns_none_when_not_a_list(): + rule = RuleObject( + "dummy.json", jsonData=build_rule_json(second_keywords="sms") + ) + + assert rule.secondApiKeywords is None diff --git a/tests/core/struct/test_tableobject.py b/tests/core/struct/test_tableobject.py index 4f183cd1..c49b249c 100644 --- a/tests/core/struct/test_tableobject.py +++ b/tests/core/struct/test_tableobject.py @@ -1,11 +1,12 @@ +from typing import Generator import pytest from quark.core.struct.tableobject import TableObject @pytest.fixture() -def table_obj(): - table_obj = TableObject(5) +def table_obj() -> Generator[TableObject, None, None]: + table_obj = TableObject() yield table_obj @@ -13,65 +14,38 @@ def table_obj(): class TestTableObject: - def test_init_with_no_arg(self): - with pytest.raises(TypeError): - _ = TableObject() - - def test_init_with_non_numeric(self): - with pytest.raises(TypeError): - _ = TableObject(None) - - def test_init_with_valid_arg(self): - table_obj = TableObject(5) - - assert isinstance(table_obj, TableObject) - assert len(table_obj.hash_table) == 5 - assert table_obj.hash_table == [[], [], [], [], []] - - def test_insert_with_non_numeric(self, table_obj): - with pytest.raises(TypeError): - table_obj.insert(None) - def test_insert_with_number_once(self, table_obj): index, data = 1, "Value" table_obj.insert(index, data) - assert table_obj.hash_table[index] == [data] + assert table_obj.getRegValues(index) == [data] def test_insert_with_number_twice(self, table_obj): table_obj.insert(0, "first") table_obj.insert(0, "second") - assert table_obj.hash_table[0] == ["first", "second"] - - def test_insert_with_num_beyond_max(self, table_obj): - index, data = 6, "Max value" - - try: - table_obj.insert(index, data) - except Exception: - pytest.fail("Should not raise exceptions.") + assert table_obj.getRegValues(0) == ["first", "second"] - def test_get_obj_list_before_insertion(self, table_obj): - assert table_obj.get_obj_list(3) == [] + def test_getRegValues_before_insertion(self, table_obj): + assert table_obj.getRegValues(3) == [] - def test_get_obj_list_after_insertion(self, table_obj): + def test_getRegValues_after_insertion(self, table_obj): table_obj.insert(3, "test_value") - assert table_obj.get_obj_list(3) == ["test_value"] + assert table_obj.getRegValues(3) == ["test_value"] - def test_get_table(self, table_obj): - assert table_obj.hash_table == table_obj.get_table() + def test_getTable(self, table_obj): + assert table_obj.hash_table == table_obj.getTable() - def test_pop_none(self, table_obj): + def test_getLatestRegValue_none(self, table_obj): with pytest.raises(IndexError): - _ = table_obj.pop(1) + _ = table_obj.getLatestRegValue(1) - def test_pop_value(self, table_obj): + def test_getLatestRegValue_value(self, table_obj): table_obj.insert(4, "one") table_obj.insert(4, "two") table_obj.insert(4, "three") - assert table_obj.pop(4) == "three" - assert table_obj.get_obj_list(4) == ["one", "two", "three"] + assert table_obj.getLatestRegValue(4) == "three" + assert table_obj.getRegValues(4) == ["one", "two", "three"] diff --git a/tests/core/struct/test_valuenode.py b/tests/core/struct/test_valuenode.py new file mode 100644 index 00000000..ea85cbc2 --- /dev/null +++ b/tests/core/struct/test_valuenode.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# This file is part of Quark-Engine - https://github.com/quark-engine/quark-engine +# See the file 'LICENSE' for copying permission. +import pytest +from quark.core.struct.valuenode import ( + Primitive, + MethodCall, + BytecodeOps, + iteratePriorPrimitives, + iteratePriorCalls, + evaluateArgument, +) + + +class TestPrimitive: + def test_resolve_number_and_string(self): + assert Primitive(123, "I").resolve() == "123" + assert Primitive("test", None).resolve() == "test" + + def test_eq_is_identity_based(self): + value = Primitive(True, "Z") + assert value == value + assert Primitive(True, "Z") != Primitive(True, "Z") + + +class TestMethodCall: + def test_resolve_simple(self): + mc = MethodCall( + "do_something", (Primitive("first", None), Primitive(2, "I")) + ) + assert mc.resolve() == "do_something(first,2)" + + def test_resolve_nested(self): + inner_call = MethodCall("inner", (Primitive(True, "Z"),)) + outer_call = MethodCall("outer", (Primitive(1, "I"), inner_call)) + assert outer_call.resolve() == "outer(1,inner(True))" + + def test_getArguments_resolves_primitives_and_calls(self): + nested = MethodCall("inner", (Primitive("text", None),)) + method_call = MethodCall("outer", (Primitive("10", "I"), nested)) + assert method_call.getArguments() == [10, "inner(text)"] + + +class TestBytecodeOps: + def test_resolve_simple(self): + op = BytecodeOps("const-string {data}", (), "Hello") + assert op.resolve() == "const-string Hello" + op_none = BytecodeOps("const {data}", (), None) + assert op_none.resolve() == "const None" + op_add = BytecodeOps( + "add-int({src0}, {src1})", + (Primitive(5, "I"), Primitive(10, "I")), + None, + ) + assert op_add.resolve() == "add-int(5, 10)" + + def test_resolve_nested(self): + inner_op = BytecodeOps("cast({src0})", (Primitive(1.0, "F"),), "int") + outer_call = MethodCall("use_val", (inner_op,)) + assert outer_call.resolve() == "use_val(cast(1.0))" + + +class TestIterators: + @pytest.fixture + def complex_structure(self): + """ + Creates a complex nested structure for testing iterators. + Structure: + call1 -> [prim1, call2] + call2 -> [prim2, op1] + op1 -> [prim3] + """ + prim1 = Primitive("p1", None) + prim2 = Primitive(2, "I") + prim3 = Primitive(True, "Z") + op1 = BytecodeOps("op({src0})", (prim3,), None) + call2 = MethodCall("func2", (prim2, op1)) + call1 = MethodCall("func1", (prim1, call2)) + return call1, call2, op1, prim1, prim2, prim3 + + def test_iteratePriorCalls(self, complex_structure): + call1, call2, op1, _, _, _ = complex_structure + calls = list(iteratePriorCalls(call1)) + assert len(calls) == 2 + assert call1 in calls + assert call2 in calls + assert op1 not in calls + + def test_iteratePriorPrimitives(self, complex_structure): + call1, _, _, prim1, prim2, prim3 = complex_structure + primitives = list(iteratePriorPrimitives(call1)) + assert primitives.count(prim1) == 1 + assert primitives.count(prim2) == 1 + assert primitives.count(prim3) == 1 + + def test_iteratePriorCalls_deduplicates_reused_nodes(self): + shared_call = MethodCall("shared", (Primitive("x", None),)) + outer = MethodCall("outer", (shared_call, shared_call)) + calls = list(iteratePriorCalls(outer)) + assert calls.count(shared_call) == 1 + + def test_iterativeResolve_direct_recursion(self): + mc = MethodCall("recursive", ()) + # Manually create a cycle + mc.argumentNodes = (mc,) + assert mc.resolve() == "recursive()" + + def test_iterativeResolve_indirect_recursion(self): + inner = MethodCall("inner", ()) + outer = MethodCall("outer", (inner,)) + # Create cycle: outer -> inner -> outer + inner.argumentNodes = (outer,) + assert outer.resolve() == "outer(inner())" + + def test_iterativeResolve_diamond_dependency(self): + # Ensure shared nodes in a DAG are not treated as recursion + # root -> b1 -> leaf + # root -> b2 -> leaf + leaf = Primitive("leaf", None) + branch1 = MethodCall("b1", (leaf,)) + branch2 = MethodCall("b2", (leaf,)) + root = MethodCall("root", (branch1, branch2)) + assert root.resolve() == "root(b1(leaf),b2(leaf))" + + +@pytest.mark.parametrize( + "value,type_hint,expected", + [ + ("42", "I", 42), + ("1", "Z", True), + ("1.5", "F", 1.5), + ("not-a-number", "I", "not-a-number"), + ("plain", None, "plain"), + ], +) +def test_evaluateArgument_converts_values(value, type_hint, expected): + assert evaluateArgument(value, type_hint) == expected diff --git a/tests/core/test_quark.py b/tests/core/test_quark.py index 0abedb88..a1f03289 100644 --- a/tests/core/test_quark.py +++ b/tests/core/test_quark.py @@ -8,6 +8,9 @@ from quark.core.quark import Quark from quark.core.struct.ruleobject import RuleObject +from quark.core.struct.registerobject import RegisterObject +from quark.core.struct.valuenode import MethodCall, Primitive +from quark.evaluator.pyeval import PyEval APK_SOURCE = ( "https://github.com/quark-engine/apk-samples" @@ -32,7 +35,7 @@ @pytest.fixture(scope="function") -def simple_quark_obj(): +def simple_quark_obj() -> Quark: r = requests.get(APK_SOURCE, allow_redirects=True) open(APK_FILENAME, "wb").write(r.content) @@ -386,75 +389,61 @@ def test_check_parameter_is_False(self, quark_obj): assert result is False - def test_check_parameter_values_with_no_keyword_rule( + def test_getMatchedKeywords_with_no_keyword_rule( self, simple_quark_obj, rule_without_keyword ): rule_object = RuleObject(rule_without_keyword) - with patch("quark.core.quark.Quark.check_parameter_values") as mock: + with patch("quark.core.quark.Quark.getMatchedKeywords") as mock: simple_quark_obj.run(rule_object) mock.assert_not_called() - def test_check_parameter_values_with_one_keyword_rule( + def test_getMatchedKeywords_with_one_keyword_rule( self, simple_quark_obj_2, rule_with_one_keyword ): rule_object = RuleObject(rule_with_one_keyword) - with patch("quark.core.quark.Quark.check_parameter_values") as mock: + with patch("quark.core.quark.Quark.getMatchedKeywords") as mock: simple_quark_obj_2.run(rule_object) mock.assert_called() - def test_check_parameter_values_without_matched_str(self, simple_quark_obj): - source_str = ( - "Landroid/content/ContentResolver;->query(Landroid/net/Uri;" - " [Ljava/lang/String; Ljava/lang/String; [Ljava/lang/String;" - " Ljava/lang/String;)Landroid/database/Cursor;" - "(Landroid/content/Context;" - "->getContentResolver()Landroid/content/ContentResolver;" - "(Lahmyth/mine/king/ahmyth/MainService;->getContextOfApplication()" - "Landroid/content/Context;()),Landroid/net/Uri;" - "->parse(Ljava/lang/String;)" - "Landroid/net/Uri;(file://usr/bin/su),v0,v0,v0,v0)" - ) - pattern_list = ( + def test_getMatchedKeywords_without_matched_str(self, simple_quark_obj: Quark): + method_call = MethodCall( + "Landroid/content/ContentResolver;->query" + "(Landroid/net/Uri;)Landroid/database/Cursor;", ( - "Landroid/content/ContentResolver;->query(Landroid/net/Uri;" - " [Ljava/lang/String; Ljava/lang/String; [Ljava/lang/String;" - " Ljava/lang/String;)Landroid/database/Cursor;" - ), + MethodCall( + "Landroid/net/Uri;->parse(Ljava/lang/String;)", + (Primitive("URL_FOR_TEST", "Ljava/lang/String;"),) + ), + Primitive(1, "I") + ) ) - keyword_item_list = [("content://call_log/calls",)] - result = simple_quark_obj.check_parameter_values( - source_str, pattern_list, keyword_item_list + result = simple_quark_obj.getMatchedKeywords( + method_call, ("KW_NOT_MATCHING",), False ) - assert bool(result) is False - - def test_check_parameter_values_with_matched_str(self, simple_quark_obj): - source_str = ( - "Landroid/database/Cursor;->getColumnIndex" - "(Ljava/lang/String;)I(Landroid/content/ContentResolver;" - "->query(Landroid/net/Uri; [Ljava/lang/String; Ljava/lang/String;" - " [Ljava/lang/String; Ljava/lang/String;)Landroid/database/Cursor;" - "(Landroid/content/Context;->getContentResolver()" - "Landroid/content/ContentResolver;(Lahmyth/mine/king/ahmyth" - "/MainService;->getContextOfApplication()Landroid/content/" - "Context;()),Landroid/net/Uri;->parse(Ljava/lang/String;)" - "Landroid/net/Uri;(content://call_log/calls),v0,v0,v0,v0),number)" - ) - pattern_list = ( - "Landroid/content/ContentResolver;->query(Landroid/net/Uri;" - " [Ljava/lang/String; Ljava/lang/String; [Ljava/lang/String;" - " Ljava/lang/String;)Landroid/database/Cursor;", + assert result == [] + + def test_getMatchedKeywords_with_matched_str(self, simple_quark_obj): + method_call = MethodCall( + "Landroid/content/ContentResolver;->query" + "(Landroid/net/Uri;)Landroid/database/Cursor;", + ( + MethodCall( + "Landroid/net/Uri;->parse(Ljava/lang/String;)", + (Primitive("content://call_log/calls", "Ljava/lang/String;"),) + ), + Primitive(1, "I") + ) ) - keyword_item_list = [("content://call_log/calls",)] - result = simple_quark_obj.check_parameter_values( - source_str, pattern_list, keyword_item_list + result = simple_quark_obj.getMatchedKeywords( + method_call, ("content://call_log/calls",), False ) - assert bool(result) is True + assert result == ["content://call_log/calls"] def test_get_json_report(self, quark_obj): json_report = quark_obj.get_json_report() @@ -543,3 +532,83 @@ def testLabelReportWithDetailedTable( assert (quark_obj.quark_analysis .label_report_table.rows[0]) == correctTableRow + + @staticmethod + def _registerWithCall(call: MethodCall) -> RegisterObject: + register = RegisterObject(value=Primitive("0", "I")) + register.called_by_func = call + return register + + def testReturnsSinglePairWhenFirstCallFeedsSecond(self): + firstMethodInfo = ("Lfoo;", "first", "()V") + secondMethodInfo = ("Lfoo;", "second", "()V") + + firstCall = MethodCall( + method=PyEval.get_method_pattern(*firstMethodInfo), + argumentNodes=(Primitive("input", "Ljava/lang/String;"),), + ) + + secondCall = MethodCall( + method=PyEval.get_method_pattern(*secondMethodInfo), + argumentNodes=(firstCall,), + ) + + usageTable = {0: [self._registerWithCall(secondCall)]} + + result = list( + Quark.findMethodCallPairs( + usageTable, firstMethodInfo, secondMethodInfo + ) + ) + + assert result == [(firstCall, secondCall)] + + def testReturnsMultiplePairsWhenSecondUsesMultipleFirstCalls(self): + firstMethodInfo = ("Lfoo;", "first", "()V") + secondMethodInfo = ("Lfoo;", "second", "()V") + + firstCallOne = MethodCall( + method=PyEval.get_method_pattern(*firstMethodInfo), + argumentNodes=(Primitive("alpha", "Ljava/lang/String;"),), + ) + firstCallTwo = MethodCall( + method=PyEval.get_method_pattern(*firstMethodInfo), + argumentNodes=(Primitive("beta", "Ljava/lang/String;"),), + ) + + secondCall = MethodCall( + method=PyEval.get_method_pattern(*secondMethodInfo), + argumentNodes=(firstCallOne, firstCallTwo), + ) + + usageTable = {0: [self._registerWithCall(secondCall)]} + + result = list( + Quark.findMethodCallPairs( + usageTable, firstMethodInfo, secondMethodInfo + ) + ) + + assert result == [ + (firstCallOne, secondCall), + (firstCallTwo, secondCall), + ] + + def testReturnsEmptyWhenSecondCallLacksMatchingPriorCall(self): + firstMethodInfo = ("Lfoo;", "first", "()V") + secondMethodInfo = ("Lfoo;", "second", "()V") + + second_call = MethodCall( + method=PyEval.get_method_pattern(*secondMethodInfo), + argumentNodes=(Primitive("gamma", "Ljava/lang/String;"),), + ) + + usageTable = {0: [self._registerWithCall(second_call)]} + + result = list( + Quark.findMethodCallPairs( + usageTable, firstMethodInfo, secondMethodInfo + ) + ) + + assert result == [] diff --git a/tests/evaluator/test_pyeval.py b/tests/evaluator/test_pyeval.py index 27ba77fc..edfca455 100644 --- a/tests/evaluator/test_pyeval.py +++ b/tests/evaluator/test_pyeval.py @@ -2,12 +2,12 @@ from unittest.mock import patch import pytest -import requests from quark.core.apkinfo import AndroguardImp from quark.core.struct.registerobject import RegisterObject from quark.core.struct.tableobject import TableObject -from quark.evaluator.pyeval import MAX_REG_COUNT, PyEval +from quark.evaluator.pyeval import PyEval +from quark.core.struct.valuenode import Primitive, MethodCall @pytest.fixture() @@ -37,28 +37,41 @@ def apkinfo(SAMPLE_PATH_13667): def pyeval(apkinfo): pyeval = PyEval(apkinfo) - # mock_hash_table = [...[], [v4_mock_variable_obj], [], [], - # [v9_mock_variable_obj]....] v4_mock_variable_obj = RegisterObject( - "Lcom/google/progress/SMSHelper;", + value=Primitive( + "Lcom/google/progress/SMSHelper;", "Lcom/google/progress/SMSHelper;" + ), value_type="Lcom/google/progress/SMSHelper;", ) v5_mock_variable_obj = RegisterObject( - "some_number", "java.lang.String.toString()", value_type="I" + value=MethodCall( + "java.lang.String.toString", (Primitive("some_number", ""),) + ), + value_type="I", ) v6_mock_variable_obj = RegisterObject( - "an_array", "java.lang.Collection.toArray()", value_type="[I" + value=MethodCall( + "java.lang.Collection.toArray", (Primitive("an_array", ""),) + ), + value_type="[I", + ) + v7_mock_variable_obj = RegisterObject( + value=Primitive("a_float", "F"), value_type="F" ) - v7_mock_variable_obj = RegisterObject("a_float", value_type="F") v8_mock_variable_obj = RegisterObject( - "ArrayMap object", + value=Primitive( + "ArrayMap object", "Landroid/support/v4/util/ArrayMap;" + ), value_type="Landroid/support/v4/util/ArrayMap;", ) v9_mock_variable_obj = RegisterObject( - "some_string", - "java.io.file.close()", + value=Primitive("some_string", "Ljava/lang/String;"), value_type="Ljava/lang/String;", ) + v9_mock_variable_obj.called_by_func.append( + MethodCall("java.io.file.close", tuple()) + ) + pyeval.table_obj.insert(4, v4_mock_variable_obj) pyeval.table_obj.insert(5, v5_mock_variable_obj) pyeval.table_obj.insert(6, v6_mock_variable_obj) @@ -257,7 +270,6 @@ class TestPyEval: def test_init(self, apkinfo): pyeval = PyEval(apkinfo) - assert len(pyeval.table_obj.hash_table) == MAX_REG_COUNT assert isinstance(pyeval.table_obj, TableObject) assert pyeval.ret_stack == [] @@ -271,7 +283,7 @@ def test_invoke_with_non_list_object(self, pyeval): def test_invoke_with_empty_list(self, pyeval): instruction = [] - with pytest.raises(IndexError): + with pytest.raises(ValueError): pyeval._invoke(instruction) def test_invoke_with_wrong_types(self, pyeval): @@ -286,36 +298,93 @@ def test_invoke_with_invalid_value(self, pyeval): with pytest.raises(ValueError): pyeval._invoke(instruction) + def test_invoke_fills_missing_types_for_instance_calls(self, pyeval): + instance_idx = 10 + arg_idx = 11 + + instance_value = Primitive("instance", "") + arg_value = Primitive("number", "") + + pyeval.table_obj.insert( + instance_idx, RegisterObject(instance_value, value_type=None) + ) + pyeval.table_obj.insert( + arg_idx, RegisterObject(arg_value, value_type=None) + ) + + instruction = [ + "invoke-virtual", + f"v{instance_idx}", + f"v{arg_idx}", + "Lcom/example/Worker;->run(I)Ljava/lang/String;", + ] + + pyeval._invoke(instruction) + + assert instance_value.value_type == "Lcom/example/Worker;" + assert arg_value.value_type == "I" + + def test_invoke_fills_missing_types_for_static_calls(self, pyeval): + first_arg_idx = 12 + second_arg_idx = 13 + + first_arg_value = Primitive("threshold", "") + second_arg_value = Primitive("payload", "") + + pyeval.table_obj.insert( + first_arg_idx, RegisterObject(first_arg_value, value_type=None) + ) + pyeval.table_obj.insert( + second_arg_idx, RegisterObject(second_arg_value, value_type=None) + ) + + instruction = [ + "invoke-static", + f"v{first_arg_idx}", + f"v{second_arg_idx}", + ( + "Lcom/example/Helpers;" + "->mix(I Ljava/lang/String;)V" + ), + ] + + pyeval._invoke(instruction) + + assert first_arg_value.value_type == "I" + assert second_arg_value.value_type == "Ljava/lang/String;" + def test_invoke_with_func_returning_value(self, pyeval): instruction = ["invoke-kind", "v4", "v9", "some_function()Lclass;"] pyeval._invoke(instruction) - assert pyeval.table_obj.pop(4).called_by_func == [ - "some_function()Lclass;(Lcom/google/progress/SMSHelper;,some_string)" - ] - assert pyeval.table_obj.pop(9).called_by_func == [ - "java.io.file.close()", - "some_function()Lclass;(Lcom/google/progress/SMSHelper;,some_string)", - ] - assert pyeval.ret_stack == [ - "some_function()Lclass;(Lcom/google/progress/SMSHelper;,some_string)" - ] + v4 = pyeval.table_obj.getLatestRegValue(4) + v9 = pyeval.table_obj.getLatestRegValue(9) + + assert len(v4.called_by_func) == 1 + assert v4.called_by_func[0].resolve() == "some_function()Lclass;(Lcom/google/progress/SMSHelper;,some_string)" + + assert len(v9.called_by_func) == 2 + assert v9.called_by_func[1].resolve() == "some_function()Lclass;(Lcom/google/progress/SMSHelper;,some_string)" + + assert len(pyeval.ret_stack) == 1 + assert pyeval.ret_stack[0].resolve() == "some_function()Lclass;(Lcom/google/progress/SMSHelper;,some_string)" assert pyeval.ret_type == "Lclass;" - @pytest.mark.skip(reason="discussion needed.") def test_invoke_with_func_not_returning_value(self, pyeval): instruction = ["invoke-kind", "v4", "v9", "some_function()V"] pyeval._invoke(instruction) - assert pyeval.table_obj.pop(4).called_by_func == [ - "some_function()V(Lcom/google/progress/SMSHelper;,some_string)" - ] - assert pyeval.table_obj.pop(9).called_by_func == [ - "java.io.file.close()", - "some_function()V(Lcom/google/progress/SMSHelper;,some_string)", - ] + v4 = pyeval.table_obj.getLatestRegValue(4) + v9 = pyeval.table_obj.getLatestRegValue(9) + + assert len(v4.called_by_func) == 1 + assert v4.called_by_func[0].resolve() == "some_function()V(Lcom/google/progress/SMSHelper;,some_string)" + + assert len(v9.called_by_func) == 2 + assert v9.called_by_func[1].resolve() == "some_function()V(Lcom/google/progress/SMSHelper;,some_string)" + assert pyeval.ret_stack == [] def test_invoke_without_registers(self, pyeval): @@ -323,10 +392,12 @@ def test_invoke_without_registers(self, pyeval): pyeval._invoke(instruction) - assert pyeval.table_obj.pop(9).called_by_func == [ - "java.io.file.close()" - ] - assert pyeval.ret_stack == ["some-func()Lclass;()"] + v9 = pyeval.table_obj.getLatestRegValue(9) + + assert len(v9.called_by_func) == 1 + assert v9.called_by_func[0].resolve() == "java.io.file.close()" + assert len(pyeval.ret_stack) == 1 + assert pyeval.ret_stack[0].resolve() == "some-func()Lclass;()" # Tests for invoke_virtual def test_invoke_virtual_with_valid_mnemonic(self, pyeval): @@ -336,7 +407,7 @@ def test_invoke_virtual_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -353,12 +424,8 @@ def test_invoke_virtual_with_class_inheritance(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.ret_stack == [ - ( - "Landroid/support/v4/util/SimpleArrayMap;" - "->isEmpty()Z(ArrayMap object)" - ) - ] + assert len(pyeval.ret_stack) == 1 + assert pyeval.ret_stack[0].resolve() == "Landroid/support/v4/util/SimpleArrayMap;->isEmpty()Z(ArrayMap object)" assert pyeval.ret_type == "Z" def test_invoke_virtual_range_with_valid_mnemonic(self, pyeval): @@ -368,7 +435,7 @@ def test_invoke_virtual_range_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -384,7 +451,7 @@ def test_invoke_direct_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -399,7 +466,7 @@ def test_invoke_direct_range_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -415,7 +482,7 @@ def test_invoke_static_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -430,7 +497,7 @@ def test_invoke_static_range_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -446,7 +513,7 @@ def test_invoke_interface_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -463,12 +530,8 @@ def test_invoke_interface_with_class_inheritance(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.ret_stack == [ - ( - "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" - ) - ] + assert len(pyeval.ret_stack) == 1 + assert pyeval.ret_stack[0].resolve() == "Landroid/support/v4/util/ArrayMap;->entrySet()Ljava/util/Set;(ArrayMap object)" assert pyeval.ret_type == "Ljava/util/Set;" def test_invoke_interface_range_with_valid_mnemonic(self, pyeval): @@ -478,7 +541,7 @@ def test_invoke_interface_range_with_valid_mnemonic(self, pyeval): "v9", ( "Landroid/support/v4/util/ArrayMap;" - "->entrySet()Ljava/util/Set;(ArrayMap object)" + "->entrySet()Ljava/util/Set;" ), ] @@ -505,12 +568,8 @@ def test_invoke_super_with_class_inheritance(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.ret_stack == [ - ( - "Landroid/support/v4/util/SimpleArrayMap;" - "->toString()Ljava/lang/String;(ArrayMap object)" - ) - ] + assert len(pyeval.ret_stack) == 1 + assert pyeval.ret_stack[0].resolve() == "Landroid/support/v4/util/SimpleArrayMap;->toString()Ljava/lang/String;(ArrayMap object)" assert pyeval.ret_type == "Ljava/lang/String;" def test_invoke_super_range_with_valid_mnemonic(self, pyeval): @@ -583,8 +642,12 @@ def test_move_with_invalid_instrcution(self, pyeval): def test_move_with_valid_instrcution(self, pyeval): instruction = ["move-result-object", "v1"] - expected_return_value = ( - "some_function()V(used_register_1, used_register_2)" + expected_return_value = MethodCall( + "some_function()V", + ( + Primitive("used_register_1", ""), + Primitive("used_register_2", ""), + ), ) expected_return_type = "Lclass;" pyeval.ret_stack.append(expected_return_value) @@ -592,8 +655,9 @@ def test_move_with_valid_instrcution(self, pyeval): pyeval._move_result(instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - expected_return_value, None, value_type=expected_return_type + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "some_function()V(used_register_1,used_register_2)" ) # Tests for move_result @@ -607,13 +671,13 @@ def test_move_result_with_valid_mnemonic(self, pyeval): # Tests for move_result_wide def test_move_result_wide_with_valid_mnemonic(self, pyeval): instruction = ["move-result-wide", "v1"] - return_value = "Return Value" - pyeval.ret_stack.append("Return Value") + return_value = Primitive("Return Value", "") + pyeval.ret_stack.append(return_value) pyeval.MOVE_RESULT_WIDE(instruction) - assert pyeval.table_obj.pop(1).value == return_value - assert pyeval.table_obj.pop(2).value == return_value + assert pyeval.table_obj.getLatestRegValue(1).value is return_value + assert pyeval.table_obj.getLatestRegValue(2).value is return_value # Tests for move_result_object def test_move_result_object_with_valid_mnemonic(self, pyeval): @@ -635,19 +699,20 @@ def test_new_instance(self, pyeval): pyeval.NEW_INSTANCE(instruction) - assert pyeval.table_obj.pop(3) == RegisterObject( - "Lcom/google/progress/SMSHelper;", - value_type="Lcom/google/progress/SMSHelper;", + assert ( + pyeval.table_obj.getLatestRegValue(3).value.resolve() + == "Lcom/google/progress/SMSHelper;" ) - assert pyeval.table_obj.pop(4) == RegisterObject( - "Lcom/google/progress/SMSHelper;", - value_type="Lcom/google/progress/SMSHelper;", + assert ( + pyeval.table_obj.getLatestRegValue(4).value.resolve() + == "Lcom/google/progress/SMSHelper;" ) pyeval.NEW_INSTANCE(override_original_instruction) - assert pyeval.table_obj.pop(4) == RegisterObject( - "Ljava/lang/Object;", value_type="Ljava/lang/Object;" + assert ( + pyeval.table_obj.getLatestRegValue(4).value.resolve() + == "Ljava/lang/Object;" ) # Tests for const_string @@ -660,9 +725,9 @@ def test_const_string(self, pyeval): pyeval.CONST_STRING(instruction) - assert pyeval.table_obj.pop(8) == RegisterObject( - "https://github.com/quark-engine/quark-engine", - value_type="Ljava/lang/String;", + assert ( + pyeval.table_obj.getLatestRegValue(8).value.resolve() + == "https://github.com/quark-engine/quark-engine" ) def test_const_string_jumbo(self, pyeval): @@ -674,9 +739,9 @@ def test_const_string_jumbo(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(8) == RegisterObject( - "https://github.com/quark-engine/quark-engine", - value_type="Ljava/lang/String;", + assert ( + pyeval.table_obj.getLatestRegValue(8).value.resolve() + == "https://github.com/quark-engine/quark-engine" ) def test_const_class(self, pyeval): @@ -688,9 +753,9 @@ def test_const_class(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(8) == RegisterObject( - "Landroid/telephony/SmsMessage;", - value_type="Ljava/lang/Class;", + assert ( + pyeval.table_obj.getLatestRegValue(8).value.resolve() + == "Landroid/telephony/SmsMessage;" ) # Tests for const @@ -735,9 +800,9 @@ def test_move_kind(self, pyeval, move_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "Lcom/google/progress/SMSHelper;", - value_type="Lcom/google/progress/SMSHelper;", + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "Lcom/google/progress/SMSHelper;" ) def test_move_object(self, pyeval): @@ -745,19 +810,22 @@ def test_move_object(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert id(pyeval.table_obj.pop(1)) == id(pyeval.table_obj.pop(4)) + assert id(pyeval.table_obj.getLatestRegValue(1)) == id( + pyeval.table_obj.getLatestRegValue(4) + ) def test_move_wide_kind(self, pyeval, move_wide_kind): instruction = [move_wide_kind, "v1", "v4"] pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "Lcom/google/progress/SMSHelper;", - value_type="Lcom/google/progress/SMSHelper;", + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "Lcom/google/progress/SMSHelper;" ) - assert pyeval.table_obj.pop(2) == RegisterObject( - "some_number", value_type="I" + assert ( + pyeval.table_obj.getLatestRegValue(2).value.resolve() + == "java.lang.String.toString(some_number)" ) def test_new_array(self, pyeval): @@ -765,9 +833,9 @@ def test_new_array(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "new-array()[(some_number)", - value_type="[java/lang/String;", + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "new-array()[(java.lang.String.toString(some_number))" ) def test_filled_array_kind_with_class_type( @@ -777,7 +845,7 @@ def test_filled_array_kind_with_class_type( pyeval.eval[instruction[0]](instruction) - assert pyeval.ret_stack == ["new-array()[type_idx()"] + assert pyeval.ret_stack[0].resolve() == "new-array()[type_idx()" assert pyeval.ret_type == "[type_idx" def test_filled_array_kind_with_primitive_type( @@ -787,17 +855,19 @@ def test_filled_array_kind_with_primitive_type( pyeval.eval[instruction[0]](instruction) - assert pyeval.ret_stack == ["new-array()[I()"] + assert pyeval.ret_stack[0].resolve() == "new-array()[I()" assert pyeval.ret_type == "[I" # Tests for aget-kind def test_aget_kind(self, pyeval, aget_kind): v2_mock_variable_obj = RegisterObject( - "some_list_like[1,2,3,4]", - "java.io.file.close()", + value=Primitive("some_list_like[1,2,3,4]", "[Ljava/lang/Integer;"), + called_by_func=MethodCall("java.io.file.close", tuple()), value_type="[Ljava/lang/Integer;", ) - v3_mock_variable_obj = RegisterObject("2", None, value_type="I") + v3_mock_variable_obj = RegisterObject( + value=Primitive("2", "I"), value_type="I" + ) pyeval.table_obj.insert(2, v2_mock_variable_obj) pyeval.table_obj.insert(3, v3_mock_variable_obj) @@ -815,8 +885,13 @@ def test_aget_kind(self, pyeval, aget_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "some_list_like[1,2,3,4][2]", value_type=expected_value_type + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "some_list_like[1,2,3,4][2]" + ) + assert ( + pyeval.table_obj.getLatestRegValue(1).current_type + == expected_value_type ) def test_aget_wide_kind(self, pyeval, aget_wide_kind): @@ -824,8 +899,9 @@ def test_aget_wide_kind(self, pyeval, aget_wide_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "an_array[some_number]", value_type="I" + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "java.lang.Collection.toArray(an_array)[java.lang.String.toString(some_number)]" ) # Tests for aput-kind @@ -834,23 +910,22 @@ def test_aput_kind(self, pyeval, aput_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(6) == RegisterObject( - "an_array[some_number]:Lcom/google/progress/SMSHelper;", - value_type="[I", + assert ( + pyeval.table_obj.getLatestRegValue(6).value.resolve() + == "java.lang.Collection.toArray(an_array)[java.lang.String.toString(some_number)]:Lcom/google/progress/SMSHelper;" ) + assert pyeval.table_obj.getLatestRegValue(6).current_type == "[I" def test_aput_wide_kind(self, pyeval, aput_wide_kind): instruction = [aput_wide_kind, "v4", "v6", "v5"] pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(6) == RegisterObject( - ( - "an_array[some_number]:" - "(Lcom/google/progress/SMSHelper;, some_number)" - ), - value_type="[I", + assert ( + pyeval.table_obj.getLatestRegValue(6).value.resolve() + == "java.lang.Collection.toArray(an_array)[java.lang.String.toString(some_number)]:(Lcom/google/progress/SMSHelper;, java.lang.String.toString(some_number))" ) + assert pyeval.table_obj.getLatestRegValue(6).current_type == "[I" # Tests for neg-kind and not-kind def test_neg_and_not_kind(self, pyeval, neg_not_kind): @@ -858,20 +933,24 @@ def test_neg_and_not_kind(self, pyeval, neg_not_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "some_number", value_type="I" + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "java.lang.String.toString(some_number)" ) + assert pyeval.table_obj.getLatestRegValue(1).current_type == "I" def test_neg_and_not_wide_kind(self, pyeval, neg_not_wide_kind): instruction = [neg_not_wide_kind, "v1", "v5"] pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "some_number", value_type="I" + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == "java.lang.String.toString(some_number)" ) - assert pyeval.table_obj.pop(2) == RegisterObject( - "an_array", value_type="[I" + assert ( + pyeval.table_obj.getLatestRegValue(2).value.resolve() + == "java.lang.Collection.toArray(an_array)" ) # Tests for type-casting @@ -883,10 +962,9 @@ def test_type_casting_without_wide_type(self, pyeval, cast_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "casting(some_number)", - value_type=pyeval.type_mapping[postfix], - ) + v1 = pyeval.table_obj.getLatestRegValue(1) + assert v1.value.resolve() == "casting(java.lang.String.toString(some_number))" + assert v1.current_type == pyeval.type_mapping[postfix] def test_type_casting_with_wide_type_to_simple_type( self, pyeval, cast_wide_to_simple_kind @@ -898,10 +976,10 @@ def test_type_casting_with_wide_type_to_simple_type( pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "casting(some_number, an_array)", - value_type=pyeval.type_mapping[postfix], - ) + v1 = pyeval.table_obj.getLatestRegValue(1) + + assert v1.value.resolve() == "casting(java.lang.String.toString(some_number), java.lang.Collection.toArray(an_array))" + assert v1.current_type == pyeval.type_mapping[postfix] def test_type_casting_with_simple_type_to_wide_type( self, pyeval, cast_simple_to_wide_kind @@ -913,14 +991,14 @@ def test_type_casting_with_simple_type_to_wide_type( pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "casting(some_number)", - value_type=pyeval.type_mapping[postfix], - ) - assert pyeval.table_obj.pop(2) == RegisterObject( - "casting(some_number)", - value_type=pyeval.type_mapping[postfix], - ) + v1 = pyeval.table_obj.getLatestRegValue(1) + v2 = pyeval.table_obj.getLatestRegValue(2) + + assert v1.value.resolve() == "casting(java.lang.String.toString(some_number))" + assert v1.current_type == pyeval.type_mapping[postfix] + + assert v2.value.resolve() == "casting(java.lang.String.toString(some_number))" + assert v2.current_type == pyeval.type_mapping[postfix] # Tests for binop-kind def test_simple_binop_kind(self, pyeval, simple_binop_kind): @@ -931,10 +1009,10 @@ def test_simple_binop_kind(self, pyeval, simple_binop_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "binop(some_number, an_array)", - value_type=pyeval.type_mapping[postfix], - ) + v1 = pyeval.table_obj.getLatestRegValue(1) + + assert v1.value.resolve() == "binop(java.lang.String.toString(some_number), java.lang.Collection.toArray(an_array))" + assert v1.current_type == pyeval.type_mapping[postfix] def test_binop_kind_with_wide_type(self, pyeval, binop_wide_kind): instruction = [binop_wide_kind, "v1", "v4", "v6"] @@ -944,14 +1022,14 @@ def test_binop_kind_with_wide_type(self, pyeval, binop_wide_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "binop(Lcom/google/progress/SMSHelper;, an_array)", - value_type=pyeval.type_mapping[postfix], - ) - assert pyeval.table_obj.pop(2) == RegisterObject( - "binop(some_number, a_float)", - value_type=pyeval.type_mapping[postfix], - ) + v1 = pyeval.table_obj.getLatestRegValue(1) + v2 = pyeval.table_obj.getLatestRegValue(2) + + assert v1.value.resolve() == "binop(Lcom/google/progress/SMSHelper;, java.lang.Collection.toArray(an_array))" + assert v1.current_type == pyeval.type_mapping[postfix] + + assert v2.value.resolve() == "binop(java.lang.String.toString(some_number), a_float)" + assert v2.current_type == pyeval.type_mapping[postfix] def test_binop_kind_in_place(self, pyeval, binop_2addr_kind): instruction = [binop_2addr_kind, "v4", "v6"] @@ -962,10 +1040,10 @@ def test_binop_kind_in_place(self, pyeval, binop_2addr_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(4) == RegisterObject( - "binop(Lcom/google/progress/SMSHelper;, an_array)", - value_type=pyeval.type_mapping[postfix], - ) + v4 = pyeval.table_obj.getLatestRegValue(4) + + assert v4.value.resolve() == "binop(Lcom/google/progress/SMSHelper;, java.lang.Collection.toArray(an_array))" + assert v4.current_type == pyeval.type_mapping[postfix] def test_binop_kind_with_literal(self, pyeval, binop_lit_kind): instruction = [binop_lit_kind, "v1", "v5", "literal_number"] @@ -976,10 +1054,10 @@ def test_binop_kind_with_literal(self, pyeval, binop_lit_kind): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "binop(some_number, literal_number)", - value_type=pyeval.type_mapping[postfix], - ) + v1 = pyeval.table_obj.getLatestRegValue(1) + + assert v1.value.resolve() == "binop(java.lang.String.toString(some_number), literal_number)" + assert v1.current_type == pyeval.type_mapping[postfix] # Tests for move-exception def test_move_exception(self, pyeval): @@ -987,8 +1065,8 @@ def test_move_exception(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - "Exception", value_type="Ljava/lang/Throwable;" + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() == "Exception" ) # Tests for fill-array-data @@ -997,8 +1075,9 @@ def test_fill_array_data(self, pyeval): pyeval.eval[instruction[0]](instruction) - assert pyeval.table_obj.pop(6) == RegisterObject( - "Embedded-array-data()[", value_type="[I" + assert ( + pyeval.table_obj.getLatestRegValue(6).value.resolve() + == "Embedded-array-data()[]" ) def test_show_table(self, pyeval): @@ -1011,13 +1090,15 @@ def test_show_table(self, pyeval): def test_invoke_and_move(self, pyeval): v6_mock_variable_obj = RegisterObject( - "some_string", None, value_type="Ljava/lang/String;" + value=Primitive("some_string", "Ljava/lang/String;"), + value_type="Ljava/lang/String;", ) pyeval.table_obj.insert(6, v6_mock_variable_obj) - assert pyeval.table_obj.pop(6) == RegisterObject( - "some_string", value_type="Ljava/lang/String;" + assert ( + pyeval.table_obj.getLatestRegValue(6).value.resolve() + == "some_string" ) first_instruction = [ @@ -1031,14 +1112,77 @@ def test_invoke_and_move(self, pyeval): pyeval.INVOKE_VIRTUAL(first_instruction) pyeval.MOVE_RESULT_OBJECT(second_instruction) - assert pyeval.table_obj.pop(1) == RegisterObject( - ( + assert ( + pyeval.table_obj.getLatestRegValue(1).value.resolve() + == ( "Lcom/google/progress/ContactsCollector;" "->getContactList()Ljava/lang/String;(some_string)" - ), - value_type="Ljava/lang/String;", + ) + ) + + @pytest.mark.parametrize("instance_type", [None, ""]) + def test_lookup_implement_returns_original_signature_when_instance_missing( + self, instance_type, pyeval + ): + method_full_name = "Lcom/example/Worker;->run()V" + + assert ( + pyeval._lookup_implement(instance_type, method_full_name) + == method_full_name ) + def test_lookup_implement_returns_method_from_instance_class( + self, pyeval, apkinfo + ): + method = next(iter(apkinfo.custom_methods)) + method_full_name = PyEval.get_method_pattern( + method.class_name, method.name, method.descriptor + ) + + resolved = pyeval._lookup_implement(method.class_name, method_full_name) + + assert resolved == method_full_name + + def test_lookup_implement_walks_superclasses_when_skip_self( + self, pyeval, apkinfo + ): + subclass = ( + "Landroid/support/v4/app/ActionBarDrawerToggle$SlideDrawable;" + ) + parent_method = apkinfo.find_method( + "Landroid/graphics/drawable/InsetDrawable;", + "", + "(Landroid/graphics/drawable/Drawable; I)V", + )[0] + + alias_signature = PyEval.get_method_pattern( + "Lquark/Interface;", parent_method.name, parent_method.descriptor + ) + expected_signature = PyEval.get_method_pattern( + parent_method.class_name, + parent_method.name, + parent_method.descriptor, + ) + + resolved = pyeval._lookup_implement( + subclass, alias_signature, skip_self=True + ) + + assert resolved == expected_signature + + def test_lookup_implement_raises_when_method_not_found(self, pyeval): + instance_type = next( + cls for cls in pyeval.apkinfo.superclass_relationships if cls + ) + missing_signature = PyEval.get_method_pattern( + instance_type, "__quark_missing__", "()V" + ) + + with pytest.raises(ValueError) as excinfo: + pyeval._lookup_implement(instance_type, missing_signature) + + assert "Instance type" in str(excinfo.value) + @staticmethod def test_get_method_pattern(): class_name = "Lcom/google/progress/ContactsCollector;" diff --git a/tests/script/test_script.py b/tests/script/test_script.py index 119ea5b3..3750764b 100644 --- a/tests/script/test_script.py +++ b/tests/script/test_script.py @@ -362,7 +362,7 @@ def testGetMethodsInArgs(QUARK_ANALYSIS_RESULT_FOR_RULE_193): QUARK_ANALYSIS_RESULT_FOR_RULE_193.behaviorOccurList ) behavior = behaviorOccurList[0] - method = behavior.getMethodsInArgs()[0].fullName + method = behavior.getMethodsInArgs()[1].fullName assert method == "Landroid/telephony/SmsManager;" + \ " getDefault ()Landroid/telephony/SmsManager;" diff --git a/tests/utils/test_tools.py b/tests/utils/test_tools.py index 9a8ad49f..a01f901d 100644 --- a/tests/utils/test_tools.py +++ b/tests/utils/test_tools.py @@ -2,8 +2,6 @@ from quark.utils.tools import ( contains, descriptor_to_androguard_format, - get_arguments_from_argument_str, - get_parenthetic_contents, remove_dup_list, ) @@ -113,26 +111,3 @@ def test_descriptor_to_androguard_format_with_combination(): result = descriptor_to_androguard_format(descriptor) assert result == "(I Ljava/lang/String; [B J)" - - -@pytest.mark.parametrize( - "source, expected", - [ - ("(a)(b)(c)", "(a)"), - ("(((a))(b))", "(((a))(b))"), - ("f1(a,f2(b))", "(a,f2(b))"), - ("()", "()"), - ("((((b)", "((((b)"), - ], -) -def test_get_parenthetic_contents(source, expected): - content = get_parenthetic_contents(source, 0) - assert expected == content - - -def test_get_arguments_from_argument_str(): - argument_str = "LClass;,10,String,new-array(),3.14,1" - descriptor = "(I Ljava/lang/String; [B F Z)" - - arguments = get_arguments_from_argument_str(argument_str, descriptor) - assert arguments == ["LClass;", 10, "String", "new-array()", 3.14, True]