diff --git a/ExperimentReplication.md b/ExperimentReplication.md new file mode 100644 index 0000000..90db8bd --- /dev/null +++ b/ExperimentReplication.md @@ -0,0 +1,212 @@ +# Experiment Replication +If you would like to replicate any of the experiments found in the paper you can find the instructions below. +Make sure to have docker setup as a non-root user on the machine. +To initially setup the project follow these steps: +> [!IMPORTANT] +> Some binaries will error out and that's okay and normal. +> Our tool isn't perfect and can't deal with all possible corner cases, but we try and give you as much data as we can. + +> [!CAUTION] +> Use the `--rda-timeout` flag to increase the timeout if you are running on an under-powered machine, this will increase total analysis time, but also increase the probability of not missing important values. The default timeout is 5 minutes per call-chain + +#### Machine Info +In case you were wondering what the machine specs were used in the paper: +Each mango container was run on a single-core 2.30GHz Intel Xeon CPU with at least 5GB of RAM. +This was parallelized across 4000 cores using kubernetes. +> [!WARNING] +> Some binaries will consume upwards of 20GB of RAM. + + +#### Install Docker (If you don't have it already) +```bash +curl -sSL https://get.docker.com/ | sudo sh +sudo usermod -aG docker $USER +sudo reboot +``` +#### venv setup +```bash +sudo apt install python3.11 python3-pip +python3.11 -m pip install virtualenv +python3.11 -m virtualenv venv +``` + +#### operation mango installation +```bash +git clone https://github.com/sefcom/operation-mango-public.git +source venv/bin/activate +cd operation-mango-public +pip install -e . +pip install -e pipeline/ +mango-pipeline --build-docker +``` + +### Table 1 - Karonte Dataset Evaluation +![Table 1 Image](assets/table_1.png) +[karonte_dataset.tgz](https://drive.google.com/file/d/1-VOf-tEpu4LIgyDyZr7bBZCDK-K2DHaj/view?usp=sharing) - sha256 25b8204d2fbac800994fefb30c8dbf80af3f15e57a9c04397f967f12a879b501 +The karonte dataset can be downloaded from UCSB's [karonte repository](https://github.com/ucsb-seclab/karonte). +In order to reproduce the mango results run the below commands (where the 20 represents the amount of docker containers your setup can handle): +```bash +mango-pipeline --path path/to/karonte-dataset --results your/karonte-results/dir --env --parallel 20 # Runs environment/nvram value resolution + +#The two commands below can be done in parallel + +mango-pipeline --path path/to/karonte-dataset --results your/karonte-results/dir --mango --parallel 20 # Runs mango analysis for command injections +mango-pipeline --path path/to/karonte-dataset --results your/karonte-results/dir --mango --parallel 20 --category overflow # Runs mango analysis for buffer overflows +``` + +To view the runtime results for these values run: +```bash +python operation-mango-public/pipeline/mango_pipeline/scripts/show_table.py your/karonte-results/dir +``` + +### Table 2 - Path Context Aggregation +![Table 2 Image](assets/table_2.png) + +Once you've run the above analysis from [Table 1](#table-1---karonte-dataset-evaluation), you can run the path aggregation script. +```bash +python operation-mango-public/pipeline/mango_pipeline/scripts/path_context_aggregator.py your/karonte-results/dir +``` +This will print out the json aggregation of all paths towards the sink keyword as well as in the file `count.agg`. +`strcpy` is the only sink in the buffer overflow category, every other sink falls into Command Injections so sum those together. + +### Table 3 - SaTC 7 Handpicked Firmware +![Table 3 Image](assets/table_3.png) +[7_firmware.tar.gz](https://www.dropbox.com/scl/fi/v7k3c2ech20l5hh6dxhdf/7_firmware.tar.gz?rlkey=lo74bd7wcgf1jmxmxlwv6j4v3&dl=0) - sha256 55859e73c9d8b46152e899b7ad0110fe671d82d462b44cc166b4c8e00fb76ab6 +In order to reproduce the mango results run the below commands (where the 20 represents the amount of docker containers your setup can handle): +```bash +mango-pipeline --path 7_firmware --results your/7-firmware-results/dir --env --parallel 20 # Runs environment/nvram value resolution + +#The two commands below can be done in parallel +mango-pipeline --path 7_firmware --results your/7-firmware-results/dir --mango --parallel 20 # Runs mango analysis for command injections +mango-pipeline --path 7_firmware --results your/7-firmware-results/dir --mango --parallel 20 --category overflow # Runs mango analysis for buffer overflows +``` + +To view the runtime results for these values run: +```bash +python operation-mango-public/pipeline/mango_pipeline/scripts/show_table.py your/7-firmware-result/dir --show-firmware +``` + +### Table 4 - Manual Analysis of TruPoCs +![Table 4 Image](assets/table_4.png) + +This table is a manual analysis of a subset of TruPoCs Generated from [Table 1](#table-1---karonte-dataset-evaluation). +For the purpose of the paper, TruPoCs are any analysis results with a score of 7 or higher. +These can be found with a simple find command: +```bash +find your/karonte-results/dir -type f -iname '[7-9]*' +``` +Once found, it is up to your reverse engineering skills to verify the validity of the reports. + + +### Table 5 - Ablation of Mango Analysis on R6400v2 +![Table 5 Image](assets/table_5.png) + +[ablation-firmware.tar.gz](https://www.dropbox.com/scl/fi/zz04neglh9d4ycsqa2c0j/ablation-firmware.tar.gz?rlkey=hwhf9tt21f7hsuzcrbrjkg08t&dl=0) - sha256 a3c0acd2b588978cdb2d1873929e9904a84cdba1f9dd603b41e0a109fecba8e6 +This is an ablation study of the two algorithmic contributions that Operation Mango introduces. + +To replicate this, run the following commands: +> [!WARNING] +> You must have the exact folder/result_folder names for the final script to function +```bash +mkdir ablation +cd ablation +mango-pipeline --path path/to/ablation-firmware \ + --results ablation-default \ + --parallel 40 \ + --env # Do an env pass + +# Copy env results to other ablation dirs +cp -r ablation-default ablation-assumed # Dir for no assumed execution +cp -r ablation-default ablation-trace # Dir for no sink-to-source analysis +cp -r ablation-default ablation-all # Dir for neither + +# The following 4 mango-pipeline commands can be run in parallel + +# Current version of mango with both assumed nonimpact and sink-to-source +mango-pipeline --path path/to/ablation-firmware \ + --results ablation-default \ + --parallel 10 \ + --mango + +# mango without assumed nonimpact +mango-pipeline --path path/to/ablation-firmware \ + --results ablation-assumed \ + --parallel 10 \ + --mango \ + --extra-args full-execution + +# mango without sink-to-source +mango-pipeline --path path/to/ablation-firmware \ + --results ablation-trace \ + --parallel 10 \ + --mango \ + --extra-args forward-trace + +# mango without assumed nonimpact or sink-to-source +mango-pipeline --path path/to/ablation-firmware \ + --results ablation-all \ + --parallel 10 \ + --mango \ + --extra-args full-execution forward-trace +``` + +Once all analyses have finished running, you can print the results using the following command: +```bash +python operation-mango-public/pipeline/mango_pipeline/scripts/ablation.py ablation/ +``` +This will create a nice table with well-formatted results. + + +### Table 6 - Large Scale Evaluation +![Table 6 Image](assets/table_6.png) + +> [!CAUTION] +> This dataset is 20GB zipped!!! +> The results can take up to 400GB!!! + +[large_dataset.tar.gz](https://www.dropbox.com/scl/fi/2ndob4flx6sn3a53fln83/large_dataset.tar.gz?rlkey=frgjlhwh244mqb4ua1om9atqd&dl=0) - sha256 a3d8012ba7bcaa1f0f34b7ce6783b5d6441902644f9bedae4031b71ab3490e2e + +Beware all ye who enter... +Unless you have access to an immense amount of compute power, this table will take you a while to reproduce. + +Once you download the entire dataset, extract all of the firmware. +```bash +tar -xvzf large_dataset.tar.gz +cd large_dataset +find . -type f -exec tar -xvzf {} \; +cd .. +``` +Once all the firmware is extracted, you can run the experiments: +```bash +mango-pipeline --path large_dataset --results your/large_dataset/res_dir --env --parallel 20 # Runs environment/nvram value resolution + +#The two commands below can be done in parallel +mango-pipeline --path large_dataset --results your/large_dataset/res_dir --mango --parallel 20 # Runs mango analysis for command injections +mango-pipeline --path large_dataset --results your/large_dataset/res_dir --mango --parallel 20 --category overflow # Runs mango analysis for buffer overflows +``` +To get a similar table output to table 6, run: +```bash +python operation-mango-public/pipeline/mango_pipeline/scripts/show_table.py your/large_dataset/res_dir +``` + + +### Table 7 - Additional Experiment (Appendix) +![Table 7 Image](assets/table_7.png) + +[additional_experiment.tar.gz](https://www.dropbox.com/scl/fi/girnzmyjaterigijmls87/additional_experiment.tar.gz?rlkey=4okv6ivbsyer6df0nc70tzz73&dl=0) - sha256 0374b6efade719cf744eeeacd9729df7908dfca1936d7c136cb0c18446bb8300 + +This is just an additional experiment showing the extensability of the project, similar to [Table 1](#table-1---karonte-dataset-evaluation) and [Table 3](#table-3---satc-7-handpicked-firmware). + +```bash +mango-pipeline --path path/to/additional-dataset --results your/additional-results/dir --env --parallel 20 # Runs environment/nvram value resolution + +#The two commands below can be done in parallel + +mango-pipeline --path path/to/additional-dataset --results your/additional-results/dir --mango --parallel 20 # Runs mango analysis for command injections +mango-pipeline --path path/to/additional-dataset --results your/additional-results/dir --mango --parallel 20 --category overflow # Runs mango analysis for buffer overflows +``` + +To get a similar table output to table 7, run: +```bash +python operation-mango-public/pipeline/mango_pipeline/scripts/show_table.py your/large_dataset/res_dir --show-firmware +``` \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..99429ae --- /dev/null +++ b/README.md @@ -0,0 +1,195 @@ +# Operation Mango \[[Paper PDF](assets/operation-mango.pdf)\] + +## Fast taint-style static analysis based vulnerability discovery + +Common vulnerability discovery techniques all follow a top-down scheme: They start from the entry point of the target program, reach as deep as possible, and examine all encountered program states against a set of security violations. These vulnerability discovery techniques are all limited by the complexity of programs, and are all prone to the path/state explosion problem. + +Alternatively, we can start from the location where vulnerabilities might occur (vulnerability sinks), trace back, and verify if the corresponding data flow may lead to a vulnerability. On top of this, we need “assumed execution”, which means when we are tracing back from a vulnerability sink to its sources, we do not faithfully execute or analyze every function on the path, instead we assume a data flow based on prior knowledge or some static analysis in advance and skip as many functions as possible during back tracing. + +Checkout our [experiment reproduction section](ExperimentReplication.md) to reproduce all the figures found in the paper. + +## Getting Started +There are several ways to run operation mango if you so choose. + +### Docker +Bypass all this non-sense and just use the container. +> [!TIP] +> Don't forget to add volumes with -v for both the binary and result directory +``` +docker run -it clasm/mango-user +``` + +### Local +I highly recommend you create a separate python virtualenv for this method. +```bash +source venv/bin/activate +git clone https://github.com/sefcom/operation-mango-public.git +cd operation-mango-public +pip install . +``` + +To build the docker container locally: +```bash +cd operation-mango-public +docker build -f docker/Dockerfile . -t mango-user +``` + +## Using Operation Mango +Once you install Operation Mango or use the docker container, you'll have access to two commands: `mango` and `env_resolve`. + +### mango +`mango` is your default command for running our basic taint analysis on binaries. +> [!TIP] +> Using the `--concise` will significantly shrink the output size and speed up analysis. +> It will not print the entire analysis taint trace in your results, but normally you won't need that. +``` +mango /path/to/bin --results your_res_dir +``` +will run the basic command injection taint analysis, checkout the `--help` flag for more options. + +### mango output structure +The output of this tool is fairly verbose, you'll be given the following: +> [!TIP] +> Any values labeled as TOP are of unknown or unresolvable values + +`{category}_mango.out` - The entire stdout/stderr of the mango run. +`{category}_results.json` - The json is as follows: +``` +{ + "closures": [ + { + "trace": {} // Function trace starting from input down to sink + "sink": {} // Sink location + "depth": int + "inputs": { + "likely": [] //Functions that flow directly into the sink + "possibly": [] //Functions seen along the way that generally are used as inputs + } + "rank": int // How confident are we this function is a TruPoC + }, + ], + "cfg_time": float // time it took to generate the cfg in seconds + "vra_time": float // time it took to run the variable recovery in seconds + "mango_time": float // time it took for actual mango analysis in seconds + "path": str // path to analyzed binary + "name": str // binary name + "sha256": str // sha256 of the file + "error": str|None // If an error occured print it here + ... // Other timing info +} +``` +`{category}_closures/` - The folder containing the results of individual flows to the sink, all of these are unresolvable by our tool. +`{category}_closures/0.{rank}_{entry_func@addr}_{sink_func@addr}` - The individual closures printed with extra information about likely and possible input sources. +e.g. `0.70_main_0x403e70_system_0x40143c`: +``` +|||system( +||| a0: -> "", +||| ) @ 0x40143c -> + +INPUT SOURCES: +Likely: +NONE +Possibly: +---------- +KEY: "accept(fd: 3)@0x403fb0_274_3" +Keywords: None +Binary Source - UNKNOWN +socket(AF_INET, SOCK_DGRAM, 0)_273_32 +accept(fd: 3)@0x403fb0_274_32 +recv(accept(fd: 3)@0x403fb0_274_32)@0x404088 +---------- +RANK: 0.700 +``` +`execv.json` - This is mostly unused but should contain info about which other processes this binary tries to execute. + +### env_resolve +`env_resolve` performs a taint analysis of a given binary to find all uses of `env` and `nvram` variables. +This is what enables our cross-binary bug finding. +``` +env_resolve /path/to/bin --results your_res_dir +``` + +The output of this tool will be found at `your_res_dir/env.json`. +To feed this info into `mango` merge all of the env.json files together (even if there is only one) with +```bash +env_resolve /path/to/bin --results your_res_dir/env.json --merge +``` + +This will spit out the file `your_res_dir/env.json`. +Then feed it into `mango`. + +```bash +mango /path/to/bin --env-dict your_res_dir/env.json --results your_res_dir +``` + +### env_resolve output structure +The `env.json` output from `env_resolve` follows the `results.json` that `mango` outputs e.g. +``` +{ + "results": { + "func_name": // i.e. nvram_get + { + "key_name": //key used to retrieve the value i.e. "http_passwd + { + "keywords": str // Any frontend keywords used to retrieve this value + "1": // position of the argument starting from "1" (i know...) + { + "arg_value": [ // arg value, in the case of getter funcs it's always the key name. + "0xaddr", // addr where the value is used + + ] + } + } + + } + }, + "cfg_time": float // time it took to generate the cfg in seconds + "vra_time": float // time it took to run the variable recovery in seconds + "analysis_time": float // time it took for actual mango analysis in seconds + "path": str // path to analyzed binary + "name": str // binary name + "sha256": str // sha256 of the file + "error": str|None // If an error occured print it here + ... // Other timing info +} +``` + +## Firmware Cross Binary and Frontend Keyword Bug Finding + +If you're trying to find bugs in some firmware samples as described in our paper, then have a look at the `mango_pipeline` [`Here`](pipeline/README.md). +For further examples of how to use this checkout the [Experiment Replication](ExperimentReplication.md) section. + +## Testing + +```bash +# run all the tests for the developed features (isolated in the `package` module) +pip install pytest-cov +pytest +``` + + +### Handcrafted binaries + +To ease testing, we crafted small binaries highlighting one (or several) case(s) we wanted to be able to handle properly. +It was particularly helpful to drive the development of the [`Handlers`](package/argument_resolver/handlers/). + +They are located under the `package/tests/binaries/` folder. + +| Binary | Description | +| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- | +| `after_values/program` | Contains multiple calls to a sink in a single function. | +| `layered/program` | Nested calls running more than the default 7-depth limit before reaching the sink. | +| `looper/program` | Runs a loop before reaching a sink. | +| `off_shoot/program` | Calls multiple functions that alter the input in sub functions before reaching the sink. | +| `recursive/program` | Contains direct and in-direct recursive calls (Highlights flaw of unresolvable call-depth). | +| `nested/program` | Nested calls and returns before reaching a sink. | +| `simple/program` | Contains call to external function `puts`. Run through nested functions, leading to different sinks (`execve`, `system`). | +| `sprintf_resolved_and_unresolved/program` | Contains two calls to `system`: one with constant data, the other one that could be influenced by the program user. | + + +To ensure reproducibility of testing, the binaries have been added to the repository. +Although, if looking to add a new one, a `Makefile` has been written for convenience. +```bash +# build some homemade light binaries +cd binaries/ && make && cd - +``` diff --git a/assets/operation-mango.pdf b/assets/operation-mango.pdf new file mode 100755 index 0000000..6af55b4 Binary files /dev/null and b/assets/operation-mango.pdf differ diff --git a/assets/table_1.png b/assets/table_1.png new file mode 100644 index 0000000..bb8449d Binary files /dev/null and b/assets/table_1.png differ diff --git a/assets/table_2.png b/assets/table_2.png new file mode 100644 index 0000000..0b8a8f0 Binary files /dev/null and b/assets/table_2.png differ diff --git a/assets/table_3.png b/assets/table_3.png new file mode 100644 index 0000000..eae4b92 Binary files /dev/null and b/assets/table_3.png differ diff --git a/assets/table_4.png b/assets/table_4.png new file mode 100644 index 0000000..cabf6df Binary files /dev/null and b/assets/table_4.png differ diff --git a/assets/table_5.png b/assets/table_5.png new file mode 100644 index 0000000..776bf6d Binary files /dev/null and b/assets/table_5.png differ diff --git a/assets/table_6.png b/assets/table_6.png new file mode 100644 index 0000000..1f6b292 Binary files /dev/null and b/assets/table_6.png differ diff --git a/assets/table_7.png b/assets/table_7.png new file mode 100755 index 0000000..61a16ba Binary files /dev/null and b/assets/table_7.png differ diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..657c8ad --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,42 @@ +ARG BASE=ubuntu:22.04 +FROM ${BASE} as base +MAINTAINER Clasm + +ENV DEBIAN_FRONTEND=noninteractive + +USER root +RUN apt-get update && apt-get -y install software-properties-common dirmngr apt-transport-https lsb-release ca-certificates + +FROM base as python_build +RUN add-apt-repository ppa:deadsnakes/ppa -y +RUN apt-get update && apt-get -o APT::Immediate-Configure=0 install -y \ + git sudo virtualenvwrapper python3.11-dev python3.11-venv python3-pip build-essential libxml2-dev \ + libxslt1-dev git libffi-dev cmake libreadline-dev libtool debootstrap \ + debian-archive-keyring libglib2.0-dev libpixman-1-dev qtdeclarative5-dev \ + binutils-multiarch nasm libc6 libgcc1 libstdc++6 \ + libtinfo5 zlib1g vim openssl libssl-dev openjdk-8-jdk graphviz graphviz-dev\ + && rm -rf /var/lib/apt/lists/* + +FROM python_build as angr_repo + +#RUN . /angr/bin/activate && pip install py-spy rich docker toml kubernetes + +FROM angr_repo as mango_repo + +RUN mkdir /operation-mango +COPY . /operation-mango +RUN cd / && python3.11 -m venv angr +WORKDIR /operation-mango + +RUN touch /operation-mango/setup.cfg +RUN . /angr/bin/activate && pip install -e . +RUN . /angr/bin/activate && pip install -e pipeline +RUN cd /angr/lib/python3.11/site-packages && git apply /operation-mango/docker/live_def.patch + +WORKDIR /operation-mango +RUN ln -s /angr/bin/mango /usr/local/bin/mango +RUN ln -s /angr/bin/env_resolve /usr/local/bin/env_resolve +RUN ln -s /angr/bin/mango-pipeline /usr/local/bin/mango-pipeline + +COPY ./docker/entrypoint.py / +RUN chmod +x /entrypoint.py diff --git a/docker/entrypoint.py b/docker/entrypoint.py new file mode 100644 index 0000000..dbbed91 --- /dev/null +++ b/docker/entrypoint.py @@ -0,0 +1,301 @@ +#!/angr/bin/python + +import os +import sys +import json +import shutil +import subprocess +import tarfile +import hashlib +import time + +from pathlib import Path +from typing import Tuple, List, Any + + +def give_write_permissions(file_path: Path): + os.chmod(file_path, os.stat(file_path).st_mode | 0o002) + + +def get_kube_data() -> tuple[ + Path | Any, Path | Any, int, int, Any, Any, Any | None, Path +]: + experiment_list_loc = Path(sys.argv[1]) + with experiment_list_loc.open("r", encoding="ascii") as f: + experiment_data = json.load(f) + + index = os.environ["JOB_COMPLETION_INDEX"] + script = experiment_data["script"] + timeout = int(experiment_data["timeout"]) + rda_timeout = int(experiment_data["rda_timeout"]) + category = experiment_data["category"] + result_dest = Path(experiment_data["result_dest"]) + target_dir = Path(experiment_data["target_dir"]) + + if script != "bin_prep": + brand = experiment_data["targets"][index]["brand"] + firmware = experiment_data["targets"][index]["firmware"] + sha = experiment_data["targets"][index]["sha"] + if experiment_data["targets"][index]["ld_paths"]: + ld_paths = experiment_data["targets"][index]["ld_paths"] + else: + ld_paths = None + + local_res = result_dest / brand / firmware / sha + target = Path(experiment_data["targets"][index]["path"]) + else: + local_res = result_dest + target = experiment_data["targets"][index] + ld_paths = None + + return ( + local_res, + target, + timeout, + rda_timeout, + script, + category, + ld_paths, + target_dir, + ) + + +def get_local_data() -> tuple[Path, Path, int, int, str, Any, str | None | Any, None]: + script = os.environ.get("SCRIPT", None) + timeout = int(os.environ.get("TIMEOUT", 3 * 60 * 60)) + rda_timeout = int(os.environ.get("RDA_TIMEOUT", 0)) + category = json.loads(os.environ.get("CATEGORY", "[]")) + + result_dest = Path(os.environ.get("RESULT_DEST", "")) + + target = Path(os.environ.get("TARGET_PATH", "")) + ld_paths = os.environ.get("LD_PATHS", None) + if ld_paths: + ld_paths = json.loads(ld_paths) + + return result_dest, target, timeout, rda_timeout, script, category, ld_paths, None + + +def gen_error_dict(target_path, time_str, script): + sha = None + if os.path.exists(target_path): + with open(target_path, "rb") as f: + sha = hashlib.file_digest(f, "sha256").hexdigest() + data = { + "results": {}, + "sha256": sha, + "name": target_path.name, + "path": str(target_path), + "closures" if script == "mango" else "results": [], + "cfg_time": None, + "vra_time": None, + time_str: None, + "error": None, + "has_sinks": True, + "ret_code": 0, + } + return data + + +def run_local(): + if len(sys.argv) > 1 and sys.argv[1] == "mango": + script = "/angr/bin/mango" + elif len(sys.argv) > 1 and sys.argv[1] == "env_resolve": + script = "/angr/bin/env_resolve" + else: + print("Requires eiter: mango or env_resolve to be argv[1]") + return + + subprocess.run([script, *sys.argv[2:]]) + + +def main(): + get_data = ( + get_kube_data if os.environ.get("KUBE", None) is not None else get_local_data + ) + ( + local_result, + target_path, + timeout, + rda_timeout, + script, + category, + ld_paths, + target_dir, + ) = get_data() + + if script is None: + run_local() + return + + debug = os.environ.get("DEBUG", None) + + local_result.mkdir(exist_ok=True, parents=True) + + if "ZIP_DEST" in os.environ: + zip_dest = Path(os.environ["ZIP_DEST"]) + os.makedirs(zip_dest, exist_ok=True) + if script == "bin_prep": + print("SHOULDNT BE HERE", target_path, zip_dest) + subprocess.call(["tar", "-xvzf", target_path, "-C", zip_dest]) + else: + firm_zip = str(target_dir / local_result.parent.name) + ".tar.gz" + print("Extracting", firm_zip, "to", zip_dest) + subprocess.call(["tar", "-xvzf", firm_zip, "-C", zip_dest]) + + if script == "bin_prep": + if tarfile.is_tarfile(target_path): + target_path = os.environ["ZIP_DEST"] + command = [ + "/angr/bin/mango-pipeline", + "--path", + str(target_path), + "--results", + str(local_result), + "--bin-prep", + ] + subprocess.call(command) + + else: + command = [] + if "PYSPY" in os.environ: + command += [ + "/angr/bin/py-spy", + "record", + "--format", + "speedscope", + "-o", + f"{str(local_result)}/speedscope.json", + "--", + ] + print("RUNNING PYSPY") + else: + print("NOT RUNNING PYSPY") + command += [f"/angr/bin/{script}", target_path] + + command += ["--rda-timeout", str(rda_timeout)] + command += ["--results", str(local_result)] + # Ignoring this option atm + # command += ["--ld-paths", " ".join(ld_paths)] + result_file = None + + if "EXTRA_ARGS" in os.environ: + command += json.loads(os.environ["EXTRA_ARGS"]) + + keyword_dict = local_result.parent / "keywords.json" + if not keyword_dict.exists(): + with keyword_dict.open("w+") as f: + f.write("{}") + command += ["--keyword-dict", str(keyword_dict)] + command += ["--workers", "0"] + + ret_code = None + start_time = time.time() + if script == "mango": + result_file = f"{category}_results.json" + if category: + command += ["--category", category] + env_dict = local_result.parent / "env.json" + if not env_dict.exists(): + with env_dict.open("w+") as f: + f.write("{}") + command += ["--env-dict", str(env_dict)] + if debug is not None: + command += ["--loglevel", "DEBUG"] + command += ["--concise"] + + print("COMMAND:", command) + try: + result = subprocess.run(command, timeout=timeout) + ret_code = result.returncode + except subprocess.TimeoutExpired as e: + ret_code = 124 + + if ret_code != 0: + tmp_path = Path("/tmp/mango.out") + if tmp_path.exists(): + shutil.copy(tmp_path, local_result / "mango.out") + else: + real_path = local_result / "mango.out" + if real_path.exists(): + real_path.unlink() + + firmware_dst = local_result.parent + subprocess.call( + [ + "/angr/bin/mango", + str(firmware_dst), + "--merge-execve", + "--results", + str(firmware_dst / "execv.json"), + ] + ) + elif script == "env_resolve": + print("COMMAND:", command) + try: + ret_code = subprocess.call(command, timeout=timeout) + except subprocess.TimeoutExpired: + ret_code = 124 + result_file = "env.json" + + firmware_dst = local_result.parent + subprocess.call( + [ + "/angr/bin/env_resolve", + str(firmware_dst), + "--merge", + "--results", + str(firmware_dst / result_file), + ] + ) + + time_str = "mango_time" if script == "mango" else "analysis_time" + try: + with (local_result / result_file).open("r") as f: + data = json.load(f) + except FileNotFoundError: + data = gen_error_dict(target_path, time_str, script) + except json.decoder.JSONDecodeError: + with (local_result / result_file).open("r") as f: + data = f.read() + print("FAILED TO DECODE JSON") + print(data) + data = gen_error_dict(target_path, time_str, script) + + data["ret_code"] = ret_code + + if ret_code == 124 and "cfg_time" in data and data["cfg_time"]: + data["ret_code"] = ret_code + data["error"] = "timeout" + data[time_str] = time.time() - start_time + + with (local_result / result_file).open("w") as f: + json.dump(data, f, indent=4) + + elif ret_code == 124: + data["ret_code"] = ret_code + data["error"] = "potential_timeout" + data[time_str] = time.time() - start_time + + elif ret_code == -9: + data["error"] = "OOMKILLED" + data["ret_code"] = ret_code + data[time_str] = time.time() - start_time + with (local_result / result_file).open("w") as f: + json.dump(data, f, indent=4) + + elif ret_code != 0: + data["error"] = "early_termination" + data["ret_code"] = ret_code + data[time_str] = time.time() - start_time + with (local_result / result_file).open("w") as f: + json.dump(data, f, indent=4) + + if (local_result / result_file).exists(): + give_write_permissions(local_result / result_file) + + print(data) + + +if __name__ == "__main__": + main() diff --git a/docker/live_def.patch b/docker/live_def.patch new file mode 100644 index 0000000..b281602 --- /dev/null +++ b/docker/live_def.patch @@ -0,0 +1,46 @@ +diff --git a/angr/knowledge_plugins/key_definitions/live_definitions.py b/angr/knowledge_plugins/key_definitions/live_definitions.py +index 3db6532f2..2b17d28fe 100644 +--- a/angr/knowledge_plugins/key_definitions/live_definitions.py ++++ b/angr/knowledge_plugins/key_definitions/live_definitions.py +@@ -160,7 +160,7 @@ class LiveDefinitions: + MultiValuedMemory( + memory_id="mem", + top_func=self.top, +- skip_missing_values_during_merging=False, ++ skip_missing_values_during_merging=True, + page_kwargs={"mo_cmp": self._mo_cmp}, + ) + if memory is None +@@ -516,9 +516,12 @@ class LiveDefinitions: + else: + definition: Definition = Definition(atom, code_loc, dummy=dummy, tags=tags) + d = MultiValues() +- for offset, vs in data.items(): +- for v in vs: +- d.add_value(offset, self.annotate_with_def(v, definition)) ++ try: ++ for offset, vs in data.items(): ++ for v in vs: ++ d.add_value(offset, self.annotate_with_def(v, definition)) ++ except AttributeError: ++ pass + + # set_object() replaces kill (not implemented) and add (add) in one step + if isinstance(atom, Register): +diff --git a/angr/storage/memory_mixins/paged_memory/pages/mv_list_page.py b/angr/storage/memory_mixins/paged_memory/pages/mv_list_page.py +index ec83c8d89..a63968337 100644 +--- a/angr/storage/memory_mixins/paged_memory/pages/mv_list_page.py ++++ b/angr/storage/memory_mixins/paged_memory/pages/mv_list_page.py +@@ -203,7 +203,11 @@ class MVListPage( + size = min(mo_length - (page_addr + b - mo_base), len(self.content) - b) + merged_to = b + size + +- merged_val = self._merge_values(to_merge, mo_length, memory=memory) ++ if memory.state.__class__.__name__ == 'LiveDefinitions': ++ merged_val = self._merge_values(to_merge, mo_length, memory=memory, ++ is_sp=b == memory.state.arch.sp_offset) ++ else: ++ merged_val = self._merge_values(to_merge, mo_length, memory=memory) + if merged_val is None: + # merge_values() determines that we should not attempt to merge this value + continue diff --git a/package/argument_resolver.egg-info/PKG-INFO b/package/argument_resolver.egg-info/PKG-INFO new file mode 100644 index 0000000..fe770b3 --- /dev/null +++ b/package/argument_resolver.egg-info/PKG-INFO @@ -0,0 +1,112 @@ +Metadata-Version: 2.1 +Name: argument_resolver +Version: 0.0.1 +Summary: An RDA based static-analysis library for resolving function arguments +Author-email: Wil Gibbs , Pamplemousse , Fish +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +Requires-Dist: angr==9.2.94 +Requires-Dist: pydot==2.0.0 +Requires-Dist: networkx==3.2.1 +Requires-Dist: psutil==5.9.8 +Requires-Dist: ipdb==0.13.13 +Requires-Dist: rich==13.7.1 +Provides-Extra: dev +Requires-Dist: ipdb; extra == "dev" +Requires-Dist: pytest; extra == "dev" +Requires-Dist: pytest-cov; extra == "dev" +Requires-Dist: mypy; extra == "dev" +Requires-Dist: flake8; extra == "dev" + +# operation-mango + +## Fast vulnerability discovery by assumed execution + +Common vulnerability discovery techniques all follow a top-down scheme: They start from the entry point of the target program, reach as deep as possible, and examine all encountered program states against a set of security violations. These vulnerability discovery techniques are all limited by the complexity of programs, and are all prone to the path/state explosion problem. + +Alternatively, we can start from the location where vulnerabilities might occur (vulnerability sinks), trace back, and verify if the corresponding data flow may lead to a vulnerability. On top of this, we need “assumed execution”, which means when we are tracing back from a vulnerability sink to its sources, we do not faithfully execute or analyze every function on the path, instead we assume a data flow based on prior knowledge or some static analysis in advance and skip as many functions as possible during back tracing. + +You can find our paper here \[[PDF](TBD)\]! +Checkout our [experiment reproduction section](ExperimentReplication.md) to reproduce all the figures found in the paper. + +## Getting Started +There are several ways to run operation mango if you so choose. + +### Docker +Bypass all this non-sense and just use the container. +> [!TIP] +> Don't forget to add volumes with -v for both the binary and result directory +``` +docker run -it clasm/mango-user +``` + +### Local +I highly recommend you create a separate python virtualenv for this method. +```bash +source venv/bin/activate +git clone https://github.com/sefcom/operation-mango-public.git +cd operation-mango-public +pip install . +``` + +To build the docker container locally: +```bash +cd operation-mango-public +docker build -f docker/Dockerfile . -t mango-user +``` + +## Using Operation Mango + +### mango +`mango` is your default command for running our basic taint analysis on binaries. +> [!TIP] +> Using the `--concise` will significantly shrink the output size and speed up analysis. +> It will not print the entire analysis taint trace in your results, but normally you won't need that. +``` +mango /path/to/bin --results your_res_dir +``` +will run the basic command injection taint analysis, checkout the `--category` flag for more options + +### env_resolve + +## Parallelized Workloads + +If you have a large workload check out `mango_pipeline` [`Here`](pipeline/README.md). + +## Testing + +```bash +# run all the tests for the developed features (isolated in the `package` module) +(venv) pip install pytest-cov +(venv) pytest +``` + + +### Handcrafted binaries + +To ease testing, we crafted small binaries highlighting one (or several) case(s) we wanted to be able to handle properly. +It was particularly helpful to drive the development of the [`Handlers`](package/argument_resolver/handlers/). + +They are located under the `package/tests/binaries/` folder. + +| Binary | Description | +| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- | +| `after_values/program` | Contains multiple calls to a sink in a single function. | +| `layered/program` | Nested calls running more than the default 7-depth limit before reaching the sink. | +| `looper/program` | Runs a loop before reaching a sink. | +| `off_shoot/program` | Calls multiple functions that alter the input in sub functions before reaching the sink. | +| `recursive/program` | Contains direct and in-direct recursive calls (Highlights flaw of unresolvable call-depth). | +| `nested/program` | Nested calls and returns before reaching a sink. | +| `simple/program` | Contains call to external function `puts`. Run through nested functions, leading to different sinks (`execve`, `system`). | +| `sprintf_resolved_and_unresolved/program` | Contains two calls to `system`: one with constant data, the other one that could be influenced by the program user. | + + +To ensure reproducibility of testing, the binaries have been added to the repository. +Although, if looking to add a new one, a `Makefile` has been written for convenience. +```bash +# build some homemade light binaries +cd binaries/ && make && cd - +``` diff --git a/package/argument_resolver.egg-info/SOURCES.txt b/package/argument_resolver.egg-info/SOURCES.txt new file mode 100644 index 0000000..fdce02b --- /dev/null +++ b/package/argument_resolver.egg-info/SOURCES.txt @@ -0,0 +1,64 @@ +README.md +pyproject.toml +package/argument_resolver/__init__.py +package/argument_resolver/__main__.py +package/argument_resolver.egg-info/PKG-INFO +package/argument_resolver.egg-info/SOURCES.txt +package/argument_resolver.egg-info/dependency_links.txt +package/argument_resolver.egg-info/entry_points.txt +package/argument_resolver.egg-info/requires.txt +package/argument_resolver.egg-info/top_level.txt +package/argument_resolver/analysis/__init__.py +package/argument_resolver/analysis/base.py +package/argument_resolver/analysis/env_resolve.py +package/argument_resolver/analysis/mango.py +package/argument_resolver/concretization/__init__.py +package/argument_resolver/concretization/concrete_runner.py +package/argument_resolver/external_function/__init__.py +package/argument_resolver/external_function/input_functions.py +package/argument_resolver/external_function/function_declarations/__init__.py +package/argument_resolver/external_function/function_declarations/custom.py +package/argument_resolver/external_function/function_declarations/nvram.py +package/argument_resolver/external_function/function_declarations/win32.py +package/argument_resolver/external_function/sink/__init__.py +package/argument_resolver/external_function/sink/sink_lists.py +package/argument_resolver/formatters/__init__.py +package/argument_resolver/formatters/closure_formatter.py +package/argument_resolver/formatters/log_formatter.py +package/argument_resolver/formatters/results_formatter.py +package/argument_resolver/handlers/__init__.py +package/argument_resolver/handlers/base.py +package/argument_resolver/handlers/local_handler.py +package/argument_resolver/handlers/network.py +package/argument_resolver/handlers/nvram.py +package/argument_resolver/handlers/static.py +package/argument_resolver/handlers/stdio.py +package/argument_resolver/handlers/stdlib.py +package/argument_resolver/handlers/string.py +package/argument_resolver/handlers/unistd.py +package/argument_resolver/handlers/url_param.py +package/argument_resolver/handlers/functions/__init__.py +package/argument_resolver/handlers/functions/constant_function.py +package/argument_resolver/utils/__init__.py +package/argument_resolver/utils/call_trace.py +package/argument_resolver/utils/call_trace_visitor.py +package/argument_resolver/utils/calling_convention.py +package/argument_resolver/utils/closure.py +package/argument_resolver/utils/format_prototype.py +package/argument_resolver/utils/graph_helper.py +package/argument_resolver/utils/rank.py +package/argument_resolver/utils/rda.py +package/argument_resolver/utils/stored_function.py +package/argument_resolver/utils/transitive_closure.py +package/argument_resolver/utils/utils.py +package/tests/test_basic_facts_for_handcrafted_binaries.py +package/tests/test_call_trace.py +package/tests/test_calling_convention.py +package/tests/test_sink.py +package/tests/test_transitive_closure.py +package/tests/test_utils.py +package/tests/test_handlers/handler_tester.py +package/tests/test_handlers/test_stdio.py +package/tests/test_handlers/test_stdlib.py +package/tests/test_handlers/test_string.py +package/tests/test_handlers/test_unistd.py \ No newline at end of file diff --git a/package/argument_resolver.egg-info/dependency_links.txt b/package/argument_resolver.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/package/argument_resolver.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/package/argument_resolver.egg-info/entry_points.txt b/package/argument_resolver.egg-info/entry_points.txt new file mode 100644 index 0000000..e688aae --- /dev/null +++ b/package/argument_resolver.egg-info/entry_points.txt @@ -0,0 +1,3 @@ +[console_scripts] +env_resolve = argument_resolver.analysis.env_resolve:main +mango = argument_resolver.analysis.mango:main diff --git a/package/argument_resolver.egg-info/not-zip-safe b/package/argument_resolver.egg-info/not-zip-safe new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/package/argument_resolver.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/package/argument_resolver.egg-info/requires.txt b/package/argument_resolver.egg-info/requires.txt new file mode 100644 index 0000000..3be6dc6 --- /dev/null +++ b/package/argument_resolver.egg-info/requires.txt @@ -0,0 +1,13 @@ +angr==9.2.94 +pydot==2.0.0 +networkx==3.2.1 +psutil==5.9.8 +ipdb==0.13.13 +rich==13.7.1 + +[dev] +ipdb +pytest +pytest-cov +mypy +flake8 diff --git a/package/argument_resolver.egg-info/top_level.txt b/package/argument_resolver.egg-info/top_level.txt new file mode 100644 index 0000000..c686f70 --- /dev/null +++ b/package/argument_resolver.egg-info/top_level.txt @@ -0,0 +1,3 @@ +argument_resolver +procedures +tests diff --git a/package/argument_resolver/__init__.py b/package/argument_resolver/__init__.py new file mode 100644 index 0000000..362d74a --- /dev/null +++ b/package/argument_resolver/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/package/argument_resolver/__main__.py b/package/argument_resolver/__main__.py new file mode 100644 index 0000000..b06eda9 --- /dev/null +++ b/package/argument_resolver/__main__.py @@ -0,0 +1,3 @@ +from argument_resolver.analysis.mango import main + +main() diff --git a/package/argument_resolver/analysis/__init__.py b/package/argument_resolver/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/package/argument_resolver/analysis/base.py b/package/argument_resolver/analysis/base.py new file mode 100644 index 0000000..e086599 --- /dev/null +++ b/package/argument_resolver/analysis/base.py @@ -0,0 +1,1169 @@ +import argparse +import json +import logging +import os +import signal +import subprocess +import shutil +import sys +import time +import ipdb +import inspect +from multiprocessing import cpu_count +from pathlib import Path +from collections import Counter +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, NamedTuple + +import psutil + +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) + +import angr +from angr.analyses.analysis import AnalysisFactory +from angr.analyses.reaching_definitions.call_trace import CallTrace +from angr.code_location import CodeLocation +from angr.analyses.reaching_definitions.dep_graph import DepGraph +from angr.analyses.reaching_definitions.reaching_definitions import ( + ReachingDefinitionsAnalysis, +) +from angr.knowledge_plugins.functions.function import Function +from angr.knowledge_plugins.key_definitions.atoms import Atom +from angr.knowledge_plugins.key_definitions.constants import OP_AFTER, OP_BEFORE +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions + +from argument_resolver.external_function.function_declarations import CUSTOM_DECLS +from argument_resolver.external_function.sink import Sink, VULN_TYPES +from argument_resolver.formatters.closure_formatter import ClosureFormatter +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers import ( + NVRAMHandlers, + NetworkHandlers, + StdioHandlers, + StdlibHandlers, + StringHandlers, + UnistdHandlers, + URLParamHandlers, + handler_factory, +) +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.call_trace import traces_to_sink +from argument_resolver.utils.call_trace_visitor import CallTraceSubject +from argument_resolver.utils.calling_convention import ( + CallingConventionResolver, + LIBRARY_DECLS, +) +from argument_resolver.utils.rda import CustomRDA +from argument_resolver.utils.stored_function import StoredFunction +from argument_resolver.utils.utils import Utils +from argument_resolver.formatters.log_formatter import CustomTextColumn + + +class ScriptBase: + def __init__( + self, + bin_path: str, + sink: str = None, + source_function="main", + min_depth=1, + max_depth=1, + arg_pos=0, + ld_paths=None, + excluded_functions=None, + result_path: str = None, + disable_progress_bar=True, + sink_category=None, + env_dict: str = None, + workers: int = 1, + rda_timeout: int = 0, + full_exec: bool = False, + forward_trace: bool = False, + log_level=logging.INFO, + enable_breakpoint=False, + keyword_dict: str = None, + ): + signal.signal(signal.SIGALRM, self.timeout_handler) + + self.bin_path = bin_path + self.min_depth = min_depth if not forward_trace else max_depth + self.max_depth = max_depth + self.excluded_functions = {} + self.source = source_function + self.assumed_execution = not full_exec + self.forward_trace = forward_trace + self.trace_dict = {} + self.sinks_found = {} + self.timeout_proc = None + self.rda_task = None + self.trace_task = None + self.enable_breakpoint = enable_breakpoint + self.category = sink_category + if keyword_dict is None: + self.keyword_dict = dict() + else: + with open(keyword_dict, "r") as f: + self.keyword_dict = json.load(f) + + self.cfg_time = 0 + self.vra_time = 0 + self.analysis_start_time = 0 + self.analysis_time = 0 + self.rda_timeout = rda_timeout + self.vra_start_time = 0 + self.time_data = {} + self.sink_time = 0 + + self.Handler = handler_factory( + [ + StdioHandlers, + StdlibHandlers, + StringHandlers, + UnistdHandlers, + NVRAMHandlers, + NetworkHandlers, + URLParamHandlers, + ] + ) + + self.result_path = Path(result_path) if result_path is not None else None + self.alarm_triggered = False + + if self.result_path is not None: + self.log_path = self.result_path / f"{self.category}_mango.out" + self.log = make_logger(log_level=log_level, should_debug=self.result_path) + logging.getLogger("angr").setLevel(logging.CRITICAL) + + self.progress = Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + CustomTextColumn("{task.completed}/{task.total}"), + TimeElapsedColumn(), + TimeRemainingColumn(), + ) + + self.set_breakpoint_handler() + self.project = self.init_analysis( + workers=workers, + show_progress_bar=not disable_progress_bar, + ld_paths=ld_paths, + ) + self.env_dict = self.load_env_dict(env_dict) + if self.keyword_dict: + self.rename_form_param_parser() + self.sinks = self.load_sinks( + custom_sink=sink, arg_pos=arg_pos, category=sink_category + ) + self.excluded_functions = self.load_excluded_functions( + excluded_functions=excluded_functions + ) + + self._calling_convention_resolver = CallingConventionResolver( + self.project, + self.project.arch, + self.project.kb.functions, + ) + + self.overwrite_func_prototypes() + + self.result_formatter = ClosureFormatter( + self.project, self._calling_convention_resolver + ) + + self.RDA = AnalysisFactory(self.project, CustomRDA) + + def set_breakpoint_handler(self): + sys.breakpointhook = self.breakpoint_handler + + def breakpoint_handler(self, *args, **kwargs): + if self.progress is not None: + self.progress.stop() + self.cleanup_timeout_proc() + frame = inspect.currentframe().f_back + ipdb.set_trace(frame) + + def init_analysis(self, workers: int, show_progress_bar: bool, ld_paths: list): + if ld_paths is None or ld_paths == "None": + project = angr.Project(self.bin_path, auto_load_libs=False) + else: + project = angr.Project( + self.bin_path, + auto_load_libs=True, + load_options={"ld_path": ld_paths, "skip_libs": ["libc.so.0"]}, + ) + + start = time.time() + project.analyses.CFGFast( + normalize=True, data_references=True, show_progressbar=show_progress_bar + ) + self.cfg_time = time.time() - start + + # Run CC analysis + start = time.time() + # Allow for 20 min of VRA + self.vra_start_time = start + vra_task = self.progress.add_task("Running VRA", total=None) + self.progress.start_task(vra_task) + self.progress.start() + self.kill_parent_after_timeout(timeout=20 * 60) + try: + project.analyses.CompleteCallingConventions( + recover_variables=True, analyze_callsites=True, workers=workers + ) + except TimeoutError: + # Failed to finish vra in time + self.log.critical("VRA TIMED OUT") + self.cleanup_timeout_proc() + exit(-1) + self.vra_time = time.time() - start + self.vra_start_time = 0 + self.cleanup_timeout_proc() + self.progress.update(vra_task, completed=1, total=1, visible=False) + self.progress.remove_task(vra_task) + Utils.arch = project.arch + + return project + + def rename_form_param_parser(self): + if not self.keyword_dict: + return + + cfg = self.project.kb.cfgs.get_most_accurate() + strings = [x for x in cfg.memory_data.items() if x[1].sort == "string"] + strings = [ + x + for x in strings + if x[1].content.decode("latin-1") in self.keyword_dict + or x[1].content.decode("latin-1").replace("=", "") in self.keyword_dict + ] + all_addrs = Counter() + for addr, string in strings: + for xref in self.project.kb.xrefs.get_xrefs_by_dst(addr): + node = cfg.get_any_node(xref.block_addr) + if node is None: + continue + + call_addrs = { + x[1].function_address + for x in cfg.graph.out_edges(node, data=True) + if x[-1]["jumpkind"] == "Ijk_Call" + } + all_addrs.update(call_addrs) + all_funcs = [self.project.kb.functions[x] for x in all_addrs] + all_funcs = [ + x + for x in all_funcs + if x.name not in LIBRARY_DECLS and "nvram" not in x.name.lower() + ] + blacklist = ["GetIniFileValue"] + for name in blacklist: + if name in all_funcs: + all_funcs.remove(name) + + if len(all_funcs) == 0: + return + param_func = max(all_funcs, key=lambda x: all_addrs[x.addr]) + self.project.kb.functions[param_func.name].name = "custom_param_parser" + + def overwrite_func_prototypes(self): + for func_name, prototype in CUSTOM_DECLS.items(): + if func_name in self.project.kb.functions: + self.project.kb.functions[func_name].prototype = prototype + + def load_excluded_functions( + self, excluded_functions: List[str] = None + ) -> Dict[Function, Set[Tuple]]: + + default_excluded_functions = self.find_default_excluded_functions() + excluded_dict = {} + valid_sinks = [] + for f_sink in self.sinks: + if f_sink.name not in self.project.kb.functions: + continue + valid_sinks.append(f_sink.name) + sink_func = self.project.kb.functions[f_sink.name] + excluded_dict[sink_func] = default_excluded_functions + + if excluded_functions is None: + continue + + for func in excluded_functions: + if func in self.project.kb.functions: + excluded_dict[sink_func].add( + (self.project.kb.functions[func], None) + ) + if len(valid_sinks) > 0: + self.log.critical("TARGETING SINKS: %s", ", ".join(valid_sinks)) + else: + self.log.critical("NO SINKS FOUND") + return excluded_dict + + @staticmethod + def load_sinks(custom_sink=None, arg_pos=None, category=None) -> List[Sink]: + sinks = [] + if custom_sink and arg_pos: + sinks = [Sink(custom_sink, [arg_pos + 1])] + else: + category = category.lower() + if category in VULN_TYPES: + sinks = VULN_TYPES[category] + if len(sinks) == 0: + sinks = VULN_TYPES["cmdi"] + return sinks + + @staticmethod + def load_env_dict(env_dict_path: str) -> Optional[Dict]: + if env_dict_path is not None: + env_dict = json.loads(Path(env_dict_path).read_text()) + else: + env_dict = {} + + return env_dict + + def update_rda_task(self, parent_name, parent_addr, child_name, child_addr): + self.progress.update( + self.rda_task, + description=f"Analyzing {parent_name}@{parent_addr}->{child_name}@{child_addr}", + advance=1, + ) + + def update_trace_task(self, parent_name, parent_addr, child_name, child_addr): + self.progress.update( + self.trace_task, + description=f"Tainting {parent_name}@{parent_addr}->{child_name}@{child_addr}", + advance=1, + ) + + def analyze(self): + self.analysis_start_time = time.time() + + try: + for sink, sink_function in self.get_sink_callsites(self.sinks): + + sink_start = time.time() + sink_count = 0 + atoms = Utils.get_atoms_from_function(sink_function, self.project.arch) + atoms = [ + atom + for idx, atom in enumerate(atoms) + if idx + 1 in sink.vulnerable_parameters + ] + + cfg = self.project.kb.cfgs.get_most_accurate() + self.sinks_found[sink_function.name] = len( + { + x.addr + for x in cfg.get_predecessors( + cfg.get_any_node(sink_function.addr) + ) + } + ) + traces = self.gen_traces_to_sink( + sink_function, self.max_depth + 1, atoms + ) + analysis_task = self.progress.add_task( + f"Tracing path to sink {sink_function.name}", total=None + ) + taint_start = time.time() + total_traces = 0 + for trace, white_list, trace_idx, total_traces in traces: + taint_time = time.time() - taint_start + + sink_count += 1 + observation_points = self.get_observation_points_from_trace(trace) + + rda_start = time.time() + rda, handler = self.run_analysis_on_trace( + trace, white_list, sink_function, atoms, observation_points + ) + rda_time = time.time() - rda_start + self.sink_time = time.time() - sink_start + callsite_tup = tuple( + [x.caller_func_addr for x in trace.callsites] + + [sink_function.addr] + ) + self.time_data[callsite_tup] = { + "taint_time": taint_time, + "rda_time": rda_time, + } + process_task = self.progress.add_task( + "Analyzing Results from RDA", total=None + ) + self.process_rda(rda, handler) + self.progress.update(process_task, visible=False) + taint_start = time.time() + self.progress.update( + analysis_task, total=total_traces, completed=trace_idx + ) + self.log.info( + f"Analyzed %s/%s for sink %s", + trace_idx, + total_traces, + sink_function.name, + ) + + self.progress.update( + analysis_task, total=total_traces, completed=total_traces + ) + self.analysis_time = time.time() - self.analysis_start_time + self.progress.stop() + self.save_results() + self.log.info("Finished Running Analysis") + temp_path = Path("/tmp/mango.out") + if temp_path.exists() and self.result_path: + self.result_path.mkdir(parents=True, exist_ok=True) + shutil.move(temp_path, self.log_path) + except Exception: + self.cleanup_timeout_proc() + self.progress.stop() + self.log.exception("OH NO MY MANGOES!!!") + exc_type, exc_value, exc_traceback = sys.exc_info() + if self.enable_breakpoint: + ipdb.post_mortem(exc_traceback) + exit(-1) + + def save_results(self): + """ + Save results of analysis + :return: + """ + pass + + def find_default_excluded_functions(self) -> Set[Tuple[int, Any]]: + # If `main` is present, don't let calltraces go beyond it: it is as good as if the entrypoint was reached. + if self.source is None: + return set() + + functions_before_source = set() + main_func = self.project.kb.functions.function(name="main") + if main_func: + functions_before_source = { + (x, None) + for x in self.project.kb.callgraph.predecessors(main_func.addr) + } + + # Due to the nature of its disassembly and CFG reconstitution, + # `angr` marks certain alignment blocks from the binary as functions + # and keeps them in the callgraph (see https://github.com/angr/angr/issues/2366 for an example); + # We don't want that to be part of our calltraces. + alignment_functions = { + (f.addr, None) for f in self.project.kb.functions.values() if f.alignment + } + + return functions_before_source | alignment_functions + + def get_sink_callsites(self, sinks: List[Sink]) -> List[Tuple[Sink, Function]]: + """ + :return: + A list of tuples, for each sink present in the binary, containing: + the representation of the itself, the representation. + """ + + final_sinks = [] + for sink in sinks: + function = self.project.kb.functions.function(name=sink.name) + if function is None: + continue + + if function.calling_convention is None: + function.calling_convention = self._calling_convention_resolver.get_cc( + function.name + ) + if hasattr(function.calling_convention, "sim_func"): + function.prototype = function.calling_convention.sim_func + + if function.prototype is None: + function.prototype = self._calling_convention_resolver.get_prototype( + function.name + ) + + final_sinks.append((sink, function)) + + return final_sinks + + def gen_traces_to_sink( + self, sink: Function, max_depth: int, atoms: List[Atom] + ) -> Generator[Tuple[CallTrace, List[StoredFunction]], None, None]: + """ + :param sink: Function to build trace to + :param atoms: Target atoms in sink + :param max_depth: The maximum length of the path between the sink and the uncovered start point. + + :return: + A tuple containing: + - A boolean telling if every trace is as high as possible (from the sink to the entrypoint of the binary); + - A generator to run a for every start found under the given max_depth. + """ + + traces = [] + for depth in range(1, max_depth): + sub_traces: Set[CallTrace] = traces_to_sink( + sink, + self.project.kb.functions.callgraph, + depth, + self.excluded_functions[sink], + ) + for s_t in sub_traces: + callsites = { + (y.caller_func_addr, y.callee_func_addr) for y in s_t.callsites + } + for trace in traces.copy(): + if all( + (x.caller_func_addr, x.callee_func_addr) in callsites + for x in trace.callsites + ): + traces.remove(trace) + traces.append(s_t) + + if not self.forward_trace: + analyzed_traces = set() + total_traces = len(traces) + + for idx, trace in enumerate( + sorted(traces.copy(), key=lambda x: len(x.callsites)) + ): + new_trace, white_list = self.check_trace_rda(trace, sink, atoms) + + traces.remove(trace) + if not new_trace.callsites: + continue + # Don't bother analyzing anything that doesn't have a known_set location + # if not any(x.code_loc.ins_addr in valid_set_locations for x in white_list): + # continue + callsites = [ + {x.caller_func_addr for x in trace.callsites} for trace in traces + ] + if ( + new_trace.callsites + and {x.caller_func_addr for x in new_trace.callsites} + not in callsites + and not any( + self._is_trace_subset(new_trace, x) in {1, 0} + for x in analyzed_traces + ) + ): + analyzed_traces.add(new_trace) + + yield new_trace, white_list, idx, total_traces + + else: + for t_1 in traces.copy(): + for t_2 in traces.copy(): + if self._is_trace_subset(t_1, t_2) == 1: + traces.remove(t_2) + for idx, trace in enumerate(sorted(traces, key=lambda x: len(x.callsites))): + yield trace, [], idx, len(traces) + + @staticmethod + def _is_trace_subset(trace_1, trace_2): + callsites_1 = { + (x.caller_func_addr, x.callee_func_addr) for x in trace_1.callsites + } + callsites_2 = { + (x.caller_func_addr, x.callee_func_addr) for x in trace_2.callsites + } + + if callsites_1 < callsites_2: + return 1 + + if callsites_1 > callsites_2: + return -1 + + if callsites_1 == callsites_2: + return 0 + + return None + + # create a function to spin up a subprocess that kills the parent after a certain amount of time + def kill_parent_after_timeout(self, timeout=None): + if self.rda_timeout == 0 or timeout == 0: + return + signal.alarm(self.rda_timeout) + p = subprocess.Popen( + f"sleep {timeout if timeout is not None else self.rda_timeout}; while true; do sleep 1 || kill -9 {os.getpid()}; kill -14 {os.getpid()} || break; done", + shell=True, + ) + self.timeout_proc = p + + def cleanup_timeout_proc(self): + signal.alarm(0) + self.alarm_triggered = False + if self.timeout_proc is not None: + parent = psutil.Process(self.timeout_proc.pid) + children = list(parent.children(recursive=True)) + self.timeout_proc.kill() + self.timeout_proc.wait() + self.timeout_proc = None + + # kill all children + for child in children: + try: + child.send_signal(9) + except psutil.NoSuchProcess: + pass + + def get_observation_points_from_trace( + self, trace: CallTrace + ) -> Set[Tuple[str, int, int]]: + def _call_statement_in_node(node) -> Optional[int]: + """ + Assuming the node is the predecessor of a function start. + Returns the statement address of the `call` instruction. + """ + if node is None or node.block is None: + return None + + if ( + self.project.arch.branch_delay_slot + and node.block.disassembly.insns[-1].mnemonic == "nop" + ): + return node.block.instruction_addrs[-2] + return node.block.instruction_addrs[-1] + + observation_points = set() + cfg = self.project.kb.cfgs.get_most_accurate() + # Get final call to target function + for pred in cfg.get_any_node(trace.callsites[0].callee_func_addr).predecessors: + if pred.function_address != trace.callsites[0].caller_func_addr: + continue + + callsite = _call_statement_in_node(pred) + if callsite is None: + continue + + observation_points.add(("insn", callsite, OP_AFTER)) + observation_points.add(("node", callsite, OP_AFTER)) + observation_points.add(("node", callsite, OP_BEFORE)) + + return observation_points + + def check_trace_rda(self, trace: CallTrace, sink: Function, sink_atoms: List[Atom]): + white_list = set() + final_trace = CallTrace(trace.target) + + target_atoms = sink_atoms + WListFunc = NamedTuple("WListFunc", [("code_loc", CodeLocation)]) + + for call_idx, callsite in enumerate(trace.callsites): + callsite_tuple = tuple( + [x.caller_func_addr for x in reversed(trace.callsites[: call_idx + 1])] + + [trace.callsites[0].callee_func_addr] + ) + if sink in self.trace_dict and callsite_tuple in self.trace_dict[sink]: + final_trace.callsites.append(callsite) + data = self.trace_dict[sink][callsite_tuple] + white_list |= set(data["white_list"]) + if data["final"]: + break + else: + continue + + single_trace = CallTrace(callsite.callee_func_addr) + single_trace.callsites = [callsite] + function_address = single_trace.current_function_address() + function = self.project.kb.functions[function_address] + subject = CallTraceSubject(single_trace, function) + handler = self.Handler( + self.project, + sink, + target_atoms, + env_dict=self.env_dict, + taint_trace=True, + progress_callback=self.update_trace_task, + ) + + self.log.info( + "Running RDA Taint on function %s@%#x...", + function.name, + function_address, + ) + self.log.debug( + "Trace: %s", + "".join( + [ + f"{self.project.kb.functions[x.caller_func_addr].name}->" + for x in reversed(single_trace.callsites) + ] + + [self.project.kb.functions[single_trace.target].name] + ), + ) + + self.kill_parent_after_timeout() + timed_out = False + self.trace_task = self.progress.add_task(f"Tainting ...", total=None) + try: + self.RDA( + subject=subject, + function_handler=handler, + observation_points=set(), + init_context=(callsite.caller_func_addr,), + start_time=time.time(), + kb=self.project.kb, + dep_graph=DepGraph(), + rda_timeout=self.rda_timeout, + max_iterations=2, + ) + except TimeoutError: + timed_out = True + + self.cleanup_timeout_proc() + self.progress.update(self.trace_task, visible=False) + self.progress.remove_task(self.trace_task) + + if timed_out: + break + + if sink not in self.trace_dict: + self.trace_dict[sink] = {} + + sinks = [ + x for x in handler.white_list if x.function.addr == single_trace.target + ] + + self.trace_dict[sink][callsite_tuple] = { + "white_list": [ + WListFunc(code_loc=x.code_loc) for x in handler.white_list + ], + "final": False, + "valid_sinks": set(), + "constant": set(), + "input": None, + } + concrete = True + + for target_func in sinks: + for atom in target_func.atoms: + # TODO: This is a lazy fix, add better support for execve + if atom not in target_atoms and not target_func.name.startswith( + "exec" + ): + continue + if target_func.constant_data[atom] is None: + continue + + for d in target_func.constant_data[atom]: + if d is None: + concrete = False + break + elif ( + d.concrete_value == 0x0 + and atom in target_func.arg_vals + and str(d) not in str(target_func.arg_vals[atom]) + ): + concrete = False + break + if not concrete: + break + + has_only_constant_func = not handler.white_list + has_only_constant_func |= ( + len(handler.white_list) == 1 + and handler.white_list[0].function.addr == single_trace.target + ) + + # If the sink arguments are constant + # we can stop the trace and use the constant values as input for the rda + if ( + concrete + and has_only_constant_func + and not self.__class__.__name__ == "EnvAnalysis" + ): + if len(final_trace.callsites) > 0: + sink_val_list = [ + x + for x in handler.white_list + if x.function.addr == single_trace.target + ] + if len(sink_val_list) > 0: + sink_val = sink_val_list[0] + self.trace_dict[sink][callsite_tuple]["input"] = sink_val.state + self.trace_dict[sink][callsite_tuple]["final"] = True + final_trace.callsites.append(callsite) + break + + next_atoms = set() + has_arg_reference = False + has_internal_dependencies = False + for func in handler.white_list: + if func.function.addr == sink.addr: + self.trace_dict[sink][callsite_tuple]["valid_sinks"].add( + func.code_loc.ins_addr + ) + + if func.function.addr == single_trace.target: + if any( + defn not in func.definitions + for atom in [a for a in target_atoms if a in func.closures] + for defn in func.closures[atom] + ): + has_internal_dependencies = True + if func != handler.call_trace[0]: + valid_defns = [ + defn + for defn in func.definitions + if defn in handler.call_trace[0].definitions + ] + if valid_defns: + has_arg_reference = True + next_atoms |= set(d.atom for d in valid_defns) + continue + + valid_closures = [ + closure + for closure in func.closures.values() + if closure.intersection(handler.call_trace[0].definitions) + ] + if valid_closures: + has_arg_reference = True + next_atoms |= set( + d.atom + for closure in valid_closures + for d in closure + if d in handler.call_trace[0].definitions + ) + + target_atoms = list(next_atoms) + + if ( + single_trace.target == sink.addr + and not (has_internal_dependencies or has_arg_reference) + and not self.__class__.__name__ == "EnvAnalysis" + ): + # No dependency on parent func if reached + self.trace_dict[sink][callsite_tuple]["constant"] = True + break + + if has_internal_dependencies or has_arg_reference: + final_trace.callsites.append(callsite) + white_list |= set( + WListFunc(code_loc=x.code_loc) for x in handler.white_list + ) + + elif self.__class__.__name__ == "EnvAnalysis": + final_trace.callsites.append(callsite) + white_list |= set( + WListFunc(code_loc=x.code_loc) for x in handler.white_list + ) + self.trace_dict[sink][callsite_tuple]["final"] = True + break + + if not has_arg_reference or not handler.current_parent.atoms: + self.trace_dict[sink][callsite_tuple]["final"] = True + break + + return final_trace, white_list + + def timeout_handler(self, signum, frame): + time_diff = time.time() - self.vra_start_time + if time.time() - self.vra_start_time < 20 * 60 or self.alarm_triggered: + self.cleanup_timeout_proc() + self.kill_parent_after_timeout((20 * 60) - time_diff) + return + + self.alarm_triggered = True + raise TimeoutError("RDA Timeout") + + def run_analysis_on_trace( + self, + trace: CallTrace, + white_list: List[StoredFunction], + sink: Function, + sink_atoms: List[Atom], + observation_points: Set[Tuple[str, int, int]], + ): + """ + Generator to get RDA analyses for each start point of the CFG at a given depth. + :param traces: The set of CallTraces leading to the sink + :param sink: The sink + :param sink_atoms: The atoms and their respective types representing the arguments flowing into the subject (sink). + :param observation_points: Livedef states to preserve at address. + :param timeout: Seconds until RDA is cancelled + """ + + handler = self.Handler( + self.project, + sink, + sink_atoms, + env_dict=self.env_dict, + assumed_execution=self.assumed_execution, + forward_trace=self.forward_trace, + progress_callback=self.update_rda_task, + ) + if self.assumed_execution: + handler.white_list = white_list + + trace_tup = tuple( + [x.caller_func_addr for x in reversed(trace.callsites)] + + [trace.callsites[0].callee_func_addr] + ) + + init_state = None + if sink in self.trace_dict and trace_tup in self.trace_dict[sink]: + init_state = self.trace_dict[sink][trace_tup]["input"] + if init_state is not None: + trace.callsites.pop() + + function_address = trace.current_function_address() + function = self.project.kb.functions[function_address] + subject = CallTraceSubject(trace, function) + + self.log.info( + "Running RDA on function %s@%#x...", function.name, function_address + ) + self.log.debug( + "Trace: %s", + "".join( + [ + f"{self.project.kb.functions[x.caller_func_addr].name}->" + for x in reversed(trace.callsites) + ] + + [sink.name] + ), + ) + + all_callsites = set(Utils.get_all_callsites(self.project)) + all_callsites.update(observation_points) + + self.rda_task = self.progress.add_task(f"Analyzing ...", total=None) + + self.kill_parent_after_timeout() + try: + rda = self.RDA( + subject=subject, + observation_points=all_callsites, + function_handler=handler, + kb=self.project.kb, + dep_graph=DepGraph(), + rda_timeout=self.rda_timeout, + start_time=time.time(), + init_state=init_state, + max_iterations=2, + ) + except TimeoutError: + rda = None + self.log.critical( + "TIMEOUT FOR subject: %s, sink: %s", + subject.content.callsites, + sink.name, + ) + + self.cleanup_timeout_proc() + self.progress.update(self.rda_task, visible=False) + self.progress.remove_task(self.rda_task) + + handler.call_trace.clear() + + return rda, handler + + def process_rda(self, dep: CustomRDA, handler: HandlerBase): + + fully_resolved = True + + self.log.debug("Starting Post Analysis") + resolved = self.post_analysis(dep, handler) + fully_resolved &= resolved + + if resolved: + sink_function = self.project.kb.functions[handler._sink_function_addr] + self.exclude_future_traces(dep, sink_function) + + def post_analysis( + self, dep: ReachingDefinitionsAnalysis, handler: HandlerBase + ) -> bool: + """ + :param dep: Completed RDA + :param handler: Handler object + :return: Whether results have been fully resolved + """ + return False + + def contains_external(self, rda: ReachingDefinitionsAnalysis, unresolved_closures): + main_func = self.project.kb.functions.function(name="main") + has_main = main_func is not None and rda.subject.content.includes_function( + main_func.addr + ) + if has_main: + return True + + for closures in unresolved_closures.values(): + for closure in closures: + external_defs = closure.handler.analyzed_list[0].definitions + for sink_atom in closure.handler._sink_atoms: + if any( + defn in closure.sink_trace.closures[sink_atom] + for defn in external_defs + ): + return True + + def vulnerable_sinks_from_call_trace( + self, handler: HandlerBase + ) -> Dict[StoredFunction, LiveDefinitions]: + vulnerable_sinks = {} + sink_function = self.project.kb.functions[handler._sink_function_addr] + for ct in handler.analyzed_list: + if ct.function.addr != sink_function.addr: + continue + self.log.debug("Checking %s for closure", ct) + if not self.forward_trace: + final_callsite = ct.subject.content.callsites[0] + call_tup = ( + final_callsite.caller_func_addr, + final_callsite.callee_func_addr, + ) + if sink_function in self.trace_dict: + if ( + call_tup in self.trace_dict[sink_function] + and ct.code_loc.ins_addr + not in self.trace_dict[sink_function][call_tup]["valid_sinks"] + ): + continue + vulnerable_sinks[ct] = ct.definitions + return vulnerable_sinks + + def exclude_future_traces( + self, rda: ReachingDefinitionsAnalysis, sink_function: Function + ): + current_function_address = rda.subject.content.current_function_address() + self.log.info( + "Exclude function %#x from future slices since the data dependencies are fully resolved.", + current_function_address, + ) + subject_callsites = rda.subject.content.callsites + self.excluded_functions[sink_function].add( + ( + current_function_address, + frozenset( + (x.caller_func_addr, x.callee_func_addr) for x in subject_callsites + ), + ) + ) + + +def default_parser(): + parser = argparse.ArgumentParser() + + path_group = parser.add_argument_group( + "Path Args", "Deciding source and result destination" + ) + + run_group = parser.add_argument_group( + "Running", "Options that modify how mango runs" + ) + + output_group = parser.add_argument_group( + "Output", "Options to increase or modify output" + ) + + path_group.add_argument(dest="bin_path", help="Binary to analyze.") + path_group.add_argument( + "--results", + dest="result_path", + default=None, + help="Where to store the results of the analysis.", + ) + + run_group.add_argument( + "--min-depth", + default=1, + type=int, + help="The minimum callstack height the analysis can reach from each sink to consider.", + ) + + run_group.add_argument( + "--max-depth", + default=1, + type=int, + help="The maximum callstack height the analysis can reach from each sink.", + ) + + run_group.add_argument( + "--source", + dest="source_function", + default="main", + type=str, + help="Use the specified function source", + ) + + output_group.add_argument( + "--disable-progress", + dest="disable_progress_bar", + action="store_true", + default=False, + help="Disable CFG progress bar", + ) + + run_group.add_argument( + "--exclude", + dest="excluded_functions", + type=str, + nargs="+", + help="List of functions to exclude from analysis", + ) + + run_group.add_argument( + "--workers", + dest="workers", + type=int, + default=1, + help="Set amount of workers to run during VRA", + ) + + run_group.add_argument( + "--ld-paths", + dest="ld_paths", + nargs="+", + help="Run analysis with ld_paths", + ) + + run_group.add_argument( + "--rda-timeout", + dest="rda_timeout", + type=int, + default=5 * 60, + help="Run with angr project auto_load_libs = True", + ) + + run_group.add_argument( + "--full-execution", + dest="full_exec", + action="store_true", + default=False, + help="Turns off assumed execution", + ) + + run_group.add_argument( + "--forward-trace", + dest="forward_trace", + action="store_true", + default=False, + help="Starts from source to sink", + ) + + output_group.add_argument( + "--loglevel", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="INFO", + dest="log_level", + help="Set the logging level", + ) + + output_group.add_argument( + "--enable-breakpoint", + dest="enable_breakpoint", + action="store_true", + default=False, + help="Enable breakpoint on error", + ) + + run_group.add_argument( + "--keyword-dict", + dest="keyword_dict", + default=None, + help="Where to store the results of the analysis.", + ) + + return parser, [path_group, run_group, output_group] diff --git a/package/argument_resolver/analysis/env_resolve.py b/package/argument_resolver/analysis/env_resolve.py new file mode 100644 index 0000000..2f25f95 --- /dev/null +++ b/package/argument_resolver/analysis/env_resolve.py @@ -0,0 +1,329 @@ +import hashlib +import logging +import json +import subprocess + +from pathlib import Path +from typing import Dict, List, Set + +from angr.analyses.reaching_definitions import ( + LiveDefinitions, + ReachingDefinitionsAnalysis, +) + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.stored_function import StoredFunction + +from argument_resolver.external_function.sink import ENV_SINKS, Sink +from argument_resolver.utils.transitive_closure import get_constant_data + +from argument_resolver.utils.utils import Utils + +from argument_resolver.analysis.base import default_parser, ScriptBase +from argument_resolver.utils.closure import Closure +import re + + +_l = make_logger() +_l.setLevel(logging.DEBUG) + + +# noinspection PyInterpreter +class EnvAnalysis(ScriptBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resolved_values = {} + + @staticmethod + def load_sinks(custom_sink=None, arg_pos=None, category=None) -> List[Sink]: + return ENV_SINKS + + def load_excluded_functions(self, excluded_functions=None): + excluded_functions = {} + default_excluded_functions = self.find_default_excluded_functions() + for f_sink in self.sinks: + if f_sink.name not in self.project.kb.functions: + continue + sink_func = self.project.kb.functions[f_sink.name] + excluded_functions[sink_func] = default_excluded_functions + return excluded_functions + + @staticmethod + def resolve_sinks( + vulnerable_sinks: Set[StoredFunction], + ) -> Dict[StoredFunction, Dict]: + # Get only the closures of the vulnerable atoms containing non-constant data. + resolved_dict = {} + for trace in vulnerable_sinks: + all_resolved = True + resolved_dict[trace] = {} + for arg, mv in trace.arg_vals.items(): + resolved_dict[trace][arg] = set() + for defn in LiveDefinitions.extract_defs_from_mv(mv): + data = get_constant_data(defn, mv, trace.state) + resolved_dict[trace][arg].update(set(data) if data else {None}) + if data is None or any( + x is None or trace.state.is_top(x) for x in data + ): + all_resolved = False + break + if not all_resolved: + break + resolved_dict[trace]["is_resolved"] = all_resolved + + return resolved_dict + + @staticmethod + def strip_non_alphanumeric_from_ends(s): + # Strip non-alphanumeric characters from the beginning and end of the string + return re.sub(r"^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$", "", s) + + def post_analysis(self, dep: ReachingDefinitionsAnalysis, handler: HandlerBase): + depth = len(dep.subject.content.callsites) if dep is not None else 1 + + if depth not in self.resolved_values: + self.resolved_values[depth] = {} + + # Relevant results are the `LiveDefinitions` captured at the `observation_points`. + potential_sinks = self.vulnerable_sinks_from_call_trace(handler) + sink_function = self.project.kb.functions[handler._sink_function_addr] + if sink_function in self.trace_dict: + for call_tup, trace_info in self.trace_dict[sink_function].items(): + if trace_info["constant"]: + potential_sinks.update( + {t_i: t_i.definitions for t_i in trace_info["constant"]} + ) + self.trace_dict[sink_function][call_tup]["constant"] = set() + _l.info("Found %d potential sinks.", len(potential_sinks)) + + resolved_sinks: Dict[StoredFunction, Dict] = self.resolve_sinks( + set(potential_sinks.keys()) + ) + + self.save_results(resolved_sinks, handler=handler) + # Change contains_external_definition + if dep is not None: + transitive_closures = { + trace: {Closure(trace, dep, handler)} + for trace, val_dict in resolved_sinks.items() + if not val_dict["is_resolved"] + } + if ( + self.contains_external(dep, transitive_closures) + or len(transitive_closures) == 0 + ): + # fully resolved - we should exclude this function for future exploration + return True + + return False + + def save_results(self, resolved_sinks=None, handler=None): + self.save_resolved_values(resolved_sinks, handler) + + def save_resolved_values(self, resolved_sinks, handler): + if resolved_sinks is None: + return + + self.result_path.mkdir(parents=True, exist_ok=True) + env_json = self.result_path / "env.json" + if env_json.exists(): + prev_dict = json.loads(env_json.read_text()) + out_dict = prev_dict["results"] + else: + out_dict = {} + prev_dict = {} + + has_sinks = len(self.get_sink_callsites(self.sinks)) != 0 + if not has_sinks: + _l.critical("NO SINKS FOUND") + + curr_sink = None + frontend_strs = {str(k): v for k, v in handler.keyword_access.items()} + for sink, values in resolved_sinks.items(): + curr_sink = sink.name + if sink.function.name not in out_dict: + out_dict[sink.function.name] = {} + keys = [ + Utils.bytes_from_int(val).decode("latin-1") + if val is not None + else "TOP" + for atom in sink.args_atoms[0] + for val in (sink.constant_data[atom] or [None]) + ] + for key in [k for k in keys if len(k) > 1]: + if key not in out_dict[sink.function.name]: + out_dict[sink.function.name][key] = {"keywords": []} + + if len(sink.args_atoms) > 1: + sink_atoms = sink.args_atoms[1:] + else: + # This branch includes getenv first arg + sink_atoms = sink.args_atoms + + for idx, args in enumerate(sink_atoms): + idx = idx + 1 + if idx not in out_dict[sink.function.name][key]: + out_dict[sink.function.name][key][idx] = {} + for arg in args: + if ( + arg not in sink.constant_data + or sink.constant_data[arg] is None + ): + continue + for val in sink.constant_data[arg]: + if val is not None: + out_val = Utils.bytes_from_int(val).decode("latin-1") + out_val = ( + out_val[:-1] + if out_val.endswith("\x00") + else out_val + ) + else: + out_val = "TOP" + if self.keyword_dict: + strings = Utils.get_strings_from_pointers( + sink.arg_vals[arg], sink.state, sink.code_loc + ) + for string in Utils.get_values_from_multivalues( + strings + ): + string_str = str(string) + for k, v in frontend_strs.items(): + if k in string_str: + out_dict[sink.function.name][key][ + "keywords" + ].extend(v) + self.log.info( + "Adding %s: %s @(%s)", + key, + out_val, + hex(sink.code_loc.ins_addr), + ) + if out_val in out_dict[sink.function.name][key][idx]: + out_dict[sink.function.name][key][idx][out_val] = list( + set(out_dict[sink.function.name][key][idx][out_val]) + | {hex(sink.code_loc.ins_addr)} + ) + else: + out_dict[sink.function.name][key][idx][out_val] = [ + hex(sink.code_loc.ins_addr) + ] + + file = Path(self.project.filename) + + final_dict = { + "results": out_dict, + "name": file.name, + "path": str(file), + "error": None, + "cfg_time": self.cfg_time, + "vra_time": self.vra_time, + "analysis_time": self.analysis_time, + "has_sinks": has_sinks, + "sink_times": {}, + } + + if not prev_dict: + with file.open("rb") as f: + final_dict["sha256"] = hashlib.file_digest(f, "sha256").hexdigest() + else: + final_dict["sha256"] = prev_dict["sha256"] + if "sink_times" in prev_dict: + final_dict["sink_times"] = prev_dict["sink_times"] + else: + final_dict["sink_times"][curr_sink] = self.sink_time + + if curr_sink is not None: + final_dict["sink_times"][curr_sink] = self.sink_time + + with env_json.open("w+") as f: + json.dump(final_dict, f, indent=4) + + @staticmethod + def merge(directory: Path, result_path: Path): + out_dict = {} + env_files = ( + subprocess.check_output( + ["find", str(directory.resolve()), "-type", "f", "-name", "env.json"] + ) + .decode() + .strip() + .split("\n") + ) + env_files = [Path(x) for x in env_files if Path(x).is_file()] + if result_path in env_files: + env_files.remove(result_path) + + for env_file in env_files: + try: + data = json.loads(env_file.read_text()) + except json.decoder.JSONDecodeError: + continue + + if not data or "error" not in data or not isinstance(data["results"], dict): + continue + for sink, val_dict in data["results"].items(): + for key, values in val_dict.items(): + if key not in out_dict: + out_dict[key] = {} + if sink not in out_dict[key]: + out_dict[key][sink] = {} + keywords = values.pop("keywords") + if data["name"] not in out_dict[key][sink]: + out_dict[key][sink][data["name"]] = { + "keywords": keywords, + "sha256": data["sha256"], + "values": [], + } + out_dict[key][sink][data["name"]]["keywords"] = list( + set(keywords) + | set(out_dict[key][sink][data["name"]]["keywords"]) + ) + for position, arg_vals in values.items(): + for val, code_loc in arg_vals.items(): + locations = [ + hex(x) if isinstance(x, int) else x for x in code_loc + ] + if val == "": + val = "TOP" + val_dict = { + "value": val, + "locations": locations, + "pos": position, + } + out_dict[key][sink][data["name"]]["values"].append(val_dict) + + with result_path.open("w+") as f: + json.dump(out_dict, f, indent=4) + + +def get_cli_args(): + parser, groups = default_parser() + path_group, run_group, output_group = groups + + run_group.add_argument( + "--merge", + dest="merge", + action="store_true", + default=False, + help="Merge all env.json in a directory", + ) + + parser.set_defaults(max_depth=2, result_path="results") + + return parser.parse_args() + + +def main(): + args = get_cli_args() + if not args.merge: + args.__dict__.pop("merge") + analyzer = EnvAnalysis(**args.__dict__) + analyzer.analyze() + else: + EnvAnalysis.merge(Path(args.bin_path), Path(args.result_path)) + + +if __name__ == "__main__": + main() diff --git a/package/argument_resolver/analysis/mango.py b/package/argument_resolver/analysis/mango.py new file mode 100644 index 0000000..89fe406 --- /dev/null +++ b/package/argument_resolver/analysis/mango.py @@ -0,0 +1,772 @@ +import subprocess + +import time +import json + +from pathlib import Path +from typing import Dict, Tuple, List + +import networkx + +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions + +from angr.analyses.reaching_definitions.reaching_definitions import ( + ReachingDefinitionsAnalysis, +) +from angr.code_location import ExternalCodeLocation +from angr.knowledge_plugins.key_definitions.tag import ReturnValueTag, ParameterTag +from angr.knowledge_plugins.key_definitions.atoms import Register + +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.handlers.local_handler import LocalHandler + +from argument_resolver.utils.closure import Closure +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.rank import get_rank + +from argument_resolver.formatters.results_formatter import save_closure +from argument_resolver.analysis.base import ScriptBase, default_parser +from argument_resolver.external_function import ( + is_an_external_input_function, +) +from argument_resolver.external_function.sink import VULN_TYPES +from argument_resolver.external_function import KEY_BEACONS +from argument_resolver.utils.closure import SkeletonClosure + + +class MangoAnalysis(ScriptBase): + def __init__(self, *args, **kwargs): + self.concise = kwargs.pop("concise") + super().__init__(*args, **kwargs) + self.all_unresolved_closures = {} + self.execv_dict = {} + + def save_results(self, check_if_sanitized=False): + for sink, defn_dict in self.all_unresolved_closures.items(): + output = [] + for closures in defn_dict.values(): + for closure in closures.values(): + output.extend(closure["output"]) + self.result_formatter.log_closures_for_sink(output, sink, self.log) + self.save_results_to_file(None, None) + + @staticmethod + def similar_closure(closure_1, closure_2) -> bool: + if isinstance(closure_1, SkeletonClosure): + code_loc_1 = closure_1.code_loc + else: + code_loc_1 = closure_1.sink_trace.code_loc + + if isinstance(closure_2, SkeletonClosure): + code_loc_2 = closure_2.code_loc + else: + code_loc_2 = closure_2.sink_trace.code_loc + + return code_loc_1.ins_addr == code_loc_2.ins_addr + + @staticmethod + def get_sources_from_closure(closure) -> Dict: + sources = {"likely": {}, "possibly": {}, "tags": {}, "valid_funcs": set()} + used_sources = {"likely": set(), "possibly": set()} + for atom in closure.handler._sink_atoms: + sink_str = Utils.get_strings_from_pointers( + closure.sink_trace.arg_vals[atom], + closure.sink_trace.state, + closure.sink_trace.code_loc, + ) + if "ARGV" in str(sink_str): + sources["likely"]["ARGV"] = ['"ARGV"'] + used_sources["likely"].add("ARGV") + + for env_var in closure.handler.env_access | set( + closure.handler.keyword_access + ): + name = env_var._encoded_name.decode("latin-1") + key = name[name.find('"') + 1 : name.rfind('"')] + if not key: + continue + + loc = name.rfind("@") + if loc != -1: + under = name.find("_", loc) + if under != -1: + name = name[:under] + if name in used_sources: + continue + + if key in str(sink_str) and key != "TOP": + if loc != -1: + addr = int(name[loc + 1 :].split("_")[0], 16) + sources["valid_funcs"].add(addr) + if key in sources["likely"]: + sources["likely"][key].append(name) + else: + sources["likely"][key] = [name] + used_sources["likely"].add(name) + else: + if key in sources["possibly"]: + sources["possibly"][key].append(name) + else: + sources["possibly"][key] = [name] + used_sources["possibly"].add(name) + + # for keyword_var, keywords in closure.handler.keyword_access.items(): + + for func, instances in closure.handler.fd_tracker.items(): + if isinstance(func, int): + continue + for fd_dict in instances: + local_sources = [] + input_name = fd_dict["val"]._encoded_name.decode("latin-1") + if '"' in input_name: + input_key = input_name[ + input_name.find('"') + 1 : input_name.rfind('"') + ] + input_key = input_key.split(",")[0].strip().strip('"') + else: + l_paren = input_name.rfind("(") + l_paren = input_name[:l_paren].rfind("(") + r_paren = input_name.find(")") + r_paren += input_name[r_paren + 1 :].find(")") + input_key = input_name[l_paren + 1 : r_paren] + + if input_key == input_name[:-1]: + input_key = input_name[ + input_name.find("(") + 1 : input_name.rfind(")") + ] + loc = input_name.rfind("@") + if loc != -1: + under = input_name.find("_", loc) + if under != -1: + input_name = input_name[:under] + + if input_name in used_sources: + continue + parents = fd_dict["parent"] + local_sources.append(str(input_name)) + while parents is not None and len(parents) > 0: + parent = parents.pop() + if parent in closure.handler.fd_tracker: + name = closure.handler.fd_tracker[parent][ + "val" + ]._encoded_name.decode("latin-1") + local_sources.insert(0, name) + if closure.handler.fd_tracker[parent]["parent"] is not None: + parents = ( + closure.handler.fd_tracker[parent]["parent"] + + parents + ) + if input_key in str(sink_str) and input_key != "": + for f in local_sources: + loc = f.rfind("@") + if loc != -1: + addr = int(f[loc + 1 :].split("_")[0], 16) + sources["valid_funcs"].add(addr) + if input_key in sources["likely"]: + sources["likely"][input_key].append(local_sources[-1]) + else: + sources["likely"][input_key] = local_sources + used_sources["likely"].add(input_name) + else: + if input_key in sources["possibly"]: + sources["possibly"][input_key].append(local_sources[-1]) + else: + sources["possibly"][input_key] = local_sources + used_sources["possibly"].add(input_name) + sources["tags"] = used_sources + return sources + + def save_closure(self, sink, defn, closure): + has_sinks = len(self.get_sink_callsites(self.sinks)) != 0 + if not has_sinks: + self.log.critical("NO SINKS FOUND") + + analyzed_list = closure.handler.analyzed_list + is_sanitized = False + + sources = self.get_sources_from_closure(closure) + + k = "likely" if sources["likely"] else "possibly" + if sink in self.all_unresolved_closures: + for defn, closures_dict in self.all_unresolved_closures[sink].items(): + for c, c_dict in [ + (k, v) + for k, v in closures_dict.items() + if self.similar_closure(closure, k) + ]: + input_keys = set(c_dict["external_input"]["sources"][k].keys()) + input_keys.discard("ARGV") + if input_keys and input_keys <= set(sources[k]): + self.log.warning( + "Not Saving Closure, %s matches %s", + set(sources[k]), + input_keys, + ) + return None + + possible_ranks = get_rank(sources["tags"]["possibly"]) + likely_ranks = get_rank(sources["tags"]["likely"]) + + dict_keys = self.env_dict | self.keyword_dict + + for keyword in [s for s in sources["likely"] if s in dict_keys]: + hit = False + for rank in [r for r in likely_ranks if keyword in r]: + hit = True + likely_ranks[rank] *= 10 + if hit: + break + + for beacon in [x for x in KEY_BEACONS if x in sources["likely"]]: + hit = False + for rank in [r for r in likely_ranks if beacon in r]: + hit = True + likely_ranks[rank] *= 10 + if hit: + break + + rank = max(likely_ranks.values() or [0]) + max(possible_ranks.values() or [0]) + + # if any("frontend_param" in y for x in sources["likely"].values() for y in x): + # rank = 7 + + all_sources = {"sources": sources, "rank": rank} + + valid_closure = { + "analyzed_list": analyzed_list, + "sanitized": is_sanitized, + "call_locs": closure.get_call_locations(), + "external_input": all_sources, + "sink_loc": closure.sink_trace.code_loc.ins_addr, + } + + if ( + sink in self.all_unresolved_closures + and defn in self.all_unresolved_closures[sink] + ): + for o_closure in self.all_unresolved_closures[sink][defn].values(): + if ( + valid_closure["call_locs"] == o_closure["call_locs"] + and valid_closure["sink_loc"] == o_closure["sink_loc"] + ): + return None + + output = self.result_formatter.format_unresolved_closures( + Path(self.project.filename).name, + closure, + valid_closure, + defn, + self.find_default_excluded_functions(), + all_sources, + env_dict=self.env_dict, + keyword_dict=self.keyword_dict, + limit_output=self.concise, + ) + self.result_formatter.log_closures_for_sink(output, sink, self.log) + valid_closure["output"] = output + del valid_closure["analyzed_list"] + + if sink not in self.all_unresolved_closures: + self.all_unresolved_closures[sink] = {} + + if defn not in self.all_unresolved_closures[sink]: + self.all_unresolved_closures[sink][defn] = {} + + self.all_unresolved_closures[sink][defn][ + SkeletonClosure(closure) + ] = valid_closure + + self.analysis_time = time.time() - self.analysis_start_time + self.save_results_to_file(closure, valid_closure) + return all_sources + + def save_results_to_file(self, closure, closure_info): + has_sinks = len(self.get_sink_callsites(self.sinks)) != 0 + if self.result_path is not None: + self.result_path.mkdir(parents=True, exist_ok=True) + + save_closure( + project=self.project, + cfg_time=self.cfg_time, + vra_time=self.vra_time, + mango_time=self.analysis_time, + closure=closure, + closure_info=closure_info, + execv_dict=self.execv_dict, + result_path=self.result_path, + time_data=self.time_data, + total_sinks=self.sinks_found, + has_sinks=has_sinks, + category=self.category, + sink_time=self.sink_time, + ) + + def post_analysis( + self, dep: ReachingDefinitionsAnalysis, handler: HandlerBase + ) -> bool: + if dep is None: + self.log.error("RDA Failed Due to Timeout") + return False + sink_function = self.project.kb.functions[handler._sink_function_addr] + + self.log.debug("Finding vulnerable sinks.") + potential_sinks = self.vulnerable_sinks_from_call_trace(handler) + self.log.info("Found %d potential sinks.", len(potential_sinks)) + + unresolved_closures = self.trim_resolved_values( + sink_function, dep, potential_sinks, handler + ) + self.log.info( + "Found %d unresolved vulnerable definitions.", len(unresolved_closures) + ) + + all_closures = set() + input_locations = set() + for defn, closures in unresolved_closures.items(): + for closure in closures: + self.log.debug("Saving Closure: %s", closure) + sources = self.save_closure(sink_function, defn, closure) + if sources: + input_locations |= sources["sources"]["valid_funcs"] + if sources is not None: + all_closures.add(closure) + + if input_locations: + callsites = {x.caller_func_addr for x in dep.subject.content.callsites} + remove_addrs = set() + for addr in input_locations: + if addr in callsites: + remove_addrs.add(addr) + else: + depth = 2 + prev_parent = None + while depth > 1: + idx, func = next( + iter( + (idx, x) + for idx, x in enumerate(handler.analyzed_list) + if x.code_loc.ins_addr == addr + ) + ) + parent_idx, parent = next( + iter( + (new_idx, x) + for new_idx, x in enumerate( + reversed(handler.analyzed_list[:idx]) + ) + if x.depth < func.depth + ) + ) + depth = parent.depth + if prev_parent is not None and prev_parent.name == parent.name: + remove_addrs = callsites + break + + prev_parent = parent + if parent.function.addr in callsites: + remove_addrs.add(parent.function.addr) + break + if len(callsites - remove_addrs) > 0: + final_callsite = None + subj_callsites = dep.subject.content.callsites + for rev_idx, callsite in enumerate(reversed(subj_callsites)): + if callsite.caller_func_addr not in remove_addrs: + final_callsite = tuple( + [ + x.caller_func_addr + for x in reversed(subj_callsites[: rev_idx + 1]) + ] + + [subj_callsites[0].callee_func_addr] + ) + else: + break + if ( + sink_function in self.trace_dict + and final_callsite in self.trace_dict[sink_function] + ): + self.log.warning( + "Found Unused Callsites: %s", + [hex(x) for x in callsites - remove_addrs], + ) + self.log.warning("Setting parent to final") + if len(final_callsite) > 2: + final_callsite = final_callsite[1:] + self.trace_dict[sink_function][final_callsite]["final"] = True + + resolved = not self.contains_external(dep, unresolved_closures) + + for closure in all_closures: + del closure.handler.analyzed_list[1:] + closure.handler.call_trace.clear() + closure.handler.call_stack.clear() + + closure.rda._function_handler = None + return resolved + + @staticmethod + def contains_external_input(closure: Closure): + contains_external, valid_funcs, _ = MangoAnalysis.search_for_external_input( + closure, closure.sink_trace + ) + if ( + contains_external + and valid_funcs + and closure.handler.analyzed_list[0] not in valid_funcs + ): + caller_addrs = { + x.caller_func_addr for x in closure.rda.subject.content.callsites + } + valid_funcs |= { + x + for x in closure.handler.analyzed_list + if x.function.addr in caller_addrs + } + valid_funcs.add(closure.sink_trace) + new_analyzed_list = [ + x for x in closure.handler.analyzed_list if x in valid_funcs + ] + return True, new_analyzed_list + return False, closure.handler.analyzed_list + + def value_from_pointer_atoms( + self, atoms: List["Atom"], state, code_loc + ) -> Tuple[List[str], bool]: + values = [] + contains_unresolved = False + for atom in atoms: + bv = Utils.get_bv_from_atom(atom, state.arch) + strings = Utils.get_strings_from_pointer(bv, state, code_loc) + for s in Utils.get_values_from_multivalues(strings): + if s.concrete: + values.append(Utils.bytes_from_int(s).decode("latin-1")) + else: + values.append(str(s)) + contains_unresolved = True + + return values, contains_unresolved + + def handle_exec(self, closure: Closure): + vals = closure.sink_trace.arg_vals[next(iter(closure.sink_trace.args_atoms[1]))] + for pointer in Utils.get_values_from_multivalues(vals): + try: + sp = closure.sink_trace.state.get_sp() + except AssertionError: + sp = None + if not Utils.is_pointer(pointer, sp, self.project): + continue + base_atom = closure.sink_trace.state.deref( + pointer, + closure.sink_trace.state.arch.bytes, + endness=closure.sink_trace.state.arch.memory_endness, + ) + args = {} + vulnerable_args = [] + if closure.sink_trace.name.startswith("execv"): + count = 0 + while count < 10: + pointer = closure.sink_trace.state.deref( + base_atom, closure.sink_trace.state.arch.bytes + ) + + if len(pointer) == 1 and next(iter(pointer)).addr == 0: + break + + arg_strings, vulnerable = self.value_from_pointer_atoms( + pointer, closure.sink_trace.state, closure.sink_trace.code_loc + ) + args[count] = arg_strings + if vulnerable: + vulnerable_args.append(count) + base_atom.addr.offset += closure.sink_trace.state.arch.bytes + count += 1 + elif closure.sink_trace.name.startswith("execl"): + state = closure.sink_trace.state + for idx, atoms in enumerate(closure.sink_trace.args_atoms[1:]): + arg_strings = [] + vulnerable = False + for arg_atom in atoms: + pointer = state.deref( + arg_atom, closure.sink_trace.state.arch.bytes + ) + values, vuln = self.value_from_pointer_atoms( + pointer, state, closure.sink_trace.code_loc + ) + arg_strings.extend(values) + vulnerable |= vuln + + args[idx] = arg_strings + if vulnerable: + vulnerable_args.append(idx) + + if len(vulnerable_args) > 0: + for name in args[0]: + name = Path(name).name + if name not in self.execv_dict: + self.execv_dict[name] = [] + self.execv_dict[name].append( + { + "args": args, + "vulnerable_args": vulnerable_args, + "addr": closure.sink_trace.code_loc.ins_addr, + } + ) + + def trim_resolved_values( + self, sink, dep, vulnerable_sinks, handler + ) -> Dict[Definition, Tuple[networkx.DiGraph, LiveDefinitions, LocalHandler]]: + # Get only the closures of the vulnerable atoms containing non-constant data. + unresolved_closures: Dict[Definition] = {} + for trace, defns in vulnerable_sinks.items(): + for atom in handler._sink_atoms: + if atom not in trace.constant_data: + continue + constant = trace.constant_data[atom] is not None and all( + d is not None for d in trace.constant_data[atom] + ) + new_closure = Closure(trace, dep, handler) + + if ( + constant + and sink.name.startswith("exec") + and sink.name != "execFormatCmd" + ): + self.handle_exec(new_closure) + + for defn in sorted( + LiveDefinitions.extract_defs_from_mv(trace.arg_vals[atom]), + key=lambda x: (x.codeloc.ins_addr or x.codeloc.block_addr) + if x.codeloc + else 0, + ): + if not constant: + if defn not in unresolved_closures: + unresolved_closures[defn] = set() + else: + for closure in unresolved_closures[defn].copy(): + if closure < new_closure: + unresolved_closures[defn].remove(closure) + if ( + sink in self.all_unresolved_closures + and defn in self.all_unresolved_closures[sink] + ): + for closure in self.all_unresolved_closures[sink][ + defn + ].copy(): + if closure < new_closure: + self.all_unresolved_closures[sink][defn].pop( + closure + ) + if all( + y != new_closure + for x in unresolved_closures.values() + for y in x + ): + unresolved_closures[defn].add(new_closure) + break + else: + output, _ = self.result_formatter.log_function(trace) + self.log.info("[blue]Resolved call to %s:", sink.name) + for line in output: + self.log.info(line) + if sink not in self.all_unresolved_closures: + continue + + if defn not in self.all_unresolved_closures[sink]: + continue + + for closure in self.all_unresolved_closures[sink][defn].copy(): + if closure < new_closure: + self.all_unresolved_closures[sink][defn].pop(closure) + + return unresolved_closures + + @staticmethod + def search_for_external_input( + closure, stored_func, valid_funcs=None, explored_funcs=None + ): + contains_external = False + if valid_funcs is None: + valid_funcs = set() + + if explored_funcs is None: + explored_funcs = set() + + if stored_func in valid_funcs or stored_func in explored_funcs: + return True, valid_funcs, explored_funcs + + explored_funcs.add(stored_func) + if is_an_external_input_function(stored_func.name): + valid_funcs.add(stored_func) + return True, valid_funcs, explored_funcs + + parent_functions = set() + for defn in { + x + for defn_set in stored_func.closures.values() + for x in defn_set | stored_func.definitions | stored_func.return_definitions + }: + if not any( + isinstance(tag, (ParameterTag, ReturnValueTag)) for tag in defn.tags + ): + continue + + if isinstance(defn.codeloc, ExternalCodeLocation): + if ( + not isinstance(defn.atom, Register) + or defn.atom.reg_offset != stored_func.state.arch.sp_offset + ): + contains_external = True + valid_funcs.add(stored_func) + func_addrs = { + tag.function for tag in defn.tags if hasattr(tag, "function") + } + for func in closure.handler.analyzed_list: + if func.function.addr in func_addrs: + valid_funcs.add(func) + else: + func_idx = closure.handler.analyzed_list.index(stored_func) + parent_functions |= { + x + for x in closure.handler.analyzed_list[:func_idx] + if defn in x.definitions | x.return_definitions + } + parent_functions.discard(closure.handler.analyzed_list[0]) + + for func in parent_functions: + ret = MangoAnalysis.search_for_external_input( + closure, func, valid_funcs, explored_funcs + ) + parent_contains_external, parent_valid_funcs, parent_explored = ret + explored_funcs |= parent_explored + if parent_contains_external: + contains_external = True + valid_funcs |= parent_valid_funcs + valid_funcs.add(stored_func) + + return contains_external, valid_funcs, explored_funcs + + @staticmethod + def merge_execve(directory: Path, result_path: Path): + out_dict = {} + execv_files = ( + subprocess.check_output( + ["find", str(directory.resolve()), "-type", "f", "-name", "execv.json"] + ) + .decode() + .strip() + .split("\n") + ) + execv_files = [Path(x) for x in execv_files if Path(x).is_file()] + if result_path in execv_files: + execv_files.remove(result_path) + + for execv_file in execv_files: + data = json.loads(execv_file.read_text()) + if not data: + continue + for bin_name, instances in data["execv"].items(): + if bin_name not in out_dict: + out_dict[bin_name] = { + "args": {}, + "vulnerable_args": [], + "parent_bins": [], + } + for val_dict in instances: + for pos, values in val_dict["args"].items(): + if pos not in out_dict[bin_name]["args"]: + out_dict[bin_name]["args"][pos] = [] + out_dict[bin_name]["args"][pos] = list( + set(out_dict[bin_name]["args"][pos] + values) + ) + + out_dict[bin_name]["vulnerable_args"] = list( + set( + out_dict[bin_name]["vulnerable_args"] + + val_dict["vulnerable_args"] + ) + ) + for x in out_dict[bin_name]["parent_bins"]: + if x["sha256"] == data["sha256"]: + x["addrs"] = list(set(x["addrs"] + [val_dict["addr"]])) + break + else: + out_dict[bin_name]["parent_bins"].append( + { + "name": data["name"], + "sha256": data["sha256"], + "addrs": [val_dict["addr"]], + } + ) + + with result_path.open("w+") as f: + json.dump(out_dict, f, indent=4) + + +def get_cli_args(): + parser, groups = default_parser() + path_group, run_group, output_group = groups + + run_group.add_argument( + "--arg", + dest="arg_pos", + default=0, + type=int, + help="The argument position in a sink function", + ) + run_group.add_argument( + "-c", + "--category", + dest="sink_category", + default="cmdi", + type=str, + choices=list(VULN_TYPES.keys()), + help="The category of sink to search for", + ) + run_group.add_argument( + "--sink", + dest="sink", + default="", + type=str, + help="Use the specified function sink", + ) + + run_group.add_argument( + "--env-dict", + dest="env_dict", + default=None, + help="Where to store the results of the analysis.", + ) + + run_group.add_argument( + "--merge-execve", + dest="merge", + action="store_true", + default=False, + help="Merge all execv.json in a directory", + ) + + output_group.add_argument( + "--concise", + dest="concise", + action="store_true", + default=False, + help="Speeds up overall analysis by outputting only the sink function", + ) + + parser.set_defaults(max_depth=7) + return parser.parse_args() + + +def main(): + args = get_cli_args() + if args.merge: + MangoAnalysis.merge_execve(Path(args.bin_path), Path(args.result_path)) + else: + args.__dict__.pop("merge") + analyzer = MangoAnalysis(**args.__dict__) + analyzer.analyze() + + +if __name__ == "__main__": + main() diff --git a/package/argument_resolver/concretization/__init__.py b/package/argument_resolver/concretization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/package/argument_resolver/concretization/concrete_runner.py b/package/argument_resolver/concretization/concrete_runner.py new file mode 100644 index 0000000..8c040c9 --- /dev/null +++ b/package/argument_resolver/concretization/concrete_runner.py @@ -0,0 +1,43 @@ +import shlex + +from angr import Project, PointerWrapper +from angr.knowledge_plugins.functions import Function +from angr.knowledge_plugins.key_definitions.atoms import SimStackArg + + +class ConcreteRunner: + + def __init__(self, project: Project, function: Function, vuln_pos: int, func_args): + self.shell_string = b";;;; `echo 'Hello World!'`" + pointer = PointerWrapper(self.shell_string, buffer=True) + func_args[vuln_pos] = pointer + vuln_reg = function.arguments[vuln_pos] + + self.function = function + self.ret_addr = 0xffffffff + self.init_state = project.factory.call_state(self.function.addr, *func_args, ret_addr=self.ret_addr, add_options={"ZERO_FILL_UNCONSTRAINED_MEMORY", "ZERO_FILL_UNCONSTRAINED_REGISTERS"}) + self.string_memloc = vuln_reg.get_value(self.init_state) + self.sm = project.factory.simulation_manager(self.init_state) + + def check_if_escaped(self) -> bool: + self.sm.explore(find=self.ret_addr) + for state in self.sm.found: + shell_string = state.memory.concrete_load(self.string_memloc, 800).tobytes().strip(b"\x00") + if shell_string != self.shell_string and self.is_escaped(shell_string): + return True + + reg_val = state.solver.eval(getattr(state.regs, self.function.calling_convention.RETURN_VAL.reg_name)) + shell_string = state.memory.concrete_load(reg_val, 800).tobytes().strip(b"\x00") + if len(shell_string) >= len(self.shell_string) and shell_string != self.shell_string and self.is_escaped(shell_string): + return True + return False + + @staticmethod + def is_escaped(string: bytes): + lexed = list(shlex.shlex(string.decode())) + if 'echo' not in lexed: + # Probably format changed somehow + return False + elif ';' in lexed or '`' in lexed: + return False + return True diff --git a/package/argument_resolver/external_function/__init__.py b/package/argument_resolver/external_function/__init__.py new file mode 100644 index 0000000..6b4acfe --- /dev/null +++ b/package/argument_resolver/external_function/__init__.py @@ -0,0 +1,7 @@ +from .input_functions import INPUT_EXTERNAL_FUNCTIONS, KEY_BEACONS +from .sink import VULN_TYPES, Sink +from .function_declarations import CUSTOM_DECLS + + +def is_an_external_input_function(function_name: str) -> bool: + return any(function_name == x for x in INPUT_EXTERNAL_FUNCTIONS) diff --git a/package/argument_resolver/external_function/function_declarations/__init__.py b/package/argument_resolver/external_function/function_declarations/__init__.py new file mode 100644 index 0000000..2fb68a4 --- /dev/null +++ b/package/argument_resolver/external_function/function_declarations/__init__.py @@ -0,0 +1,5 @@ +from .nvram import libnvram_decls +from .win32 import winreg_decls +from .custom import custom_decls + +CUSTOM_DECLS = {**libnvram_decls, **winreg_decls, **custom_decls} diff --git a/package/argument_resolver/external_function/function_declarations/custom.py b/package/argument_resolver/external_function/function_declarations/custom.py new file mode 100644 index 0000000..f2f0978 --- /dev/null +++ b/package/argument_resolver/external_function/function_declarations/custom.py @@ -0,0 +1,141 @@ +from angr.sim_type import ( + SimTypeFunction, + SimTypeInt, + SimTypePointer, + SimTypeChar, +) + + +custom_decls = { + "dprintf": SimTypeFunction( + [ + SimTypePointer(SimTypeInt(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + arg_names=["stream", "template"], + variadic=True, + ), + "twsystem": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "execFormatCmd": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "exec_cmd": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "tp_systemEx": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "___system": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "bstar_system": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "doSystemCmd": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "doShell": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "CsteSystem": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "cgi_deal_popen": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "ExeCmd": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "ExecShell": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "exec_shell_popen": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "exec_shell_popen_str": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "exec_shell_async": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "exec_shell_sync": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "exec_shell_sync2": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeInt(signed=True), + arg_names=["command"], + ), + "nflog_get_payload": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + ), + "query_param_parser": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + ), + "GetValue": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + ), + "SetValue": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + ), + "httpSetEnv": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + ), +} diff --git a/package/argument_resolver/external_function/function_declarations/nvram.py b/package/argument_resolver/external_function/function_declarations/nvram.py new file mode 100644 index 0000000..9ca2531 --- /dev/null +++ b/package/argument_resolver/external_function/function_declarations/nvram.py @@ -0,0 +1,143 @@ +from angr.sim_type import ( + SimTypeFunction, + SimTypeInt, + SimTypePointer, + SimTypeChar, + SimTypeBottom, +) + + +libnvram_decls = { + # + # Taken from: https://github.com/firmadyne/libnvram/blob/v1.0c/nvram.c . + # + # int nvram_set(char *key, char *val); + "nvram_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + arg_names=["key", "val"], + ), + # char *nvram_get(char *key); + "nvram_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypePointer(SimTypeChar(), offset=0), + arg_names=["key"], + ), + # char *nvram_safe_get(char *key); + "nvram_safe_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypePointer(SimTypeChar(), offset=0), + arg_names=["key"], + ), + # int nvram_set(char *key, char *val); + "nvram_safe_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + arg_names=["key", "val"], + ), + # + # Taken from: https://github.com/firmadyne/libnvram/blob/v1.0c/alias.c . + # + # int acosNvramConfig_set(char *key, char *val) + "acosNvramConfig_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeInt(signed=True), + arg_names=["key", "val"], + ), + # char *acosNvramConfig_get(char *key) + "acosNvramConfig_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypePointer(SimTypeChar(), offset=0), + arg_names=["key"], + ), + "acosNvramConfig_read": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeInt(), offset=0), + arg_names=["key"], + ), + #"acosNvramConfig_write": SimTypeFunction( + # [SimTypePointer(SimTypeChar(), offset=0)], + # SimTypePointer(SimTypeChar(), offset=0), + # SimTypePointer(SimTypeInt(), offset=0), + # arg_names=["key"], + #), + # + # Custom Definitions + # + # + "bcm_nvram_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeBottom(label="void"), + arg_names=["name"], + ), + "bcm_nvram_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeBottom(label="void"), + arg_names=["name", "value"], + ), + "envram_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeBottom(label="void"), + arg_names=["name"], + ), + "envram_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeBottom(label="void"), + arg_names=["name", "value"], + ), + "wlcsm_nvram_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeBottom(label="void"), + arg_names=["name"], + ), + "wlcsm_nvram_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeBottom(label="void"), + arg_names=["name", "value"], + ), + "dni_nvram_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeBottom(label="void"), + arg_names=["name"], + ), + "dni_nvram_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeBottom(label="void"), + arg_names=["name", "value"], + ), + "PTI_nvram_get": SimTypeFunction( + [SimTypePointer(SimTypeChar(), offset=0)], + SimTypeBottom(label="void"), + arg_names=["name"], + ), + "PTI_nvram_set": SimTypeFunction( + [ + SimTypePointer(SimTypeChar(), offset=0), + SimTypePointer(SimTypeChar(), offset=0), + ], + SimTypeBottom(label="void"), + arg_names=["name", "value"], + ), +} diff --git a/package/argument_resolver/external_function/function_declarations/win32.py b/package/argument_resolver/external_function/function_declarations/win32.py new file mode 100644 index 0000000..b1ef358 --- /dev/null +++ b/package/argument_resolver/external_function/function_declarations/win32.py @@ -0,0 +1,21 @@ +from angr.sim_type import SimTypeFunction, SimTypeLong + + +winreg_decls = { + # + # Taken from: https://github.com/firmadyne/libnvram/blob/v1.0c/nvram.c . + # + "RegOpenKeyExW": SimTypeFunction( + [ + SimTypeLong(signed=True), + SimTypeLong(signed=True), + SimTypeLong(signed=True), + SimTypeLong(signed=True), + SimTypeLong(signed=True), + ], + SimTypeLong(signed=True), + ), + "RegCloseKey": SimTypeFunction( + [SimTypeLong(signed=True)], SimTypeLong(signed=True) + ), +} diff --git a/package/argument_resolver/external_function/input_functions.py b/package/argument_resolver/external_function/input_functions.py new file mode 100644 index 0000000..becaa33 --- /dev/null +++ b/package/argument_resolver/external_function/input_functions.py @@ -0,0 +1,42 @@ +from typing import Set + +from .sink import VULN_TYPES + +# External functions that can be used to provide input to the program. +INPUT_EXTERNAL_FUNCTIONS: Set[str] = { + "read", + "fread", + "fgets", + "recv", + "recvfrom", + "custom_param_parser", +} | {x.name for x in VULN_TYPES["getter"]} + + +KEY_BEACONS = { + "REQUEST_METHOD", + "REQUEST_URI", + "QUERY_STRING", + "CONTENT_TYPE", + "CONTENT_LENGTH", + "PATH_INFO", + "SCRIPT_NAME", + "DOCUMENT_URI", + "HTTP_ACCEPT_LANGUAGE", + "HTTP_AUTH", + "HTTP_AUTHORIZATION", + "HTTP_CALLBACK", + "HTTP_COOKIE", + "HTTP_HNAP_AUTH", + "HTTP_HOST", + "HTTP_MTFWU_ACT", + "HTTP_MTFWU_AUTH", + "HTTP_NT", + "HTTP_REFERER", + "HTTPS", + "HTTP_SID", + "HTTP_SOAPACTION", + "HTTP_ST", + "HTTP_TIMEOUT", + "HTTP_USER_AGENT", +} diff --git a/package/argument_resolver/external_function/sink/__init__.py b/package/argument_resolver/external_function/sink/__init__.py new file mode 100644 index 0000000..24b2117 --- /dev/null +++ b/package/argument_resolver/external_function/sink/__init__.py @@ -0,0 +1,24 @@ +from .sink_lists import ( + COMMAND_INJECTION_SINKS, + PATH_TRAVERSAL_SINKS, + STRING_FORMAT_SINKS, + BUFFER_OVERFLOW_SINKS, + ENV_SINKS, + GETTER_SINKS, + SETTER_SINKS, + STRCAT_SINKS, + MEMCPY_SINKS, + Sink, +) + +VULN_TYPES = { + "cmdi": COMMAND_INJECTION_SINKS, + "path": PATH_TRAVERSAL_SINKS, + "strfmt": STRING_FORMAT_SINKS, + "overflow": BUFFER_OVERFLOW_SINKS, + "strcat": STRCAT_SINKS, + "env": ENV_SINKS, + "getter": GETTER_SINKS, + "setter": SETTER_SINKS, + "memcpy": MEMCPY_SINKS, +} diff --git a/package/argument_resolver/external_function/sink/sink_lists.py b/package/argument_resolver/external_function/sink/sink_lists.py new file mode 100644 index 0000000..b27a658 --- /dev/null +++ b/package/argument_resolver/external_function/sink/sink_lists.py @@ -0,0 +1,103 @@ +from typing import List + +from dataclasses import dataclass + + +@dataclass +class Sink: + name: str + vulnerable_parameters: List[int] + + +COMMAND_INJECTION_SINKS: List[Sink] = [ + Sink(name="system", vulnerable_parameters=[1]), + Sink(name="twsystem", vulnerable_parameters=[1]), + Sink(name="execFormatCmd", vulnerable_parameters=[1]), + Sink(name="exec_cmd", vulnerable_parameters=[1]), + Sink(name="___system", vulnerable_parameters=[1]), + Sink(name="bstar_system", vulnerable_parameters=[1]), + Sink(name="doSystemCmd", vulnerable_parameters=[1]), + Sink(name="doShell", vulnerable_parameters=[1]), + Sink(name="CsteSystem", vulnerable_parameters=[1]), + Sink(name="cgi_deal_popen", vulnerable_parameters=[1]), + Sink(name="ExeCmd", vulnerable_parameters=[1]), + Sink(name="ExecShell", vulnerable_parameters=[1]), + Sink(name="exec_shell_popen", vulnerable_parameters=[1]), + Sink(name="exec_shell_popen_str", vulnerable_parameters=[1]), + Sink(name="popen", vulnerable_parameters=[1]), + Sink(name="execl", vulnerable_parameters=[1]), + Sink(name="execlp", vulnerable_parameters=[1]), + Sink(name="execle", vulnerable_parameters=[1]), + Sink(name="execv", vulnerable_parameters=[1]), + Sink(name="execvp", vulnerable_parameters=[1]), + Sink(name="execvpe", vulnerable_parameters=[1]), + Sink(name="execve", vulnerable_parameters=[1]), + Sink(name="tp_systemEx", vulnerable_parameters=[1]), + Sink(name="exec_shell_async", vulnerable_parameters=[1]), + Sink(name="exec_shell_sync", vulnerable_parameters=[1]), + Sink(name="exec_shell_sync2", vulnerable_parameters=[1]), + Sink(name="SLIBCSystem", vulnerable_parameters=[1]), + Sink(name="SLIBCExecl", vulnerable_parameters=[2]), + Sink(name="SLIBCExec", vulnerable_parameters=[1]), + Sink(name="SLIBCExecv", vulnerable_parameters=[1]), + Sink(name="SLIBCPopen", vulnerable_parameters=[1]), + Sink(name="pegaSystem", vulnerable_parameters=[1]), +] + +PATH_TRAVERSAL_SINKS: List[Sink] = [ + Sink(name="popen", vulnerable_parameters=[1]), + Sink(name="fopen", vulnerable_parameters=[1]), +] +# Sink(name="openat", vulnerable_parameters=[1]), +# Sink(name="creat", vulnerable_parameters=[1]), + +BUFFER_OVERFLOW_SINKS: List[Sink] = [ + # Sink(name="strcat", vulnerable_parameters=[2]), + Sink(name="strcpy", vulnerable_parameters=[2]), + # Sink(name="memcpy", vulnerable_parameters=[2]), + # Sink(name="gets", vulnerable_parameters=[1]), +] + +STRCAT_SINKS: List[Sink] = [ + Sink(name="strcat", vulnerable_parameters=[2]), +] + +MEMCPY_SINKS: List[Sink] = [ + Sink(name="memcpy", vulnerable_parameters=[2]), +] + +STRING_FORMAT_SINKS: List[Sink] = [ + Sink(name="sprintf", vulnerable_parameters=[2]), + Sink(name="snprintf", vulnerable_parameters=[3]), +] + +GETTER_SINKS: List[Sink] = [ + Sink(name="getenv", vulnerable_parameters=[1]), + Sink(name="GetValue", vulnerable_parameters=[1]), + Sink(name="acosNvramConfig_get", vulnerable_parameters=[1]), + Sink(name="acosNvramConfig_read", vulnerable_parameters=[1]), + Sink(name="nvram_get", vulnerable_parameters=[1]), + Sink(name="nvram_safe_get", vulnerable_parameters=[1]), + Sink(name="bcm_nvram_get", vulnerable_parameters=[1]), + Sink(name="envram_get", vulnerable_parameters=[1]), + Sink(name="wlcsm_nvram_get", vulnerable_parameters=[1]), + Sink(name="dni_nvram_get", vulnerable_parameters=[1]), + Sink(name="PTI_nvram_get", vulnerable_parameters=[1]), +] + +SETTER_SINKS: List[Sink] = [ + Sink(name="setenv", vulnerable_parameters=[2]), + Sink(name="SetValue", vulnerable_parameters=[1]), + Sink(name="httpSetEnv", vulnerable_parameters=[1]), + Sink(name="acosNvramConfig_set", vulnerable_parameters=[2]), + Sink(name="acosNvramConfig_write", vulnerable_parameters=[2]), + Sink(name="nvram_set", vulnerable_parameters=[2]), + Sink(name="nvram_safe_set", vulnerable_parameters=[2]), + Sink(name="bcm_nvram_set", vulnerable_parameters=[2]), + Sink(name="envram_set", vulnerable_parameters=[2]), + Sink(name="wlcsm_nvram_set", vulnerable_parameters=[2]), + Sink(name="dni_nvram_set", vulnerable_parameters=[2]), + Sink(name="PTI_nvram_set", vulnerable_parameters=[2]), +] + +ENV_SINKS: List[Sink] = GETTER_SINKS + SETTER_SINKS diff --git a/package/argument_resolver/formatters/__init__.py b/package/argument_resolver/formatters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/package/argument_resolver/formatters/closure_formatter.py b/package/argument_resolver/formatters/closure_formatter.py new file mode 100644 index 0000000..67f4288 --- /dev/null +++ b/package/argument_resolver/formatters/closure_formatter.py @@ -0,0 +1,592 @@ +import string +from argument_resolver.external_function.sink.sink_lists import ( + GETTER_SINKS, + SETTER_SINKS, +) + +from typing import List + +from angr.knowledge_plugins.key_definitions.atoms import Atom, Register, SpOffset +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.live_definitions import ( + LiveDefinitions, + DerefSize, +) + +import claripy +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues + +from argument_resolver.formatters.log_formatter import CustomFormatter +from argument_resolver.utils.stored_function import StoredFunction +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.call_trace import traces_to_sink +import re + + +class ClosureFormatter: + def __init__(self, project, cc_resolver): + self.project = project + self.calling_convention_resolver = cc_resolver + self.depth_colors = [ + CustomFormatter.grey, + CustomFormatter.green, + CustomFormatter.blue, + CustomFormatter.yellow, + ] + self.func = None + + def log_function( + self, + stored_func: StoredFunction, + target_atom: Atom = None, + target_defn: Definition = None, + ): + reg_strs = [] + all_resolved = True + depth = stored_func.depth + + function = stored_func.function + call_insn = stored_func.code_loc.ins_addr or stored_func.code_loc.block_addr + + start = function.name + "(" + spacing = len(start) + strs, resolved = self.args_to_str( + stored_func, spacing, target_atom=target_atom, target_defn=target_defn + ) + all_resolved &= resolved + reg_strs.extend(strs) + log_output = [start] + for idx, reg_str in enumerate(reg_strs): + log_output.append( + reg_str + "," if idx != len(reg_str) - 1 or len(reg_strs) > 1 else "" + ) + + ret_str, _ = self.get_ret_str(stored_func) + out_str = " " * spacing + f") @ {hex(call_insn)}" + out_str += f" -> {ret_str}" if ret_str else "" + log_output.append(out_str) + depth_str = "" + for i in range(depth - 1): + depth_str += f"{self.depth_colors[i%len(self.depth_colors)]}|" + if depth_str != "": + depth_str += CustomFormatter.reset + + log_output = [depth_str + CustomFormatter.grey + x for x in log_output] + return log_output, all_resolved + + def sort_args(self, arg): + if isinstance(arg, Register): + if self.func is not None: + cc = self.func.function.calling_convention + if cc is not None: + int_args = [ + self.func.state.arch.registers[x.reg_name][0] + for x in cc.int_args + ] + return int_args.index(arg.reg_offset) + + return arg.reg_offset + else: + if isinstance(arg.addr, SpOffset): + val = arg.addr.offset * -1 + else: + val = arg.addr + return 0x1000 + val + + def args_to_str( + self, stored_func: StoredFunction, spacing: int, target_atom, target_defn + ): + reg_strs = [] + all_resolved = True + self.func = stored_func + for atom in sorted(stored_func.atoms, key=self.sort_args): + if isinstance(atom, Register): + if target_atom == atom: + reg_str = f"{CustomFormatter.reset}{CustomFormatter.blue}{stored_func.state.arch.register_names[atom.reg_offset]}{CustomFormatter.reset}: " + else: + reg_str = ( + f"{stored_func.state.arch.register_names[atom.reg_offset]}: " + ) + else: + reg_str = f"{atom}: " + + if ( + target_defn is not None + and isinstance(target_atom, Register) + and target_atom != atom + ): + vals, resolved = self.format_multivalue_output(stored_func, atom) + else: + vals, resolved = self.format_multivalue_output( + stored_func, atom, target_defn=target_defn + ) + if not resolved: + all_resolved = False + reg_str += " | ".join( + val if isinstance(val, str) else val.decode("latin-1") for val in vals + ) + reg_str = " " * spacing + reg_str + reg_strs.append(reg_str) + return reg_strs, all_resolved + + def get_ret_str(self, stored_func: StoredFunction): + if ( + stored_func.function.prototype is None + or stored_func.function.prototype.returnty is None + or stored_func.ret_val is None + ): + return "", 0 + + values = Utils.get_values_from_multivalues(stored_func.ret_val) + ret_val = list(values) if len(values) > 1 else values[0] + ret_str = f"{CustomFormatter.yellow}{ret_val}{CustomFormatter.reset}" + length = len(str(ret_val)) + return ret_str, length + + @staticmethod + def format_multivalue_output(stored_func, atom, target_defn=None): + return_vals = [] + resolved = True + mv = stored_func.arg_vals[atom] + defns = list(LiveDefinitions.extract_defs_from_mv(mv)) + if len(defns) == 0: + resolved = False + return [str(x) for x in Utils.get_values_from_multivalues(mv)], resolved + + for defn in defns: + new_mv = MultiValues() + for offset, vals in mv.items(): + for val in vals: + val_defns = set(LiveDefinitions.extract_defs(val)) + if defn not in val_defns: + continue + new_mv.add_value(offset, val) + if atom not in stored_func.constant_data: + constant_data = None + else: + constant_data = stored_func.constant_data[atom] + + if ( + 0 in new_mv + and constant_data is not None + and all(x is not None for x in constant_data) + ): + pointer = new_mv.one_value() + is_pointer = True + try: + sp = Utils.get_sp(stored_func.state) + except AssertionError: + sp = stored_func.state.arch.initial_sp + if pointer is not None and not Utils.is_pointer( + pointer, sp, stored_func.state.analysis.project + ): + is_pointer = False + + for value in constant_data: + str_bytes = Utils.bytes_from_int(value) + if all( + x in bytes(string.printable, "ascii") or x == 0 + for x in str_bytes + ) and not all(x == 0 for x in str_bytes): + if is_pointer and pointer is not None: + output_str = str(pointer) + f' -> {CustomFormatter.green}"' + else: + output_str = '"' + output_str += ( + str_bytes.decode("latin-1") + .replace("\n", "\\n") + .replace("\r", "\\r") + ) + output_str += f'"{CustomFormatter.grey}' + return_vals.append(output_str) + else: + if is_pointer: + output_str = f"{pointer} -> {CustomFormatter.green}{value}{CustomFormatter.reset}{CustomFormatter.grey}" + else: + output_str = f"{pointer}" + return_vals.append(output_str) + else: + resolved = False + + for val in Utils.get_values_from_multivalues(mv): + if defn in {x for x in LiveDefinitions.extract_defs(val)}: + arg_str = str(val) + if stored_func.state.is_stack_address(val): + offset = stored_func.state.get_stack_offset(val) + if offset is None: + if hex(0xDEADC0DE) in str(val): + arg_str = f"{CustomFormatter.light_blue}ARGV{CustomFormatter.reset}[?]" + + elif offset < 0: + offset += 2**stored_func.state.arch.bits + + if ( + offset is not None + and 0xDEADC0DE < offset < 0xDEADC0DE + 0x100 * 11 + ): + idx = (offset - 0xDEADC0DE) // 0x100 + change = offset - (0xDEADC0DE + 0x100 * idx) + if change == 0: + arg_str = f"{CustomFormatter.light_blue}ARGV{CustomFormatter.reset}[{idx - 1}]" + else: + arg_str = f"{CustomFormatter.light_blue}ARGV{CustomFormatter.reset}[{idx - 1}] + {hex(change)}" + + return_vals.append(arg_str) + + if Utils.is_pointer( + val, + sp=stored_func.state.arch.initial_sp, + project=stored_func.state.analysis.project, + ): + symbolic_vals = Utils.get_strings_from_pointer( + val, stored_func.state, stored_func.code_loc + ) + + str_list = [] + if symbolic_vals.count() > 5: + arg_str = str(symbolic_vals) + else: + for v in Utils.get_values_from_multivalues( + symbolic_vals, pretty=True + ): + if v.symbolic: + if len(v.args) > 1: + arg_list = "" + if all( + isinstance(arg, claripy.ast.Base) + for arg in v.args + ): + for arg in v.args: + if arg.concrete: + arg_list += ( + Utils.bytes_from_int( + arg + ).decode("latin-1") + ) + else: + arg_list += str(arg) + + else: + arg_list = str(v) + str_list.append(arg_list) + else: + str_list.append(str(v)) + else: + str_list.append( + Utils.bytes_from_int(v).decode("latin-1") + ) + + if str_list: + arg_str += ( + f" -> {CustomFormatter.blue}" + + " | ".join('"' + x + '"' for x in str_list) + + CustomFormatter.reset + ) + elif not isinstance(val.args[0], str): + arg_str = ( + '"' + + "".join( + [ + Utils.bytes_from_int(x).decode("latin-1") + if isinstance(x, claripy.ast.BV) and x.concrete + else str(x) + for x in val.args + ] + ) + + '"' + ) + + return_vals.append(arg_str) + return set(return_vals), resolved + + @staticmethod + def filter_trace(closure: "Closure") -> List[StoredFunction]: + sink_closure_defns = { + defn + for atom in closure.handler._sink_atoms + for defn in closure.sink_trace.closures[atom] + } + caller_addrs = { + x.caller_func_addr for x in closure.rda.subject.content.callsites + } + trace_list = [] + for stored_func in closure.handler.analyzed_list[::-1]: + if stored_func.function.addr in caller_addrs: + trace_list.append(stored_func) + elif any( + defn in sink_closure_defns + for defn in stored_func.definitions | stored_func.return_definitions + ): + trace_list.append(stored_func) + sink_closure_defns |= { + defn for defns in stored_func.closures.values() for defn in defns + } + + return trace_list[::-1] + + @staticmethod + def strip_non_letters_from_ends(s): + return re.sub(r"^[^a-zA-Z]+|[^a-zA-Z]+$", "", s) + + @staticmethod + def get_value_from_env(key, func_name, env_dict, keyword_dict): + func_name = ( + func_name.replace("get", "set") + .replace("read", "write") + .replace("Get", "Set") + ) + sources = [] + if key == "ARGV": + return sources + elif key == "stdin": + return sources + + bad_key = False + if not env_dict or key not in env_dict or func_name not in env_dict[key]: + bad_key = True + if keyword_dict and key in keyword_dict: + bad_key = False + + if bad_key: + return ["Keywords: None", "UNKNOWN"] + + keywords = [] + values = [] + if ( + key in env_dict + and func_name in env_dict[key] + and func_name != "frontend_param" + ): + for bin_name, value_dict in env_dict[key][func_name].items(): + key_vals = [ + f"{bin_name} - {func_name}({key})@{', '.join(val['locations'])}" + for val in value_dict["values"] + if val["value"] == "TOP" + ] + if key_vals: + keywords = list(set(keywords) | set(value_dict["keywords"])) + values.extend(key_vals) + else: + sources += [f"Keyword Source: {key} - {keyword_dict[key]}"] + + for keyword in keywords: + if keyword in keyword_dict: + sources += [f"Keyword Source: {keyword} - {keyword_dict[keyword]}"] + # sources += [f"Keywords: {', '.join(keywords) if keywords else 'None'}"] + sources.extend(values) + return sources + + def get_source_from_env_dict(self, stored_func, env_dict, keyword_dict): + sources = [] + key = set() + + if not env_dict: + return sources, key + + setter_name = ( + stored_func.function.name.replace("get", "set") + .replace("read", "write") + .replace("Get", "Set") + ) + if all(x.name != setter_name for x in SETTER_SINKS): + return sources, key + + for atom, values in stored_func.constant_data.items(): + if values is None: + continue + for val in values: + if val is None: + continue + try: + val_string = Utils.bytes_from_int(val).decode() + except UnicodeDecodeError: + continue + + key.add(val_string) + sources.extend( + self.get_value_from_env( + val_string, setter_name, env_dict, keyword_dict + ) + ) + + return sources, key + + def format_unresolved_closures( + self, + bin_name, + closure, + c_dict, + defn, + excluded_functions, + input_sources, + env_dict, + keyword_dict, + limit_output=False, + ): + output_list = [] + analyzed_list = c_dict["analyzed_list"] + trace_output = [] + for stored_func in analyzed_list: + if stored_func != closure.sink_trace: + if limit_output: + continue + output, all_resolved = self.log_function(stored_func) + sources, key = self.get_source_from_env_dict( + stored_func, env_dict, keyword_dict + ) + if sources: + pre_str = output[0][: output[0].index(stored_func.function.name)] + for source in sources[1:]: + output.insert( + 0, + pre_str + + ClosureFormatter.set_line_color( + source, CustomFormatter.blue, 0 + ), + ) + + output.insert( + 0, + pre_str + + ClosureFormatter.set_line_color( + f"SOURCES: {key}", CustomFormatter.blue, 0 + ), + ) + + else: # Found the sink + output, all_resolved = self.log_function(stored_func, target_defn=defn) + offset = output[0].find(stored_func.function.name) + output = [ + ClosureFormatter.set_line_color( + line, CustomFormatter.bold_red, offset + ) + for line in output + ] + trace_output.append(output) + break + + trace_output.append(output) + + trace_output.append(["", f"BINARY: {bin_name}", "INPUT SOURCES:"]) + likely_sources = input_sources["sources"]["likely"] + possibly_sources = input_sources["sources"]["possibly"] + if likely_sources or possibly_sources: + input_strings = [] + input_strings.append("[bold purple]Likely:") + if likely_sources: + for key, group in likely_sources.items(): + keys = key.strip('"').split(" | ") + func = group[-1].split("(")[0] + input_strings.append("[bold purple]" + "-" * 10) + input_strings.append("[bold purple]" + f'KEY: "{", ".join(keys)}"') + for sub_key in keys: + for idx, source in enumerate( + self.get_value_from_env( + sub_key, func, env_dict, keyword_dict + ) + ): + if source.startswith("Key"): + input_strings.append( + f"[bold purple]{source}" + ) + else: + input_strings.append( + f"[bold purple]Binary Source: {source}" + ) + input_strings.extend( + "[bold purple]" + f"Sink: {x}" for x in group + ) + input_strings.append("[bold purple]" + "-" * 10) + else: + input_strings.append("[bold purple]" + "NONE") + + if not limit_output: + input_strings.append("[bold #808080]Possibly:") + if possibly_sources: + for key, group in possibly_sources.items(): + keys = key.strip('"').split(" | ") + func = group[-1].split("(")[0] + input_strings.append("[bold #808080]" + "-" * 10) + input_strings.append( + "[bold #808080]" + f'KEY: "{", ".join(keys)}"' + ) + for sub_key in keys: + for idx, source in enumerate( + self.get_value_from_env( + sub_key, func, env_dict, keyword_dict + ) + ): + if idx == 0: + input_strings.append(f"[bold #808080]{source}") + else: + input_strings.append( + "[bold #808080]" + f"Binary Source - {source}" + ) + input_strings.extend("[bold #808080]" + x for x in group) + input_strings.append("[bold #808080]" + "-" * 10) + else: + input_strings.append("[bold #808080]NONE") + + input_strings.append( + CustomFormatter.yellow + f"RANK: {input_sources['rank']:.3f}" + ) + else: + input_strings = [CustomFormatter.bold_blue + "UNKNOWN"] + trace_output.append(input_strings) + output_list.append([y for x in trace_output for y in x]) + project = closure.rda.project + if ( + "main" in project.kb.functions + and closure.handler.analyzed_list[0].function.addr + != project.kb.functions["main"].addr + ): + traces = traces_to_sink( + closure.sink_trace.function, + project.kb.functions.callgraph, + max_depth=12, + excluded_functions=excluded_functions, + ) + traces = { + t + for t in traces + if all(x in t.callsites for x in closure.rda.subject.content.callsites) + } + output_list[-1].insert(0, "^" * 50) + for trace in traces: + output_list[-1].insert( + 0, + f"TRACE: {'->'.join(project.kb.functions[x.caller_func_addr].name for x in reversed(trace.callsites))}->{closure.sink_trace.function.name}", + ) + + return output_list + + @staticmethod + def set_line_color(line: str, color, offset): + if offset == 0: + return color + line + CustomFormatter.reset + else: + reset = "" + max_val = "" + max_idx = 0 + for fmt_color in CustomFormatter.__dict__.values(): + if not isinstance(fmt_color, str): + continue + color_idx = line.rfind(fmt_color, 0, offset) + if color_idx > max_idx: + max_val = fmt_color + max_idx = color_idx + if max_val != CustomFormatter.reset: + reset = CustomFormatter.reset + return line[:offset] + reset + color + line[offset:] + CustomFormatter.reset + + @staticmethod + def log_closures_for_sink(output_list: List[List[str]], sink, logger): + if output_list: + logger.critical(CustomFormatter.bold_red + "*" * 50) + logger.critical( + "%sUNRESOLVED CLOSURES to %s: ", CustomFormatter.bold_red, sink.name + ) + for output in output_list: + logger.critical(CustomFormatter.bold_red + "-" * 50) + for line in output: + logger.critical(line) diff --git a/package/argument_resolver/formatters/log_formatter.py b/package/argument_resolver/formatters/log_formatter.py new file mode 100644 index 0000000..59ab90a --- /dev/null +++ b/package/argument_resolver/formatters/log_formatter.py @@ -0,0 +1,119 @@ +import datetime +import logging + +from pathlib import Path + +from rich.logging import RichHandler +from rich.highlighter import NullHighlighter +from rich.console import Console +from rich.progress import TextColumn +from rich.text import Text + + +class CustomFormatter(logging.Formatter): + + grey = "[white]" + green = "[green]" + blue = "[blue]" + bold_blue = "[bold blue]" + light_blue = "[#00ffff]" + yellow = "[yellow]" + red = "[red]" + bold_red = "[bold red]" + reset = "[/]" + # format = "%(levelname)s | %(asctime)s | %(name)s | %(message)s" + + FORMATS = { + logging.DEBUG: green, + logging.INFO: grey, + logging.WARNING: yellow, + logging.ERROR: red, + logging.CRITICAL: bold_red, + } + + def format(self, record): + log_color = self.FORMATS.get(record.levelno) + level = record.levelname + " " + log_str = level.ljust(10) + log_str += f"| {datetime.datetime.fromtimestamp(record.created).strftime('%Y-%m-%d %H:%M:%S')}," + log_str += str(record.lineno).ljust(3, " ") + log_str += f" |{log_color} " + log_str += record.name + log_str += f" {self.reset}|{log_color} " + log_str += record.getMessage() + log_str += self.reset + return log_str + + +class CustomColorRichHandler(RichHandler): + def render_message(self, record, message: str) -> "ConsoleRenderable": + msg = super().render_message(record, message) + color = ( + CustomFormatter.FORMATS.get(record.levelno) + .replace("[", "") + .replace("]", "") + ) + msg.style = color + return msg + + +class CustomPathRichHandler(CustomColorRichHandler): + def emit(self, record): + record.pathname = "" + super().emit(record) + + +def make_logger(log_level=logging.INFO, should_debug=False): + log = logging.getLogger("FastFRUIT") + log.setLevel(logging.DEBUG) + if log.handlers: + return log + + log.propagate = False + debug_file = Path("/tmp/mango.out") + + if should_debug: + if debug_file.exists(): + debug_file.unlink() + console = Console(file=debug_file.open("a+"), force_terminal=True) + log.addHandler( + CustomPathRichHandler( + level=log_level, + console=console, + highlighter=NullHighlighter(), + markup=True, + rich_tracebacks=True, + keywords=[], + ) + ) + log.addHandler( + CustomPathRichHandler( + level=log_level, + highlighter=NullHighlighter(), + markup=True, + keywords=[], + rich_tracebacks=True, + ) + ) + + return log + + +class CustomTextColumn(TextColumn): + """A column containing text.""" + + def render(self, task: "Task") -> Text: + if task.total is None: + if task.completed == 0: + _text = "" + else: + _text = f"{task.completed}" + else: + _text = self.text_format.format(task=task) + if self.markup: + text = Text.from_markup(_text, style=self.style, justify=self.justify) + else: + text = Text(_text, style=self.style, justify=self.justify) + if self.highlighter: + self.highlighter.highlight(text) + return text diff --git a/package/argument_resolver/formatters/results_formatter.py b/package/argument_resolver/formatters/results_formatter.py new file mode 100644 index 0000000..d4b89c3 --- /dev/null +++ b/package/argument_resolver/formatters/results_formatter.py @@ -0,0 +1,181 @@ +import hashlib +import json +import os +import pickle +import re +from pathlib import Path +from typing import Dict, Set, Tuple + +import networkx +from networkx.drawing.nx_agraph import write_dot +from networkx.exception import NetworkXNoPath +from rich.console import Console + +import angr +from angr.code_location import ExternalCodeLocation +from angr.knowledge_plugins.functions.function import Function +from angr.knowledge_plugins.key_definitions.atoms import Register +from angr.knowledge_plugins.key_definitions.definition import Definition +from argument_resolver.utils.closure import Closure +from argument_resolver.utils.stored_function import StoredFunction + + +def save_closure( + project: angr.Project, + cfg_time: float, + vra_time: float, + mango_time: float, + closure: Closure, + closure_info: Dict, + execv_dict: Dict, + result_path: Path, + time_data: Dict[Tuple[int], Dict[str, float]], + total_sinks=None, + has_sinks=True, + sink_time=0, + category=None, +): + res_name = f"{category}_results.json" if category is not None else "results.json" + result_file = result_path / res_name + + file = Path(project.filename) + + if result_file.exists(): + closure_dict = json.loads(result_file.read_text()) + else: + with file.open("rb") as f: + sha_sum = hashlib.file_digest(f, "sha256").hexdigest() + + closure_dict = { + "closures": [], + "cfg_time": cfg_time, + "vra_time": vra_time, + "path": str(file.absolute()), + "name": file.name, + "has_sinks": has_sinks, + "sha256": sha_sum, + "sink_times": {}, + "error": None, + } + + closure_dict["mango_time"] = ( + sum(mango_time) if isinstance(mango_time, list) else mango_time + ) + closure_dict["sinks"] = total_sinks + curr_sink = None + if closure is not None and closure_info is not None: + curr_sink = closure.sink_trace.name + c_d = closure_to_dict(closure, closure_info["external_input"]) + closure_dict["closures"].append(c_d) + closure_dir = result_path / f"{category}_closures" + closure_dir.mkdir(parents=True, exist_ok=True) + parent_func = closure.handler.analyzed_list[0].name + parent_addr = hex(closure.handler.analyzed_list[0].code_loc.block_addr) + sink_addr = hex(closure.sink_trace.code_loc.ins_addr) + file_name = ( + closure_dir + / f"{c_d['rank']:.2f}_{parent_func}_{parent_addr}_{closure.sink_trace.name}_{sink_addr}" + ) + console = Console(file=open(file_name, "w+"), force_terminal=True) + closure_dict["closures"][-1]["reachable_from_main"] = ( + closure.handler.analyzed_list[0].name == "main" + ) + closure_dict["closures"][-1]["sanitized"] = closure_info["sanitized"] + if ( + "main" in project.kb.functions + and closure.handler.analyzed_list[0].name != "main" + ): + try: + path = networkx.shortest_path( + project.kb.callgraph, + project.kb.functions["main"].addr, + closure.handler.analyzed_list[0].function.addr, + ) + closure_dict["closures"][-1]["reachable_from_main"] = True + console.print( + "->".join( + project.kb.functions[x].name + for x in path + + [ + x.caller_func_addr + for x in closure.rda.subject.content.callsites + ] + ) + + "->" + + closure.sink_trace.name + ) + console.print("-" * 50 + "\n") + except NetworkXNoPath: + closure_dict["closures"][-1]["reachable_from_main"] = False + for chunk in closure_info["output"]: + for line in chunk: + console.print(line) + + console.file.close() + + closure_dict["time_data"] = { + " -> ".join(hex(x) for x in k): v for k, v in time_data.items() + } + if curr_sink is not None: + if "sink_times" not in closure_dict: + closure_dict["sink_times"] = {} + closure_dict["sink_times"][curr_sink] = sink_time + + with open(result_file, "w+") as f: + json.dump(closure_dict, f, indent=4) + os.chmod(result_file, 0o666) + + if category == "cmdi": + with open(result_path / "execv.json", "w+") as f: + json.dump( + { + "execv": execv_dict, + "name": file.name, + "sha256": closure_dict["sha256"], + }, + f, + indent=4, + ) + + +def closure_to_dict(closure: Closure, input_sources): + caller_addrs = {x.caller_func_addr for x in closure.rda.subject.content.callsites} + subject_funcs = [ + stored_func + for stored_func in closure.handler.analyzed_list + if stored_func.function.addr in caller_addrs + ] + sink = closure.sink_trace + trace = [_stored_func_to_dict(stored_func) for stored_func in subject_funcs] + + closure_dict = { + "trace": trace, + "sink": _stored_func_to_dict(sink), + "depth": sink.depth - 1, + "inputs": {k: list(v) for k, v in input_sources["sources"].items()}, + "rank": input_sources["rank"], + } + return closure_dict + + +def _stored_func_to_dict(stored_func: StoredFunction): + return { + "function": stored_func.function.name, + "string": str(stored_func), + "ins_addr": hex( + stored_func.code_loc.ins_addr or stored_func.code_loc.block_addr + ), + } + + +def save_graph(graph: networkx.DiGraph, filename: str, result_path: Path): + """ + Save a graph on disk under two representations: serialized, and as an image. + """ + path_and_filename = str(result_path / filename) + + with open(f"{path_and_filename}.pickle", "wb") as result_file: + pickle.dump(graph, result_file) + + write_dot(graph, f"{path_and_filename}.dot") + os.system(f"dot -Tsvg -o {path_and_filename}.svg {path_and_filename}.dot") diff --git a/package/argument_resolver/handlers/README.md b/package/argument_resolver/handlers/README.md new file mode 100644 index 0000000..e46e4c9 --- /dev/null +++ b/package/argument_resolver/handlers/README.md @@ -0,0 +1,30 @@ +## Adding New Handlers +Create an x.py file in the following format: +```python +import logging + +from .base import HandlerBase + +LOGGER = logging.getLogger("handlers.yourclass.h") + +class YourClass(HandlerBase): + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_funcname( + self, + state: "ReachingDefinitionsState", + codeloc: "CodeLocation", + ): + """ + :param LiveDefinitions state:: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + ... + ... + ... + return True, state + +``` +Each handler should be created as function following `handle_funcname` i.e. `handle_strcmp` +Each handler function should return `True` if analyzed and the `state` \ No newline at end of file diff --git a/package/argument_resolver/handlers/__init__.py b/package/argument_resolver/handlers/__init__.py new file mode 100644 index 0000000..eb8d6f6 --- /dev/null +++ b/package/argument_resolver/handlers/__init__.py @@ -0,0 +1,9 @@ +from .local_handler import handler_factory, LibraryHandler + +from .nvram import NVRAMHandlers +from .stdio import StdioHandlers +from .stdlib import StdlibHandlers +from .string import StringHandlers +from .unistd import UnistdHandlers +from .network import NetworkHandlers +from .url_param import URLParamHandlers diff --git a/package/argument_resolver/handlers/base.py b/package/argument_resolver/handlers/base.py new file mode 100644 index 0000000..487c7a8 --- /dev/null +++ b/package/argument_resolver/handlers/base.py @@ -0,0 +1,168 @@ +import functools + +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Set + +from angr.analyses.reaching_definitions.function_handler import FunctionHandler, FunctionCallData + +from angr.knowledge_plugins.key_definitions.atoms import Atom, Register +from angr.knowledge_plugins.functions import Function + +from argument_resolver.utils.calling_convention import CallingConventionResolver +from argument_resolver.utils.utils import Utils +from argument_resolver.formatters.log_formatter import make_logger + +from argument_resolver.utils.stored_function import StoredFunction +import claripy + +if TYPE_CHECKING: + from archinfo import Arch + from angr import Project + from angr.analyses.reaching_definitions.rd_state import ( + ReachingDefinitionsState, + Definition, + ) + from angr.code_location import CodeLocation + + +def get_arg_vals(arg_atoms: List[Atom], state: "ReachingDefinitionsState"): + vals = {} + for atom in arg_atoms: + value = state.live_definitions.get_value_from_atom(atom) + if value is not None: + vals[atom] = value + else: + vals[atom] = Utils.unknown_value_of_unknown_size(state, atom, state.current_codeloc) + return vals + + +class HandlerBase(FunctionHandler): + + MAX_READ_SIZE = 0x20 + + def __init__( + self, + project: "Project", + sink_function: "Function" = None, + sink_atoms: List[Atom] = None, + env_dict: Dict = None, + assumed_execution: bool = True, + taint_trace: bool = False, + forward_trace: bool = False, + max_local_call_depth: int = 3, + progress_callback=None + ): + """ + :param project: + :param sink_function: + :param sink_atoms: + """ + self._project = project + self._calling_convention_resolver = None + self._rda = None + self._sink_function_addr = sink_function.addr if sink_function else None + self.call_trace = [] + self.call_stack = [] + self.analyzed_list = [] + self.env_dict = env_dict + self.current_parent: StoredFunction = None + self.in_local_handler = False + self.assumed_execution = assumed_execution + self.taint_trace = taint_trace + self.forward_trace = forward_trace + self.first_run = True + self.max_local_call_depth = max_local_call_depth + self.progress_callback = progress_callback + self.fd_tracker = { + 0: {"val": claripy.BVS('"stdin"', self._project.arch.bits, explicit_name=True), "parent": None, "ins_addr": None}, + 1: {"val": claripy.BVS('"stdout"', self._project.arch.bits, explicit_name=True), "parent": None, "ins_addr": None}, + 2: {"val": claripy.BVS('"stderr"', self._project.arch.bits, explicit_name=True), "parent": None, "ins_addr": None}, + } + for fd_dict in self.fd_tracker.values(): + fd_dict["val"].variables = frozenset(set(fd_dict["val"].variables) | {"TOP"}) + + self.env_access = set() + self.keyword_access = {} + + self._sink_atoms = sink_atoms + self.sink_atom_defs: Dict[Atom, Set["Definition"]] = defaultdict(set) + self.log = make_logger() + + def gen_fd(self): + return max(x for x in self.fd_tracker if isinstance(x, int)) + 1 + + @staticmethod + def _balance_stack_before_returning( + state: "ReachingDefinitionsState", codeloc: "CodeLocation" + ) -> None: + arch: "Arch" = state.arch + if arch.call_pushes_ret: + # pops ret + sp_atom = Register(arch.sp_offset, arch.bytes) + sp_defs = state.get_definitions(sp_atom) + if sp_defs: + sp_def = next(iter(sp_defs)) + sp_data = state.registers.load( + sp_def.atom.reg_offset, size=arch.bytes + ) + state.kill_and_add_definition(sp_atom, sp_data) + + @staticmethod + def returns(func): + @functools.wraps(func) + def wrapped_func( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + *args, + **kwargs, + ): + + analysed, new_state, ret_val = func(self, state, stored_func, *args, **kwargs) + stored_func.handle_ret(new_state=new_state, value=ret_val) + + return analysed, new_state + + return wrapped_func + + @staticmethod + def tag_parameter_definitions(func): + """ + Add a `ParameterTag` to the definitions of the arguments of the function simulated by the handler. + """ + + @functools.wraps(func) + def wrapper(self, state: "ReachingDefinitionsState", data: FunctionCallData): + if data.function is None: + return False, state + stored_func = self.call_trace[-1] + stored_func.tag_params(first_run=self.first_run) + return func(self, state, stored_func) + + return wrapper + + def hook(self, rda): + self._rda = rda + self._calling_convention_resolver = CallingConventionResolver( + rda.project, + rda.project.arch, + rda.kb.functions, + ) + return self + + def handle_external_function(self, state: "ReachingDefinitionsState", data: FunctionCallData): + self.handle_local_function(state, data) + + def handle_function(self, state: "ReachingDefinitionsState", data: FunctionCallData): + depth = self.current_parent.depth if self.current_parent else 0 + stored_func = StoredFunction(state, data, self.call_stack, depth) + if stored_func.function is None: + return + self.call_trace.append(stored_func) + was_first_run = self.first_run + super().handle_function(state, data) + stored_func.return_definitions = state.analysis.function_calls[data.callsite_codeloc].ret_defns + if was_first_run: + stored_func.definitions = set().union(*[state.get_definitions(atom) for atom in stored_func.atoms]) + elif not (stored_func.function.is_plt or stored_func.function.is_simprocedure): + stored_func.definitions = set().union(*state.analysis.function_calls[data.callsite_codeloc].args_defns) diff --git a/package/argument_resolver/handlers/functions/__init__.py b/package/argument_resolver/handlers/functions/__init__.py new file mode 100644 index 0000000..0db5e2d --- /dev/null +++ b/package/argument_resolver/handlers/functions/__init__.py @@ -0,0 +1,16 @@ +from typing import Union + +from .constant_function import ConstantFunction + +CONSTANT_FUNCTIONS = [ + ConstantFunction( + "uname", param_num=1, is_pointer=True, val=b"A" * (0x400 - 1) + b"\x00" + ) +] + + +def get_constant_function(function_name: str) -> Union[None, ConstantFunction]: + for func in CONSTANT_FUNCTIONS: + if func.name == function_name: + return func + return None diff --git a/package/argument_resolver/handlers/functions/constant_function.py b/package/argument_resolver/handlers/functions/constant_function.py new file mode 100644 index 0000000..2883fb0 --- /dev/null +++ b/package/argument_resolver/handlers/functions/constant_function.py @@ -0,0 +1,67 @@ +from claripy import BVV + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation + +from argument_resolver.utils.calling_convention import cc_to_rd +from argument_resolver.utils.utils import Utils + + +class ConstantFunction: + """ + Represent a function that should return a constant value either through a parameter or return value. + """ + + def __init__( + self, + name: str, + param_num=None, + is_ret_val=False, + is_pointer=False, + val=b"CONSTANT", + ): + """ + :param name: The name of the function. + :param param_num: The index of the parameter (starting from 1) that points to the return value. + :param is_ret_val: The index of the parameter (starting from 1) that points to the return value. + :param is_pointer: If the return value should be stored in a memory location. + :param val: The value to be returned or inserted (defaults to "CONSTANT"). + """ + assert param_num is not None or is_ret_val, "Must have one or the other" + + self.name = name + self.param_num = param_num + self.is_ret_val = is_ret_val + self.is_pointer = is_pointer + self.val = val + self.cc = None + + def set_cc(self, calling_convention): + self.cc = calling_convention + + def constant_handler(self, state, stored_func): + assert self.cc is not None + mv = MultiValues(BVV(self.val)) + + if self.is_ret_val: + if state.arch.memory_endness == "Iend_LE": + self.val = reversed(self.val) + stored_func.handle_ret(new_state=state, value=self.val) + return True, state + + if self.param_num: + for _ in range(self.param_num): + sim_arg = self.cc.get_next_arg() + arg = cc_to_rd(sim_arg, state.arch) + values = Utils.get_values_from_cc_arg(sim_arg, state, state.arch) + sources = set() + for val in Utils.get_values_from_multivalues(values): + mem_loc = MemoryLocation(val, Utils.get_size_from_multivalue(mv)) + sources.add(mem_loc) + stored_func.depends(mem_loc, value=mv) + + stored_func.depends(arg, *sources, value=values) + return True, state + + def __repr__(self): + return f"ConstantFunction: {self.name} Constant Param: {self.param_num} Val: {self.val}" diff --git a/package/argument_resolver/handlers/local_handler.py b/package/argument_resolver/handlers/local_handler.py new file mode 100644 index 0000000..23a2dad --- /dev/null +++ b/package/argument_resolver/handlers/local_handler.py @@ -0,0 +1,601 @@ +import networkx as nx + +from typing import Optional, Tuple, Set, List, Type, Union + +from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState +from angr.analyses.reaching_definitions.reaching_definitions import ReachingDefinitionsAnalysis + +from angr.analyses.analysis import AnalysisFactory +from angr.analyses.reaching_definitions.dep_graph import DepGraph +from angr.analyses.reaching_definitions.subject import Subject +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.constants import OP_AFTER, OP_BEFORE +from angr.code_location import CodeLocation + +from .base import HandlerBase +from argument_resolver.utils.stored_function import StoredFunction + +from .nvram import NVRAMHandlers +from .stdio import StdioHandlers +from .stdlib import StdlibHandlers +from .string import StringHandlers +from .unistd import UnistdHandlers + +from ..utils.rda import CustomRDA +from ..utils.call_trace_visitor import CallTraceSubject +from .functions import get_constant_function +import time +from collections import deque + +LibraryHandler = Union[ + NVRAMHandlers, + StdioHandlers, + StdlibHandlers, + StringHandlers, + UnistdHandlers, +] + + +class LocalHandler(HandlerBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.white_list = [] + self.ReachingDefinitionsAnalysis = AnalysisFactory(self._project, CustomRDA) + self.triggered = False + self.external_input = False + self.external_list = [] + + @HandlerBase.tag_parameter_definitions + def handle_local_function( + self, + state: ReachingDefinitionsState, + current_func: StoredFunction, + first_run: bool = False, + ): + """ + Handles local functions during RDA + :return: StateChange: bool, state, visited_blocks, dep_graph + """ + if self._rda.rda_timeout != 0 and self._rda.start_time is not None and (time.time() - self._rda.start_time) > self._rda.rda_timeout: + raise TimeoutError("RDA Timeout") + + if self.first_run: + self.first_run = False + return + + elif self.current_parent is None and not self.taint_trace: + return + + elif current_func.name == self.call_trace[-2].name and current_func.function.is_simprocedure and self.call_trace[-2].function.is_plt: + self.call_trace.pop(-1) + return + + if self.progress_callback and self.current_parent is not None: + self.progress_callback(self.current_parent.name, hex(self.current_parent.code_loc.ins_addr or self.current_parent.function.addr), current_func.name, hex(current_func.code_loc.ins_addr or current_func.function.addr)) + + # Either going one callsite deeper or hit the final sink + if self.hit_depth_change(current_func): + self.log.debug("Hit Depth Change at: %s", current_func) + rda_tuple = self.attempt_reanalysis(current_func) + if rda_tuple is not None: + return rda_tuple + + if not current_func.exit_site_addresses and not first_run: + current_func.handle_ret(current_func.state) + return current_func.success_tuple + + should_analyze, next_subject, analyzed_idx = self.should_run_analysis(current_func) + + if should_analyze: + if analyzed_idx: + self.analyzed_list[analyzed_idx] = current_func + else: + self.analyzed_list.append(current_func) + if self.forward_trace: + for idx, x in enumerate(self.white_list.copy()): + if x.code_loc.ins_addr == current_func.code_loc.ins_addr: + self.white_list.pop(idx) + break + + current_func.save_constant_arg_data(state) + rda_tup = self.run_rda(current_func, next_subject, first_run=first_run) + if self.call_stack[-1] == current_func: + self.call_stack.pop() + return rda_tup + + current_func.handle_ret() + return current_func.success_tuple + + def generate_taint_list(self, stored_func: StoredFunction): + white_list = self.generate_whitelist(stored_func)[0] + self.white_list += white_list + stored_func.save_constant_arg_data(stored_func.state) + analyzed = set() + analyze_queue = [(x, 0) for x in white_list.copy()] + for x in white_list: + x.save_closures() + while analyze_queue: + func, depth = analyze_queue.pop() + if func in analyzed or depth > 2: + continue + + analyzed.add(func) + + if func.function.addr in {stored_func.function.addr, self.current_parent.function.addr if self.current_parent is not None else 0}: + continue + + if self.is_handled(func.name): + continue + + new_graph = DepGraph() + old_graph = func.state.dep_graph + func.state.analysis._dep_graph = new_graph + last_idx = len(self.call_trace) + self.run_rda(func, CallTraceSubject(func.subject.content, func.function), first_run=False) + + old_parent = self.current_parent + self.current_parent = func + + new_trace = self.call_trace[last_idx:] + new_trace.insert(0, func) + old_trace = self.call_trace[:last_idx] + starting_idx = 0 + while starting_idx < len(old_trace): + for idx in range(starting_idx, len(old_trace)): + if old_trace[idx] == func: + old_trace = old_trace[:idx + 1] + new_trace[1:] + old_trace[idx + 1:] + starting_idx = idx + len(new_trace) + break + else: + break + + self.call_trace = new_trace + + for defn in func.definitions: + if defn not in new_graph.graph: + for node in [x for x in new_graph.graph.nodes() if + x.codeloc.ins_addr == func.code_loc.ins_addr and x.atom == defn.atom]: + new_graph.add_edge(defn, node) + white_list = self.generate_whitelist(func)[0] + white_list.remove(func) + for x in white_list: + x.save_closures() + analyze_queue.extend([(x, depth + 1) for x in white_list]) + self.white_list += white_list + self.current_parent = old_parent + self.call_trace = old_trace + old_graph.graph.add_nodes_from(new_graph.graph.nodes()) + old_graph.graph.add_edges_from(new_graph.graph.edges()) + + func.state.analysis._dep_graph = old_graph + + def should_run_analysis(self, stored_func: StoredFunction) -> Tuple[bool, Subject, Optional[int]]: + if self.taint_trace: + if any(x.code_loc.ins_addr == stored_func.code_loc.ins_addr for x in self.call_stack): + return False, stored_func.subject, None + self.log.debug("Tainting: %s", stored_func) + if stored_func.function.addr == stored_func.subject.content.target: + self.generate_taint_list(stored_func) + + if any(x.caller_func_addr == stored_func.function.addr for x in stored_func.subject.content.callsites) and not any(x.function.addr == stored_func.function.addr for x in self.call_stack): + self.call_stack.append(stored_func) + return True, stored_func.subject, None + else: + return False, stored_func.subject, None + + if not self.assumed_execution: + self.log.debug("Analyzing %s", stored_func) + if hasattr(self, f"handle_{stored_func.function.name}"): + return True, CallTraceSubject(stored_func.subject.content, self.current_parent.function), None + else: + return True, CallTraceSubject(stored_func.subject.content, stored_func.function), None + + white_list_func = [x for x in self.white_list if x.code_loc.ins_addr == stored_func.code_loc.ins_addr] + if white_list_func: + try: + analyzed_idx = self.analyzed_list.index(white_list_func[0]) + except ValueError: + analyzed_idx = None + if any(x.code_loc.ins_addr == stored_func.code_loc.ins_addr for x in self.call_stack): + self.log.debug("Avoiding Recursion %s", stored_func) + return False, stored_func.subject, None + + self.log.debug("Analyzing %s", stored_func) + if hasattr(self, f"handle_{stored_func.function.name}"): + if self.current_parent is None: + cfg = self._project.kb.cfgs.get_most_accurate() + node = cfg.get_any_node(stored_func.code_loc.block_addr) + func = self._project.kb.functions[node.function_address] + else: + func = self.current_parent.function + return True, CallTraceSubject(stored_func.subject.content, func), analyzed_idx + else: + return True, CallTraceSubject(stored_func.subject.content, stored_func.function), None + self.log.debug("Skipping %s", stored_func) + return False, stored_func.subject, None + + def is_handled(self, function_name: str) -> bool: + return hasattr(self, f"handle_{function_name.replace('__isoc99_', '')}") + + def run_rda(self, stored_func: StoredFunction, subject: Subject, is_reanalysis=False, first_run=False): + old_rda = self._rda + prev_parent = self.current_parent + + visited_blocks = stored_func.visited_blocks + dep_graph = stored_func.state.dep_graph + if len(self.get_trimmed_callstack(stored_func)) >= self.max_local_call_depth: + return stored_func.failed_tuple + + constant_func = get_constant_function(stored_func.name) + if constant_func is not None: + self.call_stack.append(stored_func) + constant_func.set_cc(self._calling_convention_resolver.get_cc(stored_func.name)) + return constant_func.constant_handler(stored_func.state, stored_func) + + elif not is_reanalysis and self.is_handled(stored_func.name): + stored_func._data.effects = [] + return self.handle_simprocedure_function(stored_func, stored_func.state, visited_blocks) + + observation_points = self._rda._observation_points | {("insn", x, OP_AFTER) for x in stored_func.exit_site_addresses} + + rda = self.ReachingDefinitionsAnalysis( + kb=self._rda.kb, + init_state=stored_func.state, + observation_points=observation_points, + subject=subject, + function_handler=self, + start_time=self._rda.start_time, + rda_timeout=self._rda.rda_timeout, + visited_blocks=visited_blocks, + dep_graph=dep_graph, + prev_observed=self._rda.observed_results, + is_reanalysis=is_reanalysis, + ) + + rda_tuple = None + if not first_run and not is_reanalysis and not hasattr(self, f"handle_{stored_func.function.name}") and not self._rda.should_abort: + if not self.assumed_execution and stored_func.function.addr not in {x.caller_func_addr for x in subject.content.callsites}: + prev_white_list = self.white_list.copy() + self.white_list.clear() + rda_tuple = self.attempt_reanalysis(stored_func) + self.white_list = prev_white_list + + if rda_tuple is None: + self._update_old_rda(stored_func.state, old_rda, rda, stored_func) + rda_tuple = (True, stored_func.state, rda.visited_blocks, dep_graph) + + self.current_parent = prev_parent + + if not hasattr(self, f"handle_{stored_func.function.name}") and not is_reanalysis: + stored_func.handle_ret(rda_tuple[1]) + return rda_tuple + + def _update_old_rda(self, + state: ReachingDefinitionsState, + old_rda: ReachingDefinitionsAnalysis, + rda: ReachingDefinitionsAnalysis, + stored_func: StoredFunction): + + self.hook(old_rda) + old_rda.observed_results.update(rda.observed_results) + old_rda.function_calls.update(rda.function_calls) + old_rda._dep_graph = rda.dep_graph + state.analysis = old_rda + try: + old_sp = state.registers.load( + state.arch.sp_offset, size=state.arch.bytes + ) + all_exit_states = [rda.model.observed_results.get(("node", x, OP_AFTER), None) for x in stored_func.exit_site_addresses] + all_exit_states = [x for x in all_exit_states if x is not None] + if len(all_exit_states) > 0: + merged_state = all_exit_states[0] + if len(all_exit_states) > 1: + merged_state = merged_state.merge(all_exit_states[1:]) + state.live_definitions = merged_state + state.registers.store(state.arch.sp_offset, old_sp, size=state.arch.bytes) + except AttributeError: + pass + + def get_trimmed_callstack(self, stored_func): + if self.current_parent is None: + return [] + + if any(x.caller_func_addr == self.current_parent.function.addr for x in stored_func.subject.content.callsites): + return [] + idx = 0 + for idx, func in enumerate(self.call_stack[::-1]): + if any(x.caller_func_addr == func.function.addr for x in stored_func.subject.content.callsites): + break + return [x.function.addr for x in self.call_stack[len(self.call_stack)-idx:]] + + def handle_simprocedure_function(self, + stored_func: StoredFunction, + state: ReachingDefinitionsState, + visited_blocks: set): + + handler = getattr(self, f"handle_{stored_func.function.name.replace('__isoc99_', '')}", None) + self.call_stack.append(stored_func) + if handler is None: + return stored_func.failed_tuple + else: + analyzed, new_state = handler(state, stored_func) + #HandlerBase._balance_stack_before_returning(new_state, stored_func) + if not analyzed: + return stored_func.failed_tuple + + return True, new_state, visited_blocks, new_state.dep_graph + + def attempt_reanalysis(self, stored_func: StoredFunction) -> Optional[Tuple]: + if not self.forward_trace and stored_func.function.addr in {x.caller_func_addr for x in stored_func.subject.content.callsites}: + return None + + if any(x.function.addr == stored_func.function.addr for x in self.call_stack[:-1]): + return None + + if self.white_list or not self.assumed_execution: + if any(x not in self.analyzed_list for x in self.white_list): + return None + + self.white_list, in_subject = self.generate_whitelist(stored_func) + # Re-run analysis on current parent function but with the whitelist this time + if set(self.white_list) < set(self.analyzed_list): + self.white_list = [] + return None + + if set(self.white_list) == {stored_func}: + return None + + if len(self.white_list) > 0: + resume_func, max_index = self.find_resume_function() + self.call_trace = self.call_trace[:len(self.call_trace) - max_index] + + new_subject = CallTraceSubject(self.current_parent.subject.content, self.current_parent.function) + new_subject.visitor.mark_nodes_as_visited({block for block in resume_func.visited_blocks if block.addr != resume_func.code_loc.block_addr}) + resume_func.state._subject = new_subject + self.log.debug("Re-analyzing %s from %s", self.current_parent, resume_func) + + if resume_func == self.current_parent: + for callsite in new_subject.content.callsites: + if callsite.caller_func_addr == self.current_parent.function.addr: + callsite.block_addr = stored_func.code_loc.block_addr + + resume_addr = resume_func.function.addr + else: + resume_addr = resume_func.code_loc.block_addr + + try: + old_state = self._rda.get_reaching_definitions_by_node(resume_addr, OP_BEFORE) + resume_func.state.live_definitions = old_state + except KeyError: + try: + arch = resume_func.state.arch + old_sp = self.call_trace[-2].state.registers.load(arch.sp_offset, size=arch.bytes) + resume_func.state.registers.store(arch.sp_offset, old_sp, size=arch.bytes) + except (AssertionError, IndexError): + pass + rda_tup = self.run_rda(resume_func, new_subject, is_reanalysis=True) + self._rda.abort() + return rda_tup + + return None + + def find_resume_function(self) -> Tuple[StoredFunction, int]: + """ + Go through the white list of functions and find the earliest point to resume from. + Also rewind the call_trace until that point + :return: + """ + reversed_trace = self.call_trace[::-1] + idx = reversed_trace.index(self.current_parent) + reversed_trace = reversed_trace[:idx+1] + max_index = -1 + resume_func = None + for func in self.white_list: + try: + index = reversed_trace.index(func) + if index > max_index: + max_index = index + resume_func = func + except ValueError: + pass + + return resume_func, max_index + + def hit_depth_change(self, stored_func: StoredFunction) -> bool: + """ + Determines if a depth change occurs as we descend deeper into the calltrace + :param stored_func: + :return: has_changed: bool + """ + assert isinstance(self._rda.subject, CallTraceSubject) + if self.taint_trace: + return False + + if stored_func.function.addr == self._sink_function_addr: + all_constant = True + for sink_atom in self._sink_atoms: + if sink_atom not in stored_func.constant_data: + all_constant = False + elif stored_func.constant_data[sink_atom] is None or any(x is None for x in stored_func.constant_data[sink_atom]): + all_constant = False + if all_constant: + return False + + if stored_func.function.addr != self._sink_function_addr and any(x.code_loc.ins_addr == stored_func.code_loc.ins_addr for x in self.analyzed_list): + return False + + for callsite in stored_func.subject.content.callsites: + if self.current_parent is None: + continue + + if callsite.caller_func_addr == self.current_parent.function.addr \ + and callsite.callee_func_addr == stored_func.function.addr: + if callsite.block_addr is None: + return True + if callsite.block_addr == stored_func.code_loc.block_addr: + callsite.block_addr = None + return True + return False + + def generate_whitelist(self, stored_func: StoredFunction) -> Tuple[List[StoredFunction], bool]: + if stored_func.function.addr == self._sink_function_addr: + target_defns = {defn for atom in self._sink_atoms for defn in stored_func.state.get_definitions(atom)} + in_subject = True + else: + target_defns = stored_func.definitions + in_subject = stored_func.function.addr == stored_func.subject.content.target + in_subject |= any(stored_func.function.addr == x.caller_func_addr for x in stored_func.subject.content.callsites) + + if in_subject: + graph = stored_func.state.dep_graph + else: + graph = self.call_trace[-1].state.dep_graph + + # white_list = [x[0] for x in self.get_dependent_definitions(graph, stored_func, target_defns=target_defns, in_subject=in_subject)] + white_list = self.get_dependent_definitions(graph, stored_func, target_defns=target_defns, in_subject=in_subject) + if not white_list or white_list == [self.current_parent]: + white_list.append(stored_func) + + return white_list, in_subject + + @staticmethod + def bfs_with_stop_nodes(graph, start_nodes, stop_nodes): + visited = set() + queue = deque(start_nodes) + + while queue: + current_node = queue.popleft() + + if current_node in stop_nodes: + continue # Skip and do not explore from this node + + if current_node not in visited: + visited.add(current_node) + if current_node in graph: + neighbors = list(graph[current_node]) + queue.extend(neighbors) + + return visited + + def get_dependent_definitions(self, graph: DepGraph, stored_func: StoredFunction, target_defns: Set[Definition], in_subject) -> List[StoredFunction]: + """ + Recursively get all definitions that our target depends on + :param stored_func: + :param target_atoms: + :return: + """ + + # Get all root nodes of the dependency tree based on the target definitions + + if self.current_parent is not None: + parent_idx = self.call_trace.index(self.current_parent) + else: + parent_idx = 0 + truncated_trace = self.call_trace[parent_idx:] + dependent_defns: Set[Definition] = set() + + # Get all nodes reachable from the root nodes + func_queue = [stored_func] + white_list = [] + while func_queue: + root_defns: Set[Definition] = set() + func = func_queue.pop(0) + white_list.append(func) + + if func == stored_func: + if in_subject: + for defn in target_defns: + closure_graph = graph.transitive_closure(defn) + root_defns |= {node for node in closure_graph if closure_graph.in_degree[node] == 0} + else: + root_defns = target_defns + else: + for defn in func.definitions: + closure_graph = graph.transitive_closure(defn) + root_defns |= {node for node in closure_graph if closure_graph.in_degree[node] == 0} + + new_defns = set() + for defn in {x for x in root_defns if x in graph.graph}: + if defn not in dependent_defns: + try: + new_defns |= set(nx.dfs_preorder_nodes(graph.graph, source=defn)) + except KeyError: + all_nodes = set(graph.graph.nodes()) + remove_edges = set() + remove_nodes = set() + for u, v in graph.graph.edges: + if u not in all_nodes: + remove_nodes.add(u) + remove_edges.add((u, v)) + if v not in all_nodes: + remove_nodes.add(v) + remove_edges.add((u, v)) + graph.graph.remove_nodes_from(remove_nodes) + graph.graph.remove_edges_from(remove_edges) + new_defns |= set(nx.dfs_preorder_nodes(graph.graph, source=defn)) + dependent_defns |= new_defns + + valid_funcs = [] + for t_f in truncated_trace.copy(): + if t_f in white_list: + continue + if any(d in new_defns for d in t_f.definitions | t_f.return_definitions) and t_f.function is not None: + valid_funcs.append(t_f) + truncated_trace.remove(t_f) + + func_queue.extend(valid_funcs) + + return white_list + + @staticmethod + def get_nodes_to_revisit(pred: StoredFunction, desc: StoredFunction): + earlier_blocks = pred._visited_blocks + earlier_blocks.discard(next(x for x in earlier_blocks if x.addr == pred.code_loc.block_addr)) + revisit_nodes = desc._visited_blocks - pred._visited_blocks + return revisit_nodes + + def local_func_wrapper(self, function, state, code_loc): + return self.handle_local_function(state, + function.addr, + call_stack=[], + max_local_call_depth=self._rda._maximum_local_call_depth, + visited_blocks=self._rda.visited_blocks, + dep_graph=state.dep_graph, + codeloc=code_loc, + ) + + def handle_external_function_name( + self, + state: ReachingDefinitionsState, + ext_func_name: str, + src_codeloc: Optional[CodeLocation] = None, + ) -> Tuple[bool, ReachingDefinitionsState]: + handler_name = f"handle_{ext_func_name}" + function = self._project.kb.functions[ext_func_name] + if ext_func_name and hasattr(self, handler_name): + if function.is_simprocedure: + analyzed, state, _, _ = self.local_func_wrapper(function, state, src_codeloc) + return analyzed, state + else: + return getattr(self, handler_name)(state, src_codeloc) + else: + self.log.debug("No handler for external function %s(), falling back to generic handler", ext_func_name) + if self.call_trace[-1].function.name == function.name: + return False, state + analyzed, state, _, _ = self.local_func_wrapper(function, state, src_codeloc) + return analyzed, state + + +def handler_factory( + handlers: Optional[List[Type[LibraryHandler]]] = None, +) -> Type[LocalHandler]: + """ + Generate a `Handler` inheriting from the given handlers. + + :param handlers: The list of library handlers to inherit behavior from. + :return: A `FunctionHandler` to be used during an analysis. + """ + handlers = handlers or [] + handler_cls = type("Handler", (LocalHandler, *handlers), {}) + return handler_cls diff --git a/package/argument_resolver/handlers/network.py b/package/argument_resolver/handlers/network.py new file mode 100644 index 0000000..3e32c34 --- /dev/null +++ b/package/argument_resolver/handlers/network.py @@ -0,0 +1,228 @@ +import logging + +import claripy +import socket + +from typing import List +# +from angr.knowledge_plugins.key_definitions.atoms import Atom, SpOffset, HeapAddress +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues + +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction + +from archinfo import Endness + + +class NetworkHandlers(HandlerBase): + """ + Handlers for network functions + inet_ntoa, + """ + + def __init__(self, *args, **kwargs): + self.ntoa_buf = None + super().__init__(*args, **kwargs) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_accept( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + cc = self._calling_convention_resolver.get_cc(stored_func.name) + + sock_fd = cc.get_next_arg() + fd_val = Utils.get_values_from_cc_arg(sock_fd, state, state.arch) + out_val = [str(x.concrete_value) if x.concrete else str(x) for x in Utils.get_values_from_multivalues(fd_val)] + + ret_fd = self.gen_fd() + possible_parents = [x.concrete_value for x in Utils.get_values_from_multivalues(fd_val) if x.concrete] + self.fd_tracker[ret_fd] = {"val": claripy.BVS(f"{stored_func.name}(fd: {' | '.join(sorted(out_val))})@0x{stored_func.code_loc.ins_addr:x}", state.arch.bits), "parent": possible_parents, "ins_addr": None} + + return True, state, MultiValues(claripy.BVV(ret_fd, state.arch.bits)) + + def _handle_recv( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + cc = self._calling_convention_resolver.get_cc(stored_func.name) + + sock_fd = cc.get_next_arg() + buf = cc.get_next_arg() + len_ = cc.get_next_arg() + + fd_val = Utils.get_values_from_cc_arg(sock_fd, state, state.arch) + buf_addr = Utils.get_values_from_cc_arg(buf, state, state.arch) + len_val = Utils.get_values_from_cc_arg(len_, state, state.arch) + len_max = state.arch.bytes + for len_v in Utils.get_values_from_multivalues(len_val): + if len_v.concrete: + len_max = max(len_max, len_v.concrete_value) + + len_max = min(len_max if len_max > 0 else 0, self.MAX_READ_SIZE) + mem_locs = [] + for ptr in Utils.get_values_from_multivalues(buf_addr): + try: + sp = state.get_sp() + except AssertionError: + sp = state.arch.initial_sp + + if not Utils.is_pointer(ptr, sp, self._project): + continue + + if state.is_stack_address(ptr): + offset = state.get_stack_offset(ptr) + if offset is None: + continue + mem_locs.append(Atom.mem(SpOffset(state.arch.bits, offset), len_max, endness=Endness.BE)) + elif state.is_heap_address(ptr): + offset = state.get_heap_offset(ptr) + if offset is None: + continue + mem_locs.append(Atom.mem(HeapAddress(offset), len_max, endness=Endness.BE)) + elif ptr.concrete: + mem_locs.append(Atom.mem(ptr.concrete_value, len_max, endness=Endness.BE)) + + parent_fds = [] + parent = None + for val in Utils.get_values_from_multivalues(fd_val): + if val.concrete and val.concrete_value in self.fd_tracker: + parent_fds.append(val.concrete_value) + if parent is None: + parent = self.fd_tracker[val.concrete_value]["val"] + else: + parent = parent.concat(self.fd_tracker[val.concrete_value]["val"]) + + if parent is not None: + parent_name = next(iter(x for x in parent.variables if x != "TOP")) + else: + parent_name = "?" + + if stored_func.name not in self.fd_tracker: + self.fd_tracker[stored_func.name] = [] + + for mem in mem_locs: + buf_bvs = claripy.BVS( + f"{stored_func.name}({parent_name})@0x{stored_func.code_loc.ins_addr:x}", + mem.size * 8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + self.fd_tracker[stored_func.name].append({"val": buf_bvs, "parent": parent_fds, "ins_addr": stored_func.code_loc.ins_addr}) + stored_func.depends(mem, *stored_func.atoms, value=MultiValues(buf_bvs), apply_at_callsite=True) + + return True, state, MultiValues(claripy.BVV(len_max, state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_recv( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + return self._handle_recv(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_recvfrom( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + return self._handle_recv(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_nflog_get_payload( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + return self._handle_recv(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_socket( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + cc = self._calling_convention_resolver.get_cc(stored_func.name) + + domain = cc.get_next_arg() + sock_type = cc.get_next_arg() + protocol = cc.get_next_arg() + + domain_val = Utils.get_values_from_cc_arg(domain, state, state.arch) + type_val = Utils.get_values_from_cc_arg(sock_type, state, state.arch) + protocol_val = Utils.get_values_from_cc_arg(protocol, state, state.arch) + + known_domain = { + socket.AF_UNIX: "AF_UNIX", + socket.AF_INET: "AF_INET", + socket.AF_INET6: "AF_INET6", + } + + known_type = { + socket.SOCK_STREAM: "SOCK_STREAM", + socket.SOCK_DGRAM: "SOCK_DGRAM", + socket.SOCK_RAW: "SOCK_RAW", + } + + def get_val_list(val, val_dict) -> List[str]: + out_vals = [] + for v in Utils.get_values_from_multivalues(val): + if v.concrete: + if v.concrete_value in val_dict: + out_vals.append(val_dict[v.concrete_value]) + else: + out_vals.append(hex(v.concrete_value)) + else: + out_vals.append(str(v)) + return out_vals + + domain_vals = get_val_list(domain_val, known_domain) + type_vals = get_val_list(type_val, known_type) + protocol_vals = [str(x.concrete_value) if x.concrete else str(x) for x in Utils.get_values_from_multivalues(protocol_val)] + + ret_fd = self.gen_fd() + + self.fd_tracker[ret_fd] = {"val": claripy.BVS(f"{stored_func.name}({' | '.join(sorted(domain_vals))}, {' | '.join(sorted(type_vals))}, {' | '.join(sorted(protocol_vals))})", state.arch.bits), "parent": None, "ins_addr": None} + + return True, state, MultiValues(claripy.BVV(ret_fd, state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_inet_ntoa( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + handler_name: str = "inet_ntoa", + ): + """ + Hard codes the return address of 127.1.1.1 + """ + + self.log.debug("RDA: %s(), ins_addr=%#x", handler_name, stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc(handler_name) + if self.ntoa_buf is None: + val = MultiValues(claripy.BVV("127.1.1.1")) + + size = Utils.get_size_from_multivalue(val) + heap_addr = state.heap_allocator.allocate(size) + memloc = Atom.mem(heap_addr, size, endness=Endness.BE) + stored_func.depends(memloc, value=val) + heap_mv = MultiValues(Utils.gen_heap_address(heap_addr.value, state.arch)) + self.ntoa_buf = heap_mv + + return True, state, self.ntoa_buf \ No newline at end of file diff --git a/package/argument_resolver/handlers/nvram.py b/package/argument_resolver/handlers/nvram.py new file mode 100644 index 0000000..2491683 --- /dev/null +++ b/package/argument_resolver/handlers/nvram.py @@ -0,0 +1,401 @@ +import logging + +import claripy +# +from angr.calling_conventions import SimRegArg, SimStackArg +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.code_location import ExternalCodeLocation + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.calling_convention import cc_to_rd +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction +from argument_resolver.utils.transitive_closure import get_constant_data + +from archinfo import Endness + + +class NVRAMHandlers(HandlerBase): + """ + Handlers for NVRAM functions + nvram_set, acosNvramConfig_set, nvram_get, nvram_safe_get, acosNvramConfig_get, + """ + + def _handle_nvram_set( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + """ + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + :param str handler_name: Name of the handler + """ + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + # TODO + return False, state, None + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_nvram_set( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int nvram_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_SetValue( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int nvram_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set(state, stored_func) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_nvram_safe_set( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int nvram_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_acosNvramConfig_set( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int acosNvramConfig_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_wlcsm_nvram_set( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int acosNvramConfig_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set( + state, stored_func + ) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_envram_set( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int acosNvramConfig_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set(state, stored_func) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_bcm_nvram_set( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int acosNvramConfig_set(const char *name, const char *value); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_set(state, stored_func) + + + def _handle_nvram_get( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + """ + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + :param str handler_name: Name of the handler + """ + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + + default = False + set_key = None + cc = self._calling_convention_resolver.get_cc(stored_func.name) + key_arg = cc.get_next_arg() + key_ptr = Utils.get_values_from_cc_arg(key_arg, state, state.arch) + key_ptr_definitions = LiveDefinitions.extract_defs_from_mv(key_ptr) + keys = [] + default_values = { + "ifname": b"wlan0", + "netmask": b"0.0.0.0", + "ipaddr": b"1.1.1.1", + "last_auto_ip": b"2.2.2.2", + "gateway": b"3.3.3.3", + "PHYSDEVDRIVER": b"PHYSDEVDRIVER" + } + + # Mark parameter as used. + for def_ in key_ptr_definitions: + state.add_use_by_def(def_, stored_func.code_loc) + resolved_keys = get_constant_data(def_, key_ptr, state) + if resolved_keys is None: + resolved_keys = ["TOP"] + keys.extend(Utils.bytes_from_int(x).decode() if isinstance(x, claripy.ast.Base) else "TOP" for x in resolved_keys) + keys = [x[:-1] if x.endswith("\x00") else x for x in keys] + + if self.env_dict is None: + default = True + + if "get" in stored_func.name: + set_key = stored_func.name.replace("get", "set").replace("Get", "Set") + elif "read" in stored_func.name: + set_key = stored_func.name.replace("read", "write") + + out_mv = MultiValues() + values = [] + found = False + for key in keys: + default_value = f"{stored_func.name}(\"{key}\")@{hex(stored_func.code_loc.ins_addr)}" + for df, df_val in default_values.items(): + if df in key: + values.append(df_val) + found = True + break + if not default and key in self.env_dict: + if found: + continue + if set_key in self.env_dict[key]: + key_vals = [y["value"] for x in self.env_dict[key][set_key].values() for y in x["values"] if y["pos"] == "1"] + for val in key_vals: + if val != "TOP": + values.append(val.encode()) + else: + values.append(default_value) + else: + values.append(default_value) + else: + values.append(default_value) + + if len(values) == 0: + default_value = f"{stored_func.name}(\"TOP\")@{hex(stored_func.code_loc.ins_addr)}" + values.append(default_value) + + for v in values: + if isinstance(v, str): + new_val = claripy.BVS(v, self.MAX_READ_SIZE*8, explicit_name=True) + new_val.variables = frozenset(set(new_val.variables) | {"TOP"}) + out_mv.add_value(0, new_val) + self.env_access.add(new_val) + else: + out_mv.add_value(0, claripy.BVV(v + b"\x00")) + + size = Utils.get_size_from_multivalue(out_mv) + if stored_func.name in {"acosNvramConfig_read", "GetValue"}: + arg_dst = cc.get_next_arg() + dst_ptrs = Utils.get_values_from_cc_arg(arg_dst, state, state.arch) + dst_valid_ptrs = [x for x in Utils.get_values_from_multivalues(dst_ptrs) if not state.is_top(x)] + for dst_ptr in dst_valid_ptrs: + sources = stored_func.atoms - {cc_to_rd(arg_dst, state.arch, state)} + memloc = MemoryLocation(dst_ptr, size, endness=Endness.BE) + stored_func.depends(memloc, *sources, value=out_mv) + return True, state, None + else: + heap_addr = state.heap_allocator.allocate(size) + memloc = MemoryLocation(heap_addr, size, endness=Endness.BE) + stored_func.depends(memloc, value=out_mv, apply_at_callsite=True) + heap_mv = MultiValues(Utils.gen_heap_address(heap_addr.value, state.arch)) + + return True, state, heap_mv + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_nvram_get( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *nvram_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_GetValue( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *nvram_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_nvram_safe_get( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *nvram_safe_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_acosNvramConfig_get( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *acosNvramConfig_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_acosNvramConfig_read( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *acosNvramConfig_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_bcm_nvram_get( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *acosNvramConfig_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_envram_get( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *acosNvramConfig_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_wlcsm_nvram_get( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + char *acosNvramConfig_get(const char *name); + + :param ReachingDefinitionsState state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + return self._handle_nvram_get(state, stored_func) diff --git a/package/argument_resolver/handlers/static.py b/package/argument_resolver/handlers/static.py new file mode 100644 index 0000000..d256604 --- /dev/null +++ b/package/argument_resolver/handlers/static.py @@ -0,0 +1,22 @@ +import logging +from .base import HandlerBase + +from argument_resolver.formatters.log_formatter import make_logger + + +class StaticHandlers(HandlerBase): + """ + Hanlders for functions that should return static values for our purposes: + uname + """ + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_uname(self, state: "ReachingDefinitionsState", stored_func: "StoredFunction"): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + char *strcpy (char * dst, const char * src); + """ + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + return False, state, None \ No newline at end of file diff --git a/package/argument_resolver/handlers/stdio.py b/package/argument_resolver/handlers/stdio.py new file mode 100644 index 0000000..85ee3aa --- /dev/null +++ b/package/argument_resolver/handlers/stdio.py @@ -0,0 +1,655 @@ +import claripy +import itertools + +from typing import List, Union + +from angr.calling_conventions import SimRegArg, SimStackArg +from angr.code_location import CodeLocation +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation, SpOffset + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.tag import ( + ParameterTag, + ReturnValueTag, + SideEffectTag, +) +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.calling_convention import cc_to_rd +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction + +from archinfo import Endness + + +class StdioHandlers(HandlerBase): + """ + Handlers for 's functions. + """ + + # TODO Handle strstr + + def _handle_sprintf( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + concretize_nums: bool = True, + ): + """ + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + :param handler_name: Name of the "real" handler, called originally. + """ + arch = state.arch + + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + + in_place = stored_func.name in ["doSystemCmd", "twsystem", "exec_cmd", "execFormatCmd"] + if in_place: + cc = self._calling_convention_resolver.get_cc("printf") + else: + cc = self._calling_convention_resolver.get_cc(stored_func.name) + # Get sim function arguments + + if not in_place: + arg_dst = cc.get_next_arg() + + if stored_func.name in ["sprintf", "asprintf", "vsprintf"] or in_place: + arg_fmt = cc.get_next_arg() + # num_fixed_args = 2 + elif stored_func.name in ["snprintf", "vsnprintf"]: + cc.get_next_arg() # args[1]: size + arg_fmt = cc.get_next_arg() + # num_fixed_args = 3 + elif stored_func.name == "__sprintf_chk": + cc.get_next_arg() # args[1]: flag + cc.get_next_arg() # args[2]: strlen + arg_fmt = cc.get_next_arg() + # num_fixed_args = 4 + elif stored_func.name == "__snprintf_chk": + cc.get_next_arg() # args[1]: maxlen + cc.get_next_arg() # args[2]: flag + cc.get_next_arg() # args[3]: strlen + arg_fmt = cc.get_next_arg() + # num_fixed_args = 5 + else: + raise ValueError(stored_func.name) + + # Generate a single MultiValues that includes all possible sources / destinations + if not in_place: + dst_ptrs = Utils.get_values_from_cc_arg(arg_dst, state, arch) + fmt_ptrs = Utils.get_values_from_cc_arg(arg_fmt, state, arch) + + # Get all concrete format strings + fmt_strs = Utils.get_strings_from_pointers(fmt_ptrs, state, stored_func.code_loc) + + cont = True + if len(list(Utils.get_values_from_multivalues(fmt_strs))) == 0: + cont = False + self.log.debug("RDA: %s(): No (concrete) format string found", stored_func.name) + + # Get all concrete destination pointer + if cont and not in_place: + dst_int_ptrs = [x for x in Utils.get_values_from_multivalues(dst_ptrs) if not state.is_top(x)] + else: + dst_int_ptrs = [] + + if not in_place and len(dst_int_ptrs) == 0: + cont = False + self.log.debug("RDA: %s(): No (concrete) destination found", stored_func.name) + + formatted_strs = None + if cont: + fmt_args = {} + for fmt_str in [x for x in Utils.get_values_from_multivalues(fmt_strs) if x.concrete]: + fmt_prototypes = Utils.get_prototypes_from_format_string( + Utils.bytes_from_int(fmt_str) + ) + num_prototypes = len(fmt_prototypes) + if num_prototypes == 0: + # Handle format string w/o format prototypes + if formatted_strs is None: + formatted_strs = MultiValues(fmt_str) + else: + formatted_strs.add_value(0, fmt_str) + else: + # res describes the result of a format string. Each element is a list which either includes a + # static part of the format string or the resolved values of a prototype. E.g. 'ls %s' leads to + # res = [['ls '], []]. + res: List[List[Union[str, claripy.ast.BV]]] = [] + + # Extract static part in front of the first prototype + if fmt_prototypes[0].position and fmt_str.concrete: + prologue_len = ((fmt_str.size() // 8) - fmt_prototypes[0].position) * 8 + res.append([fmt_str[:prologue_len]]) + + # Process each prototype and the consecutive static part of the format string + for i in range(num_prototypes): + fmt_prototype = fmt_prototypes[i] + fmt_prototype = fmt_prototype.decode() if isinstance(fmt_prototype, bytes) else fmt_prototype + + # Prototype + values = [] + # noinspection SpellCheckingInspection + if fmt_prototype.specifier in "diuoxX": + if i not in fmt_args: + arg = cc.get_next_arg() + fmt_args[i] = arg + mv = Utils.get_values_from_cc_arg(fmt_args[i], state, arch) + for value in Utils.get_values_from_multivalues(mv): + if value.concrete: + values.append( + claripy.BVV( + str(value._model_concrete.value).encode() + ) + ) + elif concretize_nums: + values.append(claripy.BVV(b"1337")) + else: + values.append(value) + elif fmt_prototype.specifier in "s": + if i not in fmt_args: + arg = cc.get_next_arg() + fmt_args[i] = arg + + src_ptrs = Utils.get_values_from_cc_arg( + fmt_args[i], state, arch + ) + + # get the actual values + strings = Utils.get_strings_from_pointers( + src_ptrs, state, stored_func.code_loc + ) + string_values = Utils.get_values_from_multivalues(strings) + if any(len(x.annotations) == 0 for x in strings[0]): + mem_loc = MemoryLocation( + src_ptrs.one_value(), + Utils.get_size_from_multivalue(strings) // 8, + endness=Endness.BE + ) + stored_func.depends(mem_loc, value=strings, apply_at_callsite=True) + + for val in string_values: + if not state.is_top(val) and Utils.has_unknown_size( + val + ): + val.length = state.arch.bytes * 8 + values.append(val) + + elif fmt_prototype.specifier in "c": + values.append(claripy.BVV(b"|")) + else: + self.log.debug( + "RDA: %s(): Specifier %%%s not supported", + stored_func.name, + fmt_prototype.specifier, + ) + + if values: + res.append(values) + + # Static part + if i < num_prototypes - 1: + end = fmt_prototypes[i + 1].position + else: + end = fmt_str.size() // 8 + s = Utils.bytes_from_int(fmt_str)[ + fmt_prototype.position + len(fmt_prototype.prototype) : end + ].decode("latin-1") + if s not in ("", "\x00"): + res.append([s]) + + # Create a DataRelation for each permutation of res + for combinations in list(itertools.product(*res)): + out_str = None + for str_ in combinations: + if isinstance(str_, str): + str_ = claripy.BVV(str_.encode("latin-1")) + if out_str is None: + out_str = str_ + else: + out_str = out_str.concat(str_) + + if formatted_strs is None: + formatted_strs = MultiValues(out_str) + else: + formatted_strs = formatted_strs.merge(MultiValues(out_str)) + + # Add definition of resolved format string for all concrete destinations + if formatted_strs is not None and formatted_strs.count() > 0 and not in_place: + for dst_ptr in dst_int_ptrs: + sources = stored_func.atoms - {cc_to_rd(arg_dst, state.arch, state)} + if stored_func.name == "asprintf": + alloc_size = Utils.get_size_from_multivalue(formatted_strs) + heap_addr = state.heap_allocator.allocate(alloc_size) + memloc = MemoryLocation(heap_addr, alloc_size, endness=Endness.BE) + stored_func.depends(memloc, *stored_func.atoms, value=formatted_strs) + heap_ptr = MultiValues(Utils.gen_heap_address(heap_addr.value, state.arch)) + store_loc = Utils.get_memory_location_from_bv(dst_ptr, state, state.arch.bytes) + stored_func.depends(store_loc, memloc, value=heap_ptr) + else: + memloc = MemoryLocation(dst_ptr, Utils.get_size_from_multivalue(formatted_strs), endness=Endness.BE) + stored_func.depends(memloc, *sources, value=formatted_strs) + elif in_place and formatted_strs is not None: + first_arg = cc_to_rd(arg_fmt, state.arch, state) + stored_func._arg_vals[first_arg] = formatted_strs + vals = Utils.get_values_from_multivalues(formatted_strs) + stored_func.constant_data[first_arg] = vals if all(v.concrete for v in vals) else None + return True, state, MultiValues(claripy.BVV(0, 32)) + + # Add definition for return value + if cont and formatted_strs is not None: + number_of_symbols = self._number_of_symbols_for_formatted_strings(formatted_strs, state) + if number_of_symbols.count() == 0: + number_of_symbols = MultiValues(state.top(state.arch.bits)) + else: + number_of_symbols = MultiValues(state.top(state.arch.bits)) + + return True, state, number_of_symbols + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_sprintf( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int sprintf ( char * str, const char * format, ... ); + + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_vsprintf( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + Process the impact of the function's execution on register and memory definitions and uses. + + .. sourcecode:: c + + int sprintf ( char * str, const char * format, ... ); + + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_snprintf(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + .. sourcecode:: c + + int snprintf(char *str, size_t size, const char *format, ...); + """ + self.log.debug("RDA: snprintf(): Using sprintf(). Size n is ignored.") + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_vsnprintf(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + .. sourcecode:: c + + int snprintf(char *str, size_t size, const char *format, ...); + """ + self.log.debug("RDA: vsnprintf(): Using sprintf(). Size n is ignored.") + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_asprintf(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + .. sourcecode:: c + + int snprintf(char *str, size_t size, const char *format, ...); + """ + self.log.debug("RDA: asprintf(): Using sprintf().") + return self._handle_sprintf(state, stored_func) + + # TODO Handle __sprintf_chk __snprintf_chk strstr + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle___sprintf_chk(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + .. sourcecode:: c + + int __sprintf_chk(char *str, int flag, size_t strlen, const char *format, ...); + """ + self.log.debug("RDA: __sprintf_chk(): Using sprintf(). Size n is ignored.") + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle___snprintf_chk(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + .. sourcecode:: c + + int __snprintf_chk(char *str, size_t maxlen, int flag, size_t strlen, const char *format, ...); + """ + self.log.debug("RDA: __snprintf_chk(): Using sprintf(). Size n is ignored.") + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_printf(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + return False, state, None + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_twsystem(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + + self.log.debug("RDA: Using sprintf() to handle twsystem") + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_exec_cmd(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + + self.log.debug("RDA: Using sprintf() to handle %s", stored_func.name) + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_doSystemCmd(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + + self.log.debug("RDA: Using sprintf() to handle doSystemCmd") + return self._handle_sprintf(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_dprintf(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + return False, state, None + + def _handle_scanf(self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + concretize_nums: bool = True): + arch = state.arch + + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc(stored_func.name) + + # Get sim function arguments + arg_src = None + if stored_func.name.replace('__isoc99_', '') in {"sscanf", "fscanf"}: + arg_src = cc.get_next_arg() + # num_fixed_args = 2 + arg_fmt = cc.get_next_arg() + + fmt_ptrs = Utils.get_values_from_cc_arg(arg_fmt, state, arch) + + # Get all concrete format strings + fmt_strs = Utils.get_strings_from_pointers(fmt_ptrs, state, stored_func.code_loc) + + if fmt_strs.count() == 0: + self.log.debug("RDA: %s(): No (concrete) format string found", stored_func.name) + return False, state + + # Get all concrete destination pointer + fmt_args = {} + for fmt_str in [x for x in Utils.get_values_from_multivalues(fmt_strs) if x.concrete]: + fmt_prototypes = Utils.get_prototypes_from_format_string( + Utils.bytes_from_int(fmt_str) + ) + num_prototypes = len(fmt_prototypes) + if num_prototypes == 0: + # Handle format string w/o format prototypes + continue + + # Process each prototype and the consecutive static part of the format string + for i, fmt_prototype in enumerate(fmt_prototypes): + if '*' in fmt_prototype.prototype: + continue + + if i not in fmt_args: + arg = cc.get_next_arg() + fmt_args[i] = {"arg": arg, "value": None} + + mv = MultiValues(state.top(state.arch.bits)) + if fmt_prototype.specifier in "diuoxX": + mv = MultiValues(claripy.BVV(0x1337, state.arch.bits)) + elif arg_src: + src_ptrs = Utils.get_values_from_cc_arg( + arg_src, state, arch + ) + new_mv = MultiValues() + for src_ptr in Utils.get_values_from_multivalues(src_ptrs): + if state.is_top(src_ptr): + new_mv.add_value(0, src_ptr) + else: + if new_mv.count() > 0: + new_mv = new_mv.merge(Utils.get_strings_from_pointers(src_ptrs, state, stored_func.code_loc)) + else: + new_mv = Utils.get_strings_from_pointers(src_ptrs, state, stored_func.code_loc) + mv = new_mv + + if fmt_args[i]["value"] is None: + fmt_args[i]["value"] = mv + else: + fmt_args[i]["value"] = fmt_args[i]["value"].merge(mv) + + for val_dict in fmt_args.values(): + arg = val_dict["arg"] + val = val_dict["value"] + dst_ptrs = Utils.get_values_from_cc_arg(arg, state, arch) + + # Add definition of resolved format string for all concrete destinations + dst_int_ptrs = [x for x in Utils.get_values_from_multivalues(dst_ptrs) if not state.is_top(x)] + for dst_ptr in dst_int_ptrs: + if not state.is_top(dst_ptr): + memloc = MemoryLocation(Utils.get_store_method_from_ptr(dst_ptr, state), Utils.get_size_from_multivalue(val)) + sources = {cc_to_rd(arg_src, state.arch, state)} if arg_src else set() + stored_func.depends(memloc, *sources, value=val) + else: + self.log.debug("Failed to store to %s: Unresolvable Destination", val) + + number_of_symbols = MultiValues(claripy.BVV(len(fmt_args), state.arch.bits)) + return True, state, number_of_symbols + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_sscanf(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + return self._handle_scanf(state, stored_func) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_fgets(self, state: ReachingDefinitionsState, stored_func: StoredFunction): + """ + Process read and marks it as taint + .. sourcecode:: c + char *fgets(char *s, int size, FILE *stream); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: fgets(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc("fgets") + + arg_buf = cc.get_next_arg() + arg_count = cc.get_next_arg() + arg_stream = cc.get_next_arg() + + buf_ptrs = Utils.get_values_from_cc_arg(arg_buf, state, arch) + size = Utils.get_values_from_cc_arg(arg_count, state, arch) + stream = Utils.get_values_from_cc_arg(arg_stream, state, arch) + + parent_fds = [] + parent = None + for val in Utils.get_values_from_multivalues(stream): + if val.concrete and val.concrete_value in self.fd_tracker: + parent_fds.append(val.concrete_value) + if parent is None: + parent = self.fd_tracker[val.concrete_value]["val"] + else: + parent = parent.concat(self.fd_tracker[val.concrete_value]["val"]) + + for ptr in Utils.get_values_from_multivalues(buf_ptrs): + # sp = reach_def.get_sp() + size_val = Utils.get_concrete_value_from_int(size) + size_val = max(size_val) if size_val is not None else state.arch.bytes + size_val = min(size_val, self.MAX_READ_SIZE) + memloc = MemoryLocation(ptr, size_val) + if parent is not None: + parent_name = next(iter(x for x in parent.variables if x != "TOP")) + else: + parent_name = "?" + buf_bvs = claripy.BVS( + f"{stored_func.name}({parent_name})@0x{stored_func.code_loc.ins_addr:x}", + memloc.size*8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + + if stored_func.name not in self.fd_tracker: + self.fd_tracker[stored_func.name] = [] + + self.fd_tracker[stored_func.name].append( + {"val": buf_bvs, "parent": parent_fds, "ins_addr": stored_func.code_loc.ins_addr}) + stored_func.depends(memloc, *stored_func.atoms, value=MultiValues(buf_bvs)) + + return True, state, buf_ptrs + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_fopen(self, state: ReachingDefinitionsState, stored_func: StoredFunction): + """ + Process read and marks it as taint + .. sourcecode:: c + FILE* fopen(char *path, const char *mode); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: fopen(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc("fopen") + + arg_path = cc.get_next_arg() + arg_mode = cc.get_next_arg() + + path_ptrs = Utils.get_values_from_cc_arg(arg_path, state, arch) + mode_ptrs = Utils.get_values_from_cc_arg(arg_mode, state, arch) + + path = Utils.get_strings_from_pointers(path_ptrs, state, stored_func.code_loc) + paths = [] + for p in Utils.get_values_from_multivalues(path): + if p.concrete: + paths.append(f'"{Utils.bytes_from_int(p).decode("latin-1")}"') + else: + paths.append(f'"{p}"') + + mode = Utils.get_strings_from_pointers(mode_ptrs, state, stored_func.code_loc) + modes = [] + for m in Utils.get_values_from_multivalues(mode): + if m.concrete: + modes.append(f'"{Utils.bytes_from_int(m).decode("latin-1")}"') + else: + modes.append(f'"{m}"') + + fd = self.gen_fd() + buf_bvs = claripy.BVS( + f"{stored_func.name}({' | '.join(sorted(paths))}, {' | '.join(sorted(modes))})@0x{stored_func.code_loc.ins_addr:x}", + state.arch.bits) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + + self.fd_tracker[fd] = {"val": buf_bvs, "parent": None, "ins_addr": None} + + return True, state, MultiValues(claripy.BVV(fd, state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_fread(self, state: ReachingDefinitionsState, stored_func: StoredFunction): + """ + Process read and marks it as taint + .. sourcecode:: c + char *fread(char *s, int size, size_t nmemb, FILE *stream); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: fgets(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc("fread") + + arg_buf = cc.get_next_arg() + arg_size = cc.get_next_arg() + arg_nmemb = cc.get_next_arg() + arg_stream = cc.get_next_arg() + + buf_ptrs = Utils.get_values_from_cc_arg(arg_buf, state, arch) + size = Utils.get_values_from_cc_arg(arg_size, state, arch) + nmemb = Utils.get_values_from_cc_arg(arg_nmemb, state, arch) + stream = Utils.get_values_from_cc_arg(arg_stream, state, arch) + + size_val = Utils.get_concrete_value_from_int(size) + size_val = max(size_val) if size_val is not None else state.arch.bytes + + nmemb_val = Utils.get_concrete_value_from_int(nmemb) + nmemb_val = max(nmemb_val) if nmemb_val is not None else state.arch.bytes + + parent_fds = [] + parent = None + for val in Utils.get_values_from_multivalues(stream): + if val.concrete and val.concrete_value in self.fd_tracker: + parent_fds.append(val.concrete_value) + if parent is None: + parent = self.fd_tracker[val.concrete_value]["val"] + else: + parent = parent.concat(self.fd_tracker[val.concrete_value]["val"]) + + if stored_func.name not in self.fd_tracker: + self.fd_tracker[stored_func.name] = [] + + for ptr in Utils.get_values_from_multivalues(buf_ptrs): + size = min(size_val*nmemb_val, 0x1000) + memloc = MemoryLocation(ptr, size) + + if parent is not None: + parent_name = next(iter(x for x in parent.variables if x != "TOP")) + else: + parent_name = "?" + buf_bvs = claripy.BVS( + f"{stored_func.name}({parent_name})@0x{stored_func.code_loc.ins_addr:x}", + memloc.size*8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + mv = MultiValues(buf_bvs) + self.fd_tracker[stored_func.name].append( + {"val": buf_bvs, "parent": parent_fds, "ins_addr": stored_func.code_loc.ins_addr}) + stored_func.depends(memloc, *stored_func.atoms, value=mv, apply_at_callsite=True) + + return True, state, MultiValues(state.top(state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_popen(self, state: ReachingDefinitionsState, stored_func: StoredFunction): + return False, state, None + + def _number_of_symbols_for_formatted_strings( + self, formatted_strings: MultiValues, state: ReachingDefinitionsState + ) -> MultiValues: + data_number_of_symbols = MultiValues() + for s in Utils.get_values_from_multivalues(formatted_strings): + if isinstance(s, claripy.String): + data_number_of_symbols.add_value( + 0, claripy.BVV(s.string_length - 1, state.arch.bits) + ) + elif state.is_top(s): + return MultiValues(offset_to_values={0: {state.top(state.arch.bits)}}) + else: + data_number_of_symbols.add_value( + 0, claripy.BVV(len(s), state.arch.bits) + ) + + return data_number_of_symbols diff --git a/package/argument_resolver/handlers/stdlib.py b/package/argument_resolver/handlers/stdlib.py new file mode 100644 index 0000000..8a442ab --- /dev/null +++ b/package/argument_resolver/handlers/stdlib.py @@ -0,0 +1,265 @@ +import claripy +from typing import TYPE_CHECKING +import logging + +from angr.calling_conventions import SimRegArg, SimStackArg + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation +from angr.knowledge_plugins.key_definitions.heap_address import HeapAddress +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.knowledge_plugins.key_definitions.tag import ReturnValueTag +from angr.knowledge_plugins.key_definitions.undefined import Undefined + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction +from argument_resolver.utils.calling_convention import cc_to_rd + +from archinfo import Endness + + +if TYPE_CHECKING: + from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState + + +class StdlibHandlers(HandlerBase): + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_malloc(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + void *malloc(size_t size); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: malloc(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("malloc") + + arg_size = cc.get_next_arg() + + size = Utils.get_values_from_cc_arg(arg_size, state, state.arch) + size_ints = Utils.get_concrete_value_from_int(size) + + alloc_size = max(size_ints) if size_ints is not None else 0x64 + alloc_size = max(alloc_size, 0x20) + if size_ints and len(size_ints) >= 2: + self.log.debug( + "RDA: malloc(): Found multiple values for size: %s, used %d", + ", ".join([str(i) for i in size_ints]), + alloc_size, + ) + else: + self.log.debug("RDA: malloc(): No concrete size found") + + heap_addr = state.heap_allocator.allocate(alloc_size) + memloc = MemoryLocation(heap_addr, alloc_size) + heap_val = MultiValues(claripy.BVV(0x0, alloc_size*8)) + stored_func.depends(memloc, *stored_func.atoms, value=heap_val) + ptr = MultiValues(Utils.gen_heap_address(heap_addr.value, state.arch)) + + return True, state, ptr + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_calloc(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + void *calloc(size_t nmemb, size_t size); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: calloc(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("calloc") + + arg_nmemb = cc.get_next_arg() + arg_size = cc.get_next_arg() + + nmemb_values = Utils.get_values_from_cc_arg(arg_nmemb, state, state.arch) + size_values = Utils.get_values_from_cc_arg(arg_size, state, state.arch) + + nmemb_ints = Utils.get_concrete_value_from_int(nmemb_values) + size_ints = Utils.get_concrete_value_from_int(size_values) + + nmemb = max(nmemb_ints) if nmemb_ints is not None else 1 + size = max(size_ints) if size_ints is not None else 0x64 + + chunk_size = max(nmemb*size, state.arch.bytes*2) + addr = state.heap_allocator.allocate(chunk_size) + ptr = claripy.BVV(addr.value, state.arch.bits) + location = MemoryLocation(addr, chunk_size) + + stored_func.depends(location, *stored_func.atoms, value=MultiValues(claripy.BVV(0, chunk_size*8))) + + return True, state, MultiValues(ptr) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_free(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + void *free(void *ptr); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: free(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + #cc = self._calling_convention_resolver.get_cc("free") + #ptr_argument = cc.get_next_arg() + + #ptr_data = Utils.get_values_from_cc_arg(ptr_argument, state, state.arch) + + #for pointer_value in Utils.get_values_from_multivalues(ptr_data): + # if Utils.is_heap_address(pointer_value): + # heap_offset = Utils.get_heap_offset(pointer_value) + # state.heap_allocator.free(HeapAddress(heap_offset)) + # elif state.is_top(pointer_value): + # state.heap_allocator.free(Undefined()) + # else: + # self.log.debug("RDA: free(): Unexpected Pointer Value, got %s", pointer_value) + return False, state, None + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_rand(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + int rand(void); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: rand(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + return True, state, MultiValues(claripy.BVV(0xDEADBEEF, state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_system(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + int system(const char *command); + :param state: Register and memory definitions and uses + :param stored_func: Stored Function data + """ + self.log.debug("RDA: system(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + # Add definition for return value + # Let's return 0 by default, assuming everything went fine: + # > The value returned is -1 on error (e.g., fork(2) failed), and the return status of the command otherwise. + + return True, state, MultiValues(claripy.BVV(0, state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_getenv(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + char *getenv(const char *name); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: getenv(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc("getenv") + + name_argument = cc.get_next_arg() + name_pointers = Utils.get_values_from_cc_arg(name_argument, state, arch) + name_values = Utils.get_strings_from_pointers(name_pointers, state, stored_func.code_loc) + + + return_values = MultiValues() + for offset in name_values.keys(): + for name in name_values[offset]: + if name.concrete: + concrete_name = Utils.bytes_from_int(name).decode("latin-1") + buf_bvs = claripy.BVS(f'{stored_func.name}("{concrete_name}")@0x{stored_func.code_loc.ins_addr:x}', self.MAX_READ_SIZE*8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + self.env_access.add(buf_bvs) + ret_val, has_unknown = state.environment.get({concrete_name}) + if ret_val == {Undefined()}: + addr = state.heap_allocator.allocate(self.MAX_READ_SIZE) + atom = MemoryLocation(addr, buf_bvs.size()//8, endness=Endness.BE) + ret_val = MultiValues(Utils.gen_heap_address(addr.value, state.arch)) + stored_func.depends(atom, *stored_func.atoms, value=buf_bvs) + else: + for val in ret_val: + atom = MemoryLocation(val, state.arch.bytes) + stored_func.depends(atom, *stored_func.atoms) + ret_val = MultiValues(offset_to_values={0: ret_val}) + else: + addr = state.heap_allocator.allocate(self.MAX_READ_SIZE) + buf_bvs = claripy.BVS(f"{stored_func.name}({name})@0x{stored_func.code_loc.ins_addr:x}", self.MAX_READ_SIZE*8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + atom = MemoryLocation(addr, buf_bvs.size()//8, endness=Endness.BE) + ret_val = MultiValues(Utils.gen_heap_address(addr.value, state.arch)) + stored_func.depends(atom, *stored_func.atoms, value=buf_bvs) + return_values = return_values.merge(ret_val) + + return True, state, return_values + + def _setenv(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + int setenv(const char *name, const char *value, int overwrite); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + self.log.debug("RDA: setenv(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc(stored_func.name) + + if stored_func.name == "httpSetEnv": + # discard first arg + _ = cc.get_next_arg() + name_argument = cc.get_next_arg() + value_argument = cc.get_next_arg() + + name_pointers = Utils.get_values_from_cc_arg(name_argument, state, arch) + name_values = Utils.get_strings_from_pointers(name_pointers, state, stored_func.code_loc) + + value_values = Utils.get_values_from_cc_arg(value_argument, state, arch) + + addrs = set() + for pointer in Utils.get_values_from_multivalues(value_values): + strings = Utils.get_strings_from_pointer(pointer, state, stored_func.code_loc) + size = Utils.get_size_from_multivalue(strings) + addr = state.heap_allocator.allocate(size) + addrs.add(Utils.gen_heap_address(addr.value, state.arch)) + atom = MemoryLocation(addr, size, endness=Endness.BE) + try: + source_atoms = {defn.atom for defn in LiveDefinitions.extract_defs_from_mv(pointer)} + except AttributeError: + source_atoms = {cc_to_rd(value_argument, state.arch, state)} + + stored_func.depends(atom, *source_atoms, value=strings, apply_at_callsite=True) + + for name in Utils.get_values_from_multivalues(name_values): + if not name.concrete: + continue + + concrete_name = Utils.bytes_from_int(name).decode("utf-8") + state.environment.set(concrete_name, addrs) + + return True, state, MultiValues(claripy.BVV(0, state.arch.bits)) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_setenv(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + return self._setenv(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_httpSetEnv(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + return self._setenv(state, stored_func) diff --git a/package/argument_resolver/handlers/string.py b/package/argument_resolver/handlers/string.py new file mode 100644 index 0000000..ec6505c --- /dev/null +++ b/package/argument_resolver/handlers/string.py @@ -0,0 +1,875 @@ +import logging + +import claripy +import itertools +import functools +import re + +from typing import Optional, Tuple + +from angr.calling_conventions import SimRegArg, SimStackArg +from angr.knowledge_plugins.key_definitions.atoms import ( + MemoryLocation, + SpOffset, + Register, +) + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.tag import ( + SideEffectTag, + ReturnValueTag, +) +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions, DerefSize +from angr.knowledge_plugins.key_definitions.heap_address import HeapAddress +from angr.errors import SimMemoryError + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.calling_convention import cc_to_rd +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction + +from archinfo import Endness + + +class StringHandlers(HandlerBase): + """ + Handlers for 's functions + strcmp, strncmp, strncmp, strncasecmp, strcoll, strcpy, strncpy, strcat, strncat + *NOTE*: Please think twice before adding class attributes: + It should happen *ONLY* for `string.h` functions that have an internal state that the analysis needs to model. + """ + + # As per the documentation: + # "On the first call to strtok() the string to be parsed should be specified in str. + # In each subsequent call that should parse the same string, str should be NULL." + _strtok_remaining_string_pointers: Optional[MultiValues] = None + + def _handle_strcat( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + """ + :param LiveDefinitions state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + :param str handler_name: Name of the handler + """ + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc(stored_func.name) + # Get sim function arguments + arg_dst = cc.get_next_arg() + arg_src = cc.get_next_arg() + + # Extract pointers for arguments + dst_ptrs = Utils.get_values_from_cc_arg(arg_dst, state, state.arch) + src_ptrs = Utils.get_values_from_cc_arg(arg_src, state, state.arch) + + # Evaluate all pointers + + dst_values = [] + for dst_atom in state.deref(dst_ptrs, DerefSize.NULL_TERMINATE): + dst_strings = state.get_values(dst_atom) + if dst_strings is None: + dst_values.append((MultiValues(state.top(state.arch.bits)), dst_atom)) + continue + + new_dst_strings = MultiValues() + for dst_string in dst_strings[0]: + if dst_string.size() >= 8: + last_byte = dst_string.get_byte(dst_string.size() // 8 - 1) + if last_byte.concrete and last_byte.concrete_value == 0: + dst_string = dst_string.get_bytes(0, dst_string.size() // 8 - 1) + if dst_string.size() == 0: + dst_string = state.top(state.arch.bits) + new_dst_strings.add_value(0, dst_string) + dst_values.append((new_dst_strings, dst_atom)) + + src_values = [] + for src_atom in state.deref(src_ptrs, DerefSize.NULL_TERMINATE): + if src_atom.size == 4096: + src_strings = MultiValues(state.top(state.arch.bits)) + else: + src_strings = state.get_values(src_atom) + if src_strings is None: + src_strings = MultiValues(state.top(state.arch.bits)) + + src_values.append((src_strings, src_atom)) + + if len(dst_values) == 0: + dst_values.append((MultiValues(state.top(state.arch.bits)), None)) + elif len(src_values) == 0: + src_values.append((MultiValues(state.top(state.arch.bits)), None)) + + for d, s in itertools.product(dst_values, src_values): + d_val, d_atom = d + s_val, s_atom = s + if d_atom is None: + continue + + concat_value = d_val.concat(s_val) + dst_memloc = MemoryLocation(d_atom.addr, Utils.get_size_from_multivalue(concat_value), endness=Endness.BE) + atoms = [d_atom] + if s_atom is not None: + atoms.append(s_atom) + stored_func.depends(dst_memloc, *atoms, value=concat_value) + + return True, state, dst_ptrs + + def _handle_strcmp( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ignore_case: bool = True, + ): + """ + :param LiveDefinitions state:: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + :param str handler_name: Name of the handler + :param bool ignore_case: Case sensitivity + """ + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + + res = MultiValues( + offset_to_values={ + 0: { + claripy.BVV(-1, state.arch.bits), + claripy.BVV(0, state.arch.bits), + claripy.BVV(1, state.arch.bits), + } + } + ) + + return True, state, res + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strcmp(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + int strcmp ( const char * str1, const char * str2 ); + """ + return self._handle_strcmp(state, stored_func, ignore_case=True) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strncmp( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + .. sourcecode:: c + int strncmp(const char *s1, const char *s2, size_t n); + """ + self.log.debug("RDA: strncmp(): Using strcmp(). Size n is ignored.") + return self._handle_strcmp(state, stored_func, ignore_case=True) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strcasecmp( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + .. sourcecode:: c + int strcasecmp(const char *s1, const char *s2); + """ + return self._handle_strcmp(state, stored_func, ignore_case=False) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strncasecmp( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + .. sourcecode:: c + int strncasecmp(const char *s1, const char *s2, size_t n); + """ + self.log.debug( + "RDA: strncasecmp(): Using strcmp() case sensitivity. Size n is ignored." + ) + return self._handle_strcmp(state, stored_func, ignore_case=False) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strcoll( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + .. sourcecode:: c + int strcoll(const char *s1, const char *s2); + """ + self.log.debug("RDA: strcoll(): Using strcmp(). Locales are ignored.") + return self._handle_strcmp(state, stored_func, ignore_case=True) + + def _handle_strcpy(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + cc = self._calling_convention_resolver.get_cc("strcpy") + # Get sim function arguments + arg_dst = cc.get_next_arg() + arg_src = cc.get_next_arg() + + # Extract values for arguments + dst_ptrs = Utils.get_values_from_cc_arg(arg_dst, state, state.arch) + src_ptrs = Utils.get_values_from_cc_arg(arg_src, state, state.arch) + + # Evaluate all pointers + for dst_ptr in Utils.get_values_from_multivalues(dst_ptrs): + if state.is_top(dst_ptr): + self.log.debug("RDA: strcpy(): Destination pointer undefined") + elif state.is_stack_address(dst_ptr) or state.is_heap_address(dst_ptr): + src_values = Utils.get_strings_from_pointers( + src_ptrs, state, stored_func.code_loc + ) + max_size = Utils.get_size_from_multivalue(src_values) + if max_size == 0: + max_size = state.arch.bytes + src_values = MultiValues(state.top(state.arch.bits)) + if state.is_heap_address(dst_ptr): + heap_offset = state.get_heap_offset(dst_ptr) + sub_atom = HeapAddress(heap_offset) + else: + sub_atom = dst_ptr + #stack_offset = state.get_stack_offset(dst_ptr) + #sub_atom = SpOffset(state.arch.bits, stack_offset) + memloc = MemoryLocation(sub_atom, max_size, endness=Endness.BE) + src_memlocs = {MemoryLocation(src_ptr, max_size) for src_ptr in + Utils.get_values_from_multivalues(src_ptrs)} + stored_func.depends(memloc, *src_memlocs, value=src_values) + else: + self.log.debug( + "RDA: strcpy(): Expected TOP or stack offset, got %s", + type(dst_ptr).__name__, + ) + + return True, state, dst_ptrs + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strcpy(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + char *strcpy (char * dst, const char * src); + """ + self.log.debug("RDA: strcpy(), ins_addr=%#x", stored_func.code_loc.ins_addr) + return self._handle_strcpy(state, stored_func) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strncpy( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + .. sourcecode:: c + char *strncpy(char *dst, const char *src, size_t n); + """ + self.log.debug("RDA: strncpy(), ins_addr=%#x", stored_func.code_loc.ins_addr) + # GDI just use regular strcpy + return self._handle_strcpy(state, stored_func) + + #cc = self._calling_convention_resolver.get_cc("strncpy") + + #dst_argument = cc.get_next_arg() + #src_argument = cc.get_next_arg() + #n_argument = cc.get_next_arg() + + #dst_pointers = Utils.get_values_from_cc_arg(dst_argument, state, state.arch) + #src_pointers = Utils.get_values_from_cc_arg(src_argument, state, state.arch) + #n_values = Utils.get_values_from_cc_arg(n_argument, state, state.arch) + + #src_values = Utils.get_strings_from_pointers(src_pointers, state, stored_func.code_loc) + #_new_dst_values = MultiValues() + + #for n in Utils.get_values_from_multivalues(n_values): + # _src_values = set() + # for value in Utils.get_values_from_multivalues(src_values): + # n_val = Utils.get_signed_value(n.concrete_value, state.arch.bits) if n.concrete else -1 + # if n.symbolic: + # value_to_add = Utils.value_of_unknown_size(value, state, cc_to_rd(src_argument, state.arch, state), stored_func.code_loc).one_value() + + # elif state.is_top(value): + # value_to_add = state.top(n.concrete_value * 8) + + # elif n_val < 0: + # value_to_add = Utils.value_of_unknown_size(value, state, cc_to_rd(src_argument, state.arch, state), stored_func.code_loc).one_value() + + # elif n_val * 8 < value.size(): + # # truncate the string if needed + # value_to_add = value[value.size() - 1: value.size() - n_val * 8] + # else: + # # we don't have enough bits (or have just enough bits). don't truncate. + # value_to_add = value + + # _src_values.add(value_to_add) + + # _new_dst_values = _new_dst_values.merge( + # MultiValues(offset_to_values={0: _src_values}) + # ) + + ## Evaluate all pointers + #for dst_pointer in Utils.get_values_from_multivalues(dst_pointers): + # if state.is_top(dst_pointer): + # self.log.debug("RDA: strncpy(): Destination pointer undefined") + # elif isinstance(dst_pointer, (SpOffset, claripy.ast.Base)): + # size = max(x.size() for x in Utils.get_values_from_multivalues(_new_dst_values)) + # memory_location = None + # if state.is_stack_address(dst_pointer): + # memory_location = MemoryLocation(dst_pointer, size // 8, endness=Endness.BE) + # elif state.is_heap_address(dst_pointer): + # heap_offset = state.get_heap_offset(dst_pointer) + # heap_addr = HeapAddress(heap_offset) + # memory_location = MemoryLocation(heap_addr, size // 8, endness=Endness.BE) + # elif dst_pointer.concrete and self._project.loader.find_segment_containing(dst_pointer._model_concrete.value) and self._project.loader.find_segment_containing(dst_pointer._model_concrete.value).is_writable: + # memory_location = MemoryLocation(dst_pointer._model_concrete.value, size // 8, endness=Endness.BE) + # else: + # self.log.debug( + # "RDA: strncpy(): Invalid destination pointer %#x, sp=%s", + # dst_pointer + # if isinstance(dst_pointer, int) + # else dst_pointer._model_concrete.value, + # hex(Utils.get_sp(state)) + # ) + # if memory_location: + # src_memlocs = {MemoryLocation(src_ptr, size // 8) for src_ptr in + # Utils.get_values_from_multivalues(src_pointers)} + # stored_func.depends(memory_location, *src_memlocs, value=_new_dst_values) + # else: + # self.log.debug( + # "RDA: strncpy(): Expected Undefined, integer or Parameter, got %s", + # type(dst_pointer).__name__, + # ) + + #return True, state, dst_pointers + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strcat(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + char * strcat ( char * destination, const char * source ); + """ + return self._handle_strcat(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strncat( + self, state: "ReachingDefinitionsState", stored_func: StoredFunction + ): + """ + .. sourcecode:: c + char *strncat(char *dest, const char *src, size_t n); + """ + self.log.debug("RDA: strncat(): Using strcat(). Size n is ignored.") + return self._handle_strcat(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strlen(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + .. sourcecode:: c + size_t strlen(const char *s); + :param LiveDefinitions state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + self.log.debug("RDA: strlen(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("strlen") + + arg_str = cc.get_next_arg() + str_ptrs = Utils.get_values_from_cc_arg(arg_str, state, state.arch) + + str_values = Utils.get_strings_from_pointers(str_ptrs, state, stored_func.code_loc) + + res = MultiValues() + for str_ in Utils.get_values_from_multivalues(str_values): + if state.is_top(str_): + res.add_value(0, state.top(state.arch.bits)) + self.log.debug("RDA: strlen(): Could not resolve str") + elif isinstance(str_, claripy.ast.Base): + res.add_value(0, claripy.BVV(str_.size() // 8, state.arch.bits)) + else: + self.log.debug( + "RDA: strlen(): Expected BVV, got %s", + type(str_).__name__, + ) + if res._values is None: + res.add_value(0, state.top(state.arch.bits)) + + return True, state, res + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_atoi(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + int atoi (const char * str); + :param LiveDefinitions state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + self.log.debug("RDA: atoi(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("atoi") + + arg_str = cc.get_next_arg() + str_ptrs = Utils.get_values_from_cc_arg(arg_str, state, state.arch) + + str_values = Utils.get_strings_from_pointers(str_ptrs, state, stored_func.code_loc) + + res = MultiValues() + for str_ in Utils.get_values_from_multivalues(str_values): + if state.is_top(str_) or ( + isinstance(str_, claripy.ast.Base) and not str_.concrete + ): + res.add_value(0, state.top(state.arch.bits)) + self.log.debug("RDA: atoi(): Could not resolve str") + elif isinstance(str_, claripy.ast.Base): + str_ = Utils.bytes_from_int(str_).decode("latin-1") + + match = re.match(r"^[\t\n\v\f\r]*([+-]?\d+).*$", str_) + if match is None or match.group(1) == "": + res.add_value(0, claripy.BVV(0, state.arch.bits)) + self.log.debug( + "RDA: atoi(): claripy.ast.Base could not be simplified to a string" + ) + else: + res.add_value(0, claripy.BVV(int(match.group(1)), state.arch.bits)) + else: + self.log.debug( + "RDA: atoi(): Expected claripy.ast.Base, got %s", + type(str_).__name__, + ) + + return True, state, res + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_memcpy(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + void *memcpy(void *dest, const void *src, size_t n); + :param LiveDefinitions state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + self.log.debug("RDA: memcpy(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("memcpy") + + _get_values_from_cc_arg = lambda argument: Utils.get_values_from_cc_arg( + argument, state, state.arch + ) + dest_values = _get_values_from_cc_arg(cc.get_next_arg()) + src_values = _get_values_from_cc_arg(cc.get_next_arg()) + n_values = _get_values_from_cc_arg(cc.get_next_arg()) + + # Recover the content to "copy" from memory. + src_content = Utils.get_strings_from_pointers(src_values, state, stored_func.code_loc) + + # Restrict the content to the number of characters retrieved earlier. + + truncated_mv = MultiValues() + for length in Utils.get_values_from_multivalues(n_values): + for offset in src_content.keys(): + for string in src_content[offset]: + if ( + length.concrete + and length._model_concrete.value < string.size() // 8 + ): + if length._model_concrete.value != 0: + truncated_mv.add_value( + offset, + string[ + : string.size() + - length._model_concrete.value * 8 + ], + ) + else: + truncated_mv.add_value(offset, claripy.BVV(0x0, 8)) + else: + truncated_mv.add_value(offset, string) + + # Set the data of the destination's definitions. + for value in Utils.get_values_from_multivalues(dest_values): + size = Utils.get_size_from_multivalue(truncated_mv) + memory_location = MemoryLocation(value, size) + src_locations = {MemoryLocation(src_ptr, size) for src_ptr in Utils.get_values_from_multivalues(src_values)} + stored_func.depends(memory_location, *src_locations, value=truncated_mv) + + return True, state, dest_values + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_memset(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + void *memset(void *s, int c, size_t n); + :param LiveDefinitions state: Register and memory definitions and uses + :param Codeloc codeloc: Code location of the call + """ + self.log.debug("RDA: memset(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("memset") + + arg_s = cc.get_next_arg() + arg_c = cc.get_next_arg() + arg_n = cc.get_next_arg() + + s_pointer_values = Utils.get_values_from_cc_arg(arg_s, state, state.arch) + + c_values = Utils.get_values_from_cc_arg(arg_c, state, state.arch) + n_values = Utils.get_values_from_cc_arg(arg_n, state, state.arch) + + s_values = MultiValues() + for (c, n) in itertools.product( + Utils.get_values_from_multivalues(c_values), + Utils.get_values_from_multivalues(n_values), + ): + if isinstance(n, claripy.ast.Base) and n.concrete: + if isinstance(c, claripy.ast.Base) and c.concrete: + if n.concrete_value == 0: + continue + size = ( + state.arch.bytes + if n.concrete_value < 0 + or (n.concrete_value >> n.size() - 1) == 1 + else n.concrete_value + ) + value = MultiValues(claripy.BVV(bytes([c.concrete_value & 0xFF] * size), size * 8)) + elif state.is_top(c) or state.is_stack_address(c): + definitions = list(state.extract_defs(c)) + if definitions: + value = Utils.unknown_value_of_unknown_size( + state, definitions[0].atom, stored_func.code_loc + ) + else: + value = MultiValues(state.top(state.arch.bits)) + else: + raise ValueError( + f"RDA: memset(): Expected Undefined or int for parameter c, got {type(c).__name__}" + ) + elif isinstance(n, claripy.ast.Base) and state.is_top(n): + if isinstance(c, claripy.ast.Base) and c.concrete: + value = MultiValues(claripy.BVV( + bytes([c.concrete_value & 0xff]) * state.arch.bytes, + state.arch.bytes * 8, + )) + elif isinstance(c, claripy.ast.Base) and state.is_top(c): + definitions = list(state.extract_defs(c)) + if definitions: + value = Utils.unknown_value_of_unknown_size( + state, definitions[0].atom, stored_func.code_loc + ) + else: + value = MultiValues(state.top(state.arch.bits)) + else: + value = MultiValues(state.top(state.arch.bits)) + self.log.debug( + f"RDA: memset(): Expected TOP or concrete for parameter c, got {type(c).__name__}" + ) + self.log.debug("RDA: memset(): Could not resolve n") + elif isinstance(n, claripy.ast.Base): + value = MultiValues(state.top(state.arch.bits)) + else: + raise ValueError( + f"RDA: memset(): Expected TOP or concrete for parameter n, got {type(n).__name__}", + ) + + s_values = s_values.merge(value) + + if s_values.count() > 0: + for destination_pointer in Utils.get_values_from_multivalues(s_pointer_values): + memory_location = MemoryLocation( + destination_pointer, Utils.get_size_from_multivalue(s_values) + ) + + stored_func.depends(memory_location, value=s_values) + + return True, state, s_pointer_values + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strdup(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process the impact of the function's execution on register and memory definitions and uses. + .. sourcecode:: c + char *strdup(const char *s); + """ + self.log.debug("RDA: strdup(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("strdup") + + s_argument = cc.get_next_arg() + s_pointers = Utils.get_values_from_cc_arg(s_argument, state, state.arch) + + s_values = Utils.get_strings_from_pointers(s_pointers, state, stored_func.code_loc) + + # Count the trailing '\0' in the length. + length = Utils.get_size_from_multivalue(s_values) + + # As per `strdup` manual: "Memory for the new string is obtained with malloc [...]" + new_string_address = state.heap_allocator.allocate(length) + + # Add values the string can take to the new memory location + memory_location = MemoryLocation(new_string_address, length, endness=Endness.BE) + src_locations = {MemoryLocation(ptr, length) for ptr in Utils.get_values_from_multivalues(s_pointers)} + stored_func.depends(memory_location, *src_locations, value=s_values) + + destination_pointers = MultiValues(claripy.BVV(new_string_address.value, state.arch.bits)) + + return True, state, destination_pointers + + def _handle_strstr(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + arch = state.arch + cc = self._calling_convention_resolver.get_cc("fgets") + + haystack_arg = cc.get_next_arg() + # needle_arg = cc.get_next_arg() + + return_locations = Utils.get_values_from_cc_arg(haystack_arg, state, arch) + # needle_ptrs = Utils.get_values_from_cc_arg(needle_arg, state, arch) + + # return_locations = MultiValues() + # for haystack_ptr in Utils.get_values_from_multivalues(haystack_ptrs): + # for haystack_string in Utils.get_values_from_multivalues(Utils.get_strings_from_pointer(haystack_ptr, state, stored_func.code_loc)): + # # sp = reach_def.get_sp() + # if haystack_string.symbolic: + # return_locations.add_value(0, state.top(state.arch.bits)) + # continue + + # haystack_str = Utils.bytes_from_int(haystack_string) + # for needle in Utils.get_values_from_multivalues(Utils.get_strings_from_pointers(needle_ptrs, state, stored_func.code_loc)): + # if needle.symbolic: + # return_locations.add_value(0, state.top(state.arch.bits)) + # continue + # needle_str = Utils.bytes_from_int(needle) + # if needle_str not in haystack_str: + # return_locations.add_value(0, claripy.BVV(0, state.arch.bits)) + # else: + # idx = haystack_str.find(needle_str) + # return_locations.add_value(0, haystack_ptr + idx) + + return True, state, return_locations + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strstr(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Process read and marks it as taint + .. sourcecode:: c + char *strstr(char *haystack, char *needle); + :param state: Register and memory definitions and uses + :param codeloc: Code location of the call + """ + # Instead of doing a full implementation of strstr, we just return the haystack pointer + self.log.debug("RDA: strstr(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + return self._handle_strstr(state, stored_func) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_stristr(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Instead of doing a full implementation of strtok, we just return the haystack pointer + .. sourcecode:: c + char *strtok(char *str, const char *delim); + """ + self.log.debug("RDA: stristr(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + return self._handle_strstr(state, stored_func) + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strchr(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Instead of doing a full implementation of strtok, we just return the haystack pointer + .. sourcecode:: c + char *strtok(char *str, const char *delim); + """ + self.log.debug("RDA: strchr(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + return self._handle_strstr(state, stored_func) + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_strtok(self, state: "ReachingDefinitionsState", stored_func: StoredFunction): + """ + Instead of doing a full implementation of strtok, we just return the haystack pointer + .. sourcecode:: c + char *strtok(char *str, const char *delim); + """ + self.log.debug("RDA: strtok(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + return self._handle_strstr(state, stored_func) + + #cc = self._calling_convention_resolver.get_cc("strtok") + + #str_argument = cc.get_next_arg() + #str_pointers = Utils.get_values_from_cc_arg(str_argument, state, state.arch) + + ## `strtok` is *NOT* a pure function: + ## - it modifies the input `str` + ## - it has an internal state keeping the `str` passed in its latest call + #if 0x0 in [ + # x._model_concrete.value + # for x in Utils.get_values_from_multivalues(str_pointers) + # if x.concrete + #]: + # if self._strtok_remaining_string_pointers is not None: + # self.log.debug("RDA: strtok(): Subsequent calls on the same input string detected, but no pointers have been recorded on a previous call!") + # return True, state, MultiValues(claripy.BVV(0x0, 32)) + + # if self._strtok_remaining_string_pointers is None or all( + # x.concrete and 0x0 == x._model_concrete.value + # for x in Utils.get_values_from_multivalues( + # self._strtok_remaining_string_pointers + # ) + # ): + # self.log.info("RDA: strtok(): End of string reached") + # return True, state, MultiValues(claripy.BVV(0x0, 32)) + # else: + # str_pointers = self._strtok_remaining_string_pointers + + #delim_argument = cc.get_next_arg() + #delim_pointers = Utils.get_values_from_cc_arg(delim_argument, state, state.arch) + #delim_values = Utils.get_strings_from_pointers(delim_pointers, state, stored_func.code_loc) + + #def _strtok(string, delimiter) -> Tuple[MultiValues, Optional[MultiValues]]: + # """ + # :return: The token, and the "leftover" string, past the delimiter. + # """ + # if state.is_top(string): + # defs = list(state.extract_defs(string)) + # if defs: + # definition = defs[0] + # return Utils.unknown_value_of_unknown_size( + # state, definition.atom, stored_func.code_loc + # ), Utils.unknown_value_of_unknown_size( + # state, definition.atom, stored_func.code_loc + # ) + # else: + # self.log.debug("No definition exists for string %s.", string) + # atom = Register( + # str_argument.check_offset(state.arch), str_argument.size + # ) + # return Utils.unknown_value_of_unknown_size( + # state, atom, stored_func.code_loc + # ), Utils.unknown_value_of_unknown_size(state, atom, stored_func.code_loc) + # elif string.concrete: + # if delimiter.concrete: + # concrete_string = string._model_concrete.value + # concrete_delim = delimiter._model_concrete.value + # if isinstance(concrete_string, int): + # concrete_string = concrete_string.to_bytes( + # string.size() // 8, "big" + # ) + + # if isinstance(concrete_delim, int): + # concrete_delim = concrete_delim.to_bytes( + # delimiter.size() // 8, "big" + # ) + + # index = concrete_string.find(concrete_delim) + # if index == -1: + # return MultiValues(offset_to_values={0: {string}}), None + # return MultiValues( + # offset_to_values={0: {string[: index * 8]}} + # ), MultiValues(offset_to_values={0: {string[index * 8 :]}}) + # else: + # self.log.debug( + # "RDA: strtok(): Expected concrete for parameter delim, got %s", + # type(delimiter).__name__, + # ) + # definition = next(state.extract_defs(string)) + # return Utils.unknown_value_of_unknown_size( + # state, definition.atom, stored_func.code_loc + # ), Utils.unknown_value_of_unknown_size( + # state, definition.atom, stored_func.code_loc + # ) + # else: + # self.log.debug( + # "RDA: strtok(): Expected Undefined, or str for parameter string, got %s", + # type(string).__name__, + # ) + # definition = next(state.extract_defs(string)) + # return Utils.unknown_value_of_unknown_size( + # state, definition.atom, stored_func.code_loc + # ), Utils.unknown_value_of_unknown_size(state, definition.atom, stored_func.code_loc) + + ## Keep track of the pointers to the remaining strings (past the first token), to be able to handle subsequent + ## calls to `strtok` with NULL pointers. + #remaining_string_pointers = set() + #return_values = set() + #for pointer, delimiter in itertools.product( + # Utils.get_values_from_multivalues(str_pointers), + # Utils.get_values_from_multivalues(delim_values), + #): + # # Try to be as precise as possible for each pointer: + # # - get only the strings pointed to + # # - make the size as small as possible (size of the biggest pointed element) + # max_length = 0 + + # for strings in itertools.product( + # Utils.get_values_from_multivalues( + # Utils.get_strings_from_pointer(pointer, state, stored_func.code_loc) + # ) + # ): + # string = functools.reduce(lambda a, b: a.concat(b), strings) + # token, leftover = _strtok(string, delimiter) + # if leftover: + # leftover = leftover.one_value() + + # if 0 not in token: + # continue + + # for value in token[0]: + # return_values |= {value} + + # # Put the relevant data in the model: + # # - create the memory location corresponding to the start of the remaining string + # # - set the truncated string to the corresponding data + # length = value.size() // 8 + # if Utils.has_unknown_size(value): + # pointer_to_leftover = state.top(state.arch.bits) + # leftover_length = state.arch.bytes + # elif leftover is None or ( + # isinstance(leftover, claripy.ast.Base) + # and leftover.concrete + # and leftover._model_concrete.value == 0x0 + # ): + # pointer_to_leftover = claripy.BVV(0x0, state.arch.bits) + # leftover_length = 0 + # else: + # pointer_to_leftover = pointer + length + # leftover_length = string.size() // 8 - length + + # if leftover is not None: + # memory_location = MemoryLocation( + # pointer_to_leftover, leftover_length + # ) + # try: + # stored_func.depends(memory_location, cc_to_rd(str_argument, state.arch, state)) + # except SimMemoryError: + # pass + + # remaining_string_pointers |= {pointer_to_leftover} + + # max_length = length if length > max_length else max_length + + # memory_location = MemoryLocation(pointer, max_length) + # stored_func.depends(memory_location, cc_to_rd(str_argument, state.arch, state)) + + #if not remaining_string_pointers: + # empty_pointer = {claripy.BVV(0x0, state.arch.bits)} + # self._strtok_remaining_string_pointers = MultiValues( + # offset_to_values={0: empty_pointer} + # ) + #else: + # self._strtok_remaining_string_pointers = MultiValues( + # offset_to_values={0: remaining_string_pointers} + # ) + + #return True, state, self._strtok_remaining_string_pointers if remaining_string_pointers else None diff --git a/package/argument_resolver/handlers/unistd.py b/package/argument_resolver/handlers/unistd.py new file mode 100644 index 0000000..4443ab7 --- /dev/null +++ b/package/argument_resolver/handlers/unistd.py @@ -0,0 +1,152 @@ +import os + +from angr.code_location import ExternalCodeLocation +from angr.calling_conventions import SimRegArg, SimStackArg +from angr.engines.light import SpOffset +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState +from angr.knowledge_plugins.key_definitions.tag import ( + ReturnValueTag, + SideEffectTag, + InitialValueTag, +) +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.calling_convention import cc_to_rd +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction + +from archinfo import Endness + +import claripy + + +class UnistdHandlers(HandlerBase): + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_open(self, state: ReachingDefinitionsState, stored_func: StoredFunction): + """ + Process read and marks it as taint + .. sourcecode:: c + int open(char *path, const char *mode); + :param stored_func: + :param state: Register and memory definitions and uses + """ + self.log.debug("RDA: fgets(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc("fread") + + arg_path = cc.get_next_arg() + arg_mode = cc.get_next_arg() + + path_ptrs = Utils.get_values_from_cc_arg(arg_path, state, arch) + mode = Utils.get_values_from_cc_arg(arg_mode, state, arch) + known_modes = { + os.O_RDONLY: "r", + os.O_WRONLY: "w", + os.O_RDWR: "rw", + } + + path = Utils.get_strings_from_pointers(path_ptrs, state, stored_func.code_loc) + paths = [] + for p in Utils.get_values_from_multivalues(path): + if p.concrete: + paths.append(f'"{Utils.bytes_from_int(p).decode("latin-1")}"') + else: + paths.append(f'"{p}"') + + modes = [] + for m in Utils.get_values_from_multivalues(mode): + if m.concrete: + modes.append(f'"{known_modes.get(m.concrete_value, m.concrete_value)}"') + else: + modes.append(f'"{m}"') + + fd = self.gen_fd() + buf_bvs = claripy.BVS(f"{stored_func.name}({' | '.join(paths)}, {' | '.join(modes)})@0x{stored_func.code_loc.ins_addr:x}", + state.arch.bits) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + self.fd_tracker[fd] = {"val": buf_bvs, "parent": None, "ins_addr": None} + + return True, state, MultiValues(claripy.BVV(fd, state.arch.bits)) + + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_read( + self, + state: "ReachingDefinitionsState", + stored_func: StoredFunction, + ): + """ + Process read and marks it as taint + .. sourcecode:: c + size_t read(int fd, void *buf, size_t count); + :param ReachingDefinitionsState state: reaching definitions state + :param Codeloc codeloc: Code location of the call + :param handler_name: Name of function to handle + """ + self.log.debug("RDA: %s(), ins_addr=%#x", stored_func.name, stored_func.code_loc.ins_addr) + + cc = self._calling_convention_resolver.get_cc("read") + + # get args + fd = cc.get_next_arg() # fd + arg_buf = cc.get_next_arg() + arg_size = cc.get_next_arg() + + # get buf + fd_vals = Utils.get_values_from_cc_arg(fd, state, state.arch) + buf_ptrs = Utils.get_values_from_cc_arg(arg_buf, state, state.arch) + + # get count/size + size_values = Utils.get_values_from_cc_arg(arg_size, state, state.arch) + + + parent = None + parent_fds = [] + for val in Utils.get_values_from_multivalues(fd_vals): + if val.concrete and val.concrete_value in self.fd_tracker: + parent_fds.append(val.concrete_value) + if parent is None: + parent = self.fd_tracker[val.concrete_value]["val"] + else: + parent = parent.concat(self.fd_tracker[val.concrete_value]["val"]) + + if stored_func.name not in self.fd_tracker: + self.fd_tracker[stored_func.name] = [] + + for ptr in Utils.get_values_from_multivalues(buf_ptrs): + for count_val in Utils.get_values_from_multivalues(size_values): + sp_offset = SpOffset(state.arch.bits, state.get_stack_offset(ptr)) + if sp_offset.offset is None: + continue + + if count_val.concrete: + size = min(count_val.concrete_value, self.MAX_READ_SIZE) + memloc = MemoryLocation(sp_offset, size, endness=Endness.BE) + else: + memloc = MemoryLocation(sp_offset, state.arch.bytes, endness=Endness.BE) + + if parent is not None: + parent_name = next(iter(x for x in parent.variables if x != "TOP")) + else: + parent_name = "?" + + buf_bvs = claripy.BVS( + f"{stored_func.name}({parent_name})@0x{stored_func.code_loc.ins_addr:x}", + memloc.size * 8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + self.fd_tracker[stored_func.name].append( + {"val": buf_bvs, "parent": parent_fds, "ins_addr": stored_func.code_loc.ins_addr}) + mv = MultiValues(buf_bvs) + + stored_func.depends(memloc, *stored_func.atoms, value=mv) + + return True, state, size_values diff --git a/package/argument_resolver/handlers/url_param.py b/package/argument_resolver/handlers/url_param.py new file mode 100644 index 0000000..6457017 --- /dev/null +++ b/package/argument_resolver/handlers/url_param.py @@ -0,0 +1,76 @@ +import os + +from angr.code_location import ExternalCodeLocation +from angr.calling_conventions import SimRegArg, SimStackArg +from angr.engines.light import SpOffset +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState +from angr.knowledge_plugins.key_definitions.tag import ( + ReturnValueTag, + SideEffectTag, + InitialValueTag, +) +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues + +from argument_resolver.formatters.log_formatter import make_logger +from argument_resolver.handlers.base import HandlerBase +from argument_resolver.utils.calling_convention import cc_to_rd +from argument_resolver.utils.utils import Utils +from argument_resolver.utils.stored_function import StoredFunction + +import claripy + + +class URLParamHandlers(HandlerBase): + + @HandlerBase.returns + @HandlerBase.tag_parameter_definitions + def handle_custom_param_parser(self, state: ReachingDefinitionsState, stored_func: StoredFunction): + """ + Process read and marks it as taint + .. sourcecode:: c + int open(char *path, const char *mode); + :param stored_func: + :param state: Register and memory definitions and uses + """ + self.log.debug("RDA: fgets(), ins_addr=%#x", stored_func.code_loc.ins_addr) + + arch = state.arch + cc = self._calling_convention_resolver.get_cc("query_param_parser") + + if len(stored_func.function.prototype.args) == 1: + src = None + param = cc.get_next_arg() + dst = None + else: + src = cc.get_next_arg() + param = cc.get_next_arg() + dst = cc.get_next_arg() + + param_ptr = Utils.get_values_from_cc_arg(param, state, arch) + params = Utils.get_strings_from_pointers(param_ptr, state, stored_func.code_loc) + + param_list = [] + for p in Utils.get_values_from_multivalues(params): + if p.concrete: + found_param = Utils.bytes_from_int(p).decode("latin-1") + param_list.append(found_param) + + buf_bvs = claripy.BVS(f'frontend_param("{", ".join(param_list)}")@0x{stored_func.code_loc.ins_addr:x}', + self.MAX_READ_SIZE * 8) + buf_bvs.variables = frozenset(set(buf_bvs.variables) | {"TOP"}) + self.keyword_access[buf_bvs] = param_list + + mv = MultiValues(buf_bvs) + if dst: + dst_ptr = Utils.get_values_from_cc_arg(dst, state, arch) + if len(stored_func.function.prototype.args) > 2: + for dst in Utils.get_values_from_multivalues(dst_ptr): + if not state.is_top(dst): + memloc = MemoryLocation(Utils.get_store_method_from_ptr(dst, state), Utils.get_size_from_multivalue(mv)) + sources = {cc_to_rd(src, state.arch, state)} if src else set() + stored_func.depends(memloc, *sources, value=mv) + + return True, state, mv diff --git a/package/argument_resolver/utils/__init__.py b/package/argument_resolver/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/package/argument_resolver/utils/call_trace.py b/package/argument_resolver/utils/call_trace.py new file mode 100644 index 0000000..1027b85 --- /dev/null +++ b/package/argument_resolver/utils/call_trace.py @@ -0,0 +1,75 @@ +from typing import List, Set, Tuple + +import networkx + +from angr.analyses.reaching_definitions.call_trace import CallTrace +from angr.knowledge_plugins.functions import Function +from angr import Project + + +def _trace_contains_child(parent, child): + if child[1] is None: + return parent.includes_function(child[0]) + else: + parent_set = { + (x.caller_func_addr, x.callee_func_addr) for x in parent.callsites + } + return parent.includes_function(child[0]) and child[1].issubset(parent_set) + + +# TODO This logic can probably be simplified +def traces_to_sink( + sink: Function, + callgraph, + max_depth: int, + excluded_functions: Set[Tuple], +) -> Set[CallTrace]: + """ + Peek into the callgraph and discover all functions reaching the sink within `max_depth` layers of calls. + + :param sink: The function to be reached. + :param project: The project ot obtain the callgraph from. + :param max_depth: A bound within to look for transitive predecessors of the sink. + :param excluded_functions: A set of functions to ignore, and stop the discovery from. + + :return: s leading to the given sink. + """ + queue: List[Tuple[CallTrace, int]] = [(CallTrace(sink.addr), 0)] + starts: Set[CallTrace] = set() + + while queue: + trace, curr_depth = queue.pop(0) + + if trace.current_function_address() in starts: + continue + + caller_func_addr = trace.current_function_address() + callers: Set[int] = set(callgraph.predecessors(caller_func_addr)) + + if len(callers) == 0: + starts |= {trace} + + # remove the functions that we already came across - essentially bypassing recursive function calls - and excluded functions + if any(_trace_contains_child(trace, ex) for ex in excluded_functions): + callers = set() + + for caller in callers.copy(): + if trace.includes_function(caller): + callers.remove(caller) + if any(caller == ex[0] for ex in excluded_functions if ex[1] is None): + callers.remove(caller) + + caller_depth = curr_depth + 1 + if caller_depth >= max_depth: + # reached the depth limit. add them to potential analysis starts + starts |= { + trace.step_back(caller_addr, None, caller_func_addr) + for caller_addr in callers + } + else: + # add them to the queue + queue.extend([(trace.step_back(caller_addr, None, caller_func_addr), caller_depth) + for caller_addr in callers + ]) + + return starts diff --git a/package/argument_resolver/utils/call_trace_visitor.py b/package/argument_resolver/utils/call_trace_visitor.py new file mode 100644 index 0000000..3cda6d4 --- /dev/null +++ b/package/argument_resolver/utils/call_trace_visitor.py @@ -0,0 +1,96 @@ +from typing import List, Tuple, Optional, Set + +from angr.block import BlockNode +from angr.utils.graph import GraphUtils +from angr.analyses.forward_analysis.visitors.graph import GraphVisitor, NodeType +from angr.analyses.reaching_definitions.call_trace import CallTrace +from angr.analyses.reaching_definitions.subject import Subject, SubjectType +from angr.utils.graph import dfs_back_edges + + +class CallTraceSubject(Subject): + def __init__(self, trace: CallTrace, func): + self._content = trace + self._visitor = FunctionGraphVisitor(func) + self._type = SubjectType.CallTrace + self._cc = func.calling_convention + + @property + def visitor(self) -> "FunctionGraphVisitor": + return self._visitor + + def copy(self): + clone = CallTraceSubject(self._content, self._visitor.function) + clone._visitor._sorted_nodes = self._visitor._sorted_nodes.copy() + clone._visitor._worklist = self._visitor._worklist.copy() + clone._visitor._nodes_set = self._visitor._nodes_set.copy() + clone._visitor._node_to_index = self._visitor._node_to_index.copy() + clone._visitor._reached_fixedpoint = self._visitor._reached_fixedpoint.copy() + clone._visitor._back_edges_by_src = self._visitor._back_edges_by_src.copy() + clone._visitor._back_edges_by_dst = self._visitor._back_edges_by_dst.copy() + clone._visitor._pending_nodes = self._visitor._pending_nodes.copy() + + return clone + + +class FunctionGraphVisitor(GraphVisitor): + """ + :param knowledge.Function func: + """ + + def __init__(self, func, graph=None): + super().__init__() + self.function = func + + if graph is None: + self.graph = self.function.graph + else: + self.graph = graph + + self.reset() + + def mark_nodes_for_revisit(self, blocks: Set[BlockNode]): + for block in blocks: + self.revisit_node(block) + + def mark_nodes_as_visited(self, nodes: Set[BlockNode]): + valid_nodes = {x for x in nodes if x in self._nodes_set} + for node in valid_nodes: + self._worklist.remove(node) + + def revisit_successors(self, node: NodeType, include_self=True) -> None: + super().revisit_successors(node, include_self=include_self) + #print("WORKLIST", node, self._worklist) + + def successors(self, node): + return list(self.graph.successors(node)) + + def predecessors(self, node): + return list(self.graph.predecessors(node)) + + def next_node(self) -> Optional[NodeType]: + node = super().next_node() + while node is not None and node in self._reached_fixedpoint: + node = super().next_node() + return node + + def sort_nodes(self, nodes=None): + sorted_nodes = GraphUtils.quasi_topological_sort_nodes(self.graph) + + if nodes is not None: + sorted_nodes = [n for n in sorted_nodes if n in set(nodes)] + + return sorted_nodes + + def back_edges(self) -> List[Tuple[NodeType, NodeType]]: + start_nodes = [node for node in self.graph if node.addr == self.function.addr] + if not start_nodes: + start_nodes = [ + node for node in self.graph if self.graph.in_degree(node) == 0 + ] + + if not start_nodes: + raise NotImplementedError() + + start_node = start_nodes[0] + return list(dfs_back_edges(self.graph, start_node)) \ No newline at end of file diff --git a/package/argument_resolver/utils/calling_convention.py b/package/argument_resolver/utils/calling_convention.py new file mode 100644 index 0000000..cfaa31d --- /dev/null +++ b/package/argument_resolver/utils/calling_convention.py @@ -0,0 +1,229 @@ +import logging + +from typing import Dict, List, Union, Optional, Tuple, TYPE_CHECKING + +from angr.calling_conventions import ( + SimFunctionArgument, + SimRegArg, + SimStackArg, + DEFAULT_CC, + SimCC, +) +from angr.sim_type import SimTypePointer, SimTypeChar + +from angr.engines.light import SpOffset +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation, Register +from angr.knowledge_plugins.functions.function_manager import FunctionManager +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.procedures.definitions.glibc import _libc_decls +from archinfo.arch import Arch + +from argument_resolver.utils.utils import Utils +from argument_resolver.external_function.function_declarations import CUSTOM_DECLS + +if TYPE_CHECKING: + from angr.sim_type import SimTypeFunction + + +LOGGER = logging.getLogger("FastFRUIT") + +LIBRARY_DECLS = {**_libc_decls, **CUSTOM_DECLS} + + +def get_default_cc_with_args(num_args: int, arch: Arch, is_win=False) -> SimCC: + """ + Get the default calling convention, containing where the arguments are located when the function is called, and + where the return value will be placed. + + Query angr.calling_convention.DEFAULT_CC to recover the calling convention corresponding to the given arch, and + compute the argument positions whenever they appear on the stack. + + :param num_args: The number of arguments the function takes. + :param arch: The architecture of the binary where the studied function is. + :return: The calling convention. + """ + + def compute_offset(arch, default_cc, offset): + def is_mips(arch): + return arch.name.lower().find("mips") > -1 + + initial_offset = 1 if arch.call_pushes_ret else 0 + mips_offset = ( + len(default_cc.ARG_REGS + default_cc.FP_ARG_REGS) if is_mips(arch) else 0 + ) + return arch.bytes * (offset + mips_offset + initial_offset) + + platform_name = "Linux" if not is_win else "Win32" + default_cc = DEFAULT_CC[arch.name][platform_name] + + reg_args = [SimRegArg(x, arch.bytes) for x in default_cc.ARG_REGS[:num_args]] + stack_args: List[SimFunctionArgument] = list( + map( + lambda offset: SimStackArg( + compute_offset(arch, default_cc, offset), arch.bits + ), + range(num_args - len(reg_args)), + ) + ) + cc = SimCC.find_cc(arch, reg_args + stack_args, default_cc.STACKARG_SP_DIFF) + if cc is None: + cc = default_cc(arch) + return cc + + +def cc_to_rd( + sim: SimFunctionArgument, arch: Arch, state=None +) -> Union[Register, MemoryLocation]: + """ + Conversion to Register and SpOffset from respectively angr/calling_conventions' SimRegArg and SimStackArg. + + The arch parameter is necessary to create the Register, as its constructor needs an offset and a size. + + :param sim: Input register or stack offset + :param arch: Architecture + :return: Output register or stack offset + """ + if isinstance(sim, SimRegArg): + offset, size = arch.registers[sim.reg_name] + return Register(offset, size, arch) + if isinstance(sim, SimStackArg): + if state is not None: + initial_sp = ( + LiveDefinitions.INITIAL_SP_64BIT + if arch.bits == 64 + else LiveDefinitions.INITIAL_SP_32BIT + ) + return MemoryLocation( + SpOffset( + sim.size, (Utils.get_sp(state) - initial_sp) + sim.stack_offset + ), + sim.size, + endness=arch.memory_endness, + ) + else: + return MemoryLocation( + SpOffset(sim.size, sim.stack_offset), + sim.size, + endness=arch.memory_endness, + ) + else: + raise TypeError(f"Expected SimRegArg or SimStackArg, got {type(sim).__name__}") + + +def get_next_arg(self): + if hasattr(self, "sim_func"): + if self.arg_counter >= len(self.sim_func.args): + if len(self.sim_func.args) > 0: + arg = self.next_arg( + self.session, self.sim_func.args[-1].with_arch(self.ARCH) + ) + else: + arg = self.next_arg( + self.session, SimTypePointer(SimTypeChar).with_arch(self.ARCH) + ) + else: + arg = self.next_arg( + self.session, self.sim_func.args[self.arg_counter].with_arch(self.ARCH) + ) + self.arg_counter += 1 + else: + arg = self.next_arg( + self.session, SimTypePointer(SimTypeChar).with_arch(self.ARCH) + ) + return arg + + +class CallingConventionResolver: + """ + Query calling conventions for the functions we are interested in. + """ + + def __init__( + self, + project, + arch: Arch, + functions: FunctionManager, + ): + """ + :param arch: The architecture targeted by the analysed binary. + :param functions: Function manager that includes all functions of the binary. + :param variable_recovery_fast: The analysis from the ongoing . + """ + self._project = project + self._arch = arch + self._functions = functions + + self._cc: Dict[str, SimCC] = {} + self._prototypes: Dict[str, Optional["SimTypeFunction"]] = {} + + def _get_cc_and_proto( + self, function_name + ) -> Tuple[Optional[SimCC], Optional["SimTypeFunction"]]: + cc, proto = None, None + + if function_name in LIBRARY_DECLS: + number_of_parameters = len(LIBRARY_DECLS[function_name].args) + cc = get_default_cc_with_args( + number_of_parameters, + self._arch, + is_win=len(self._project.loader.all_pe_objects) > 0, + ) + cc.sim_func = LIBRARY_DECLS[function_name] + + # attempt to use CallingConventionAnalysis to get its prototype + func = self._functions.function(name=function_name) + if func is not None: + cc_analysis = self._project.analyses.CallingConvention(func) + proto = cc_analysis.prototype + else: + proto = cc.sim_func + elif function_name in self._functions: + func = self._functions[function_name] + self._project.analyses.VariableRecoveryFast(func) + cc_analysis = self._project.analyses.CallingConvention(func) + if cc_analysis.cc is None: + LOGGER.error("CCA: Failed for %s()", function_name) + else: + cc = cc_analysis.cc + proto = cc_analysis.prototype + cc.sim_func = LIBRARY_DECLS[function_name] + # LOGGER.debug("CCA: %s() with arguments %s", function_name, cc.args) + else: + LOGGER.error( + "CCA: Failed for %s(), function neither an external function nor have its name in CFG", + function_name, + ) + setattr(SimCC, "get_next_arg", get_next_arg) + + return cc, proto + + def get_cc(self, function_name: str) -> Optional[SimCC]: + """ + Return calling convention given the name of a function. + + :param function_name: The function's name + :return: The Calling convention (from angr) + """ + if function_name not in self._cc: + ( + self._cc[function_name], + self._prototypes[function_name], + ) = self._get_cc_and_proto(function_name) + if self._cc[function_name] is not None: + self._cc[function_name].session = self._cc[function_name].arg_session(None) + self._cc[function_name].arg_counter = 0 + return self._cc[function_name] + + def get_prototype(self, function_name: str) -> Optional["SimTypeFunction"]: + """ + Return the function prototype given the name of a function. + + :param function_name: Function name + :return: The function prototype + """ + if function_name not in self._cc: + ( + self._cc[function_name], + self._prototypes[function_name], + ) = self._get_cc_and_proto(function_name) + return self._prototypes[function_name] diff --git a/package/argument_resolver/utils/closure.py b/package/argument_resolver/utils/closure.py new file mode 100644 index 0000000..32a9d39 --- /dev/null +++ b/package/argument_resolver/utils/closure.py @@ -0,0 +1,130 @@ +from typing import NamedTuple + +from angr.analyses import ReachingDefinitionsAnalysis + +from .stored_function import StoredFunction + +class SkeletonClosure: + + def __init__(self, closure): + self.callsites = Closure._get_callsites(closure) + self.code_loc = closure.sink_trace.code_loc + self.sink_addr = closure.sink_trace.function.addr + self.call_stack_len = len(closure.sink_trace.call_stack) + self.hash = hash(closure) + + def __lt__(self, other): + if isinstance(other, Closure): + if not self.sink_addr == other.sink_trace.function.addr: + return False + closure_callsites = Closure._get_callsites(other) + call_stack_len = len(other.sink_trace.call_stack) + + elif isinstance(other, SkeletonClosure): + closure_callsites = other.callsites + call_stack_len = other.call_stack_len + + else: + raise ValueError(f"Cannot compare SkeletonClosure and {other.__class__.__name__}") + + if not (self.callsites < closure_callsites): + return False + + return self.call_stack_len < call_stack_len + + def __gt__(self, other): + if isinstance(other, Closure): + if not self.sink_addr == other.sink_trace.function.addr: + return False + closure_callsites = Closure._get_callsites(other) + call_stack_len = len(other.sink_trace.call_stack) + + elif isinstance(other, SkeletonClosure): + closure_callsites = other.callsites + call_stack_len = other.call_stack_len + + else: + raise ValueError(f"Cannot compare SkeletonClosure and {other.__class__.__name__}") + + if not (self.callsites > closure_callsites): + return False + + return self.call_stack_len > call_stack_len + + def __hash__(self): + return self.hash + + def __eq__(self, other): + return hash(self) == hash(other) + + +class Closure(NamedTuple): + sink_trace: StoredFunction + rda: ReachingDefinitionsAnalysis + handler: "HandlerBase" + + def __lt__(self, other): + self._type_check(other) + if not self.sink_trace.function == other.sink_trace.function: + return False + + if self.compare_callsites(other) != -1: + return False + + return self.sink_trace.call_stack < other.sink_trace.call_stack + + def __gt__(self, other): + self._type_check(other) + if not self.sink_trace.function == other.sink_trace.function: + return False + + if self.compare_callsites(other) != 1: + return False + + return self.sink_trace.call_stack > other.sink_trace.call_stack + + def __eq__(self, other): + self._type_check(other) + return hash(self) == hash(other) + + def compare_callsites(self, other): + callsites = self._get_callsites(self) + other_callsites = self._get_callsites(other) + + if callsites == other_callsites: + return 0 + + if callsites < other_callsites: + return -1 + + if callsites > other_callsites: + return 1 + + return None + + @staticmethod + def _get_callsites(closure): + callsites = set() + for callsite in closure.sink_trace.subject.content.callsites: + callsites.add(callsite.caller_func_addr) + callsites.add(callsite.callee_func_addr) + + return callsites + + def get_call_locations(self): + callsites = { + x.caller_func_addr for x in self.sink_trace.subject.content.callsites + } + + trace_1_idx = self.handler.analyzed_list.index(self.sink_trace) + call_locs = { + x.code_loc.ins_addr or x.code_loc.block_addr + for x in self.handler.analyzed_list[:trace_1_idx] + if x.function.addr in callsites + } + return call_locs + + @staticmethod + def _type_check(other): + if not isinstance(other, Closure): + raise ValueError(f"Cannot compare Closure and {other.__class__.__name__}") diff --git a/package/argument_resolver/utils/format_prototype.py b/package/argument_resolver/utils/format_prototype.py new file mode 100644 index 0000000..7afed44 --- /dev/null +++ b/package/argument_resolver/utils/format_prototype.py @@ -0,0 +1,25 @@ +class FormatPrototype: + def __init__(self, prototype, specifier, position): + self._prototype = prototype + if isinstance(specifier, bytes): + specifier = specifier.decode("latin-1") + self._specifier = specifier + self._position = position + + @property + def prototype(self): + return self._prototype + + @property + def specifier(self): + return self._specifier + + @property + def position(self): + return self._position + + def __str__(self): + return f"FormatPrototype<'{self.prototype}', position: {self.position}>" + + def __repr__(self): + return self.__str__() diff --git a/package/argument_resolver/utils/graph_helper.py b/package/argument_resolver/utils/graph_helper.py new file mode 100644 index 0000000..3c574d0 --- /dev/null +++ b/package/argument_resolver/utils/graph_helper.py @@ -0,0 +1,89 @@ +import networkx as nx + + +class GraphHelper: + """ + Helpful class for displaying Dependency Graph + """ + + @staticmethod + def find_subgraph_from_defn(graph, sink): + source_nodes = [x for x in graph if graph.in_degree(x) == 0] + final_graph = None + for source_node in source_nodes: + g = nx.bfs_tree(graph, source_node) + if sink not in g: + continue + + if final_graph is None: + final_graph = g + else: + final_graph = nx.compose(final_graph, g) + return final_graph + + + @staticmethod + def calculate_node_depth(graph, depth=0, node_depth=None): + if node_depth is None: + node_depth = {x: 0 for x in graph if graph.in_degree(x) == 0} + + for node in [k for k, v in node_depth.items() if v == depth]: + for pred in graph.predecessors(node): + if pred not in node_depth: + node_depth[pred] = depth - 1 + GraphHelper.calculate_node_depth( + graph, depth=depth - 1, node_depth=node_depth + ) + for succ in graph.successors(node): + if succ not in node_depth: + node_depth[succ] = depth + 1 + GraphHelper.calculate_node_depth( + graph, depth=depth + 1, node_depth=node_depth + ) + return node_depth + + + @staticmethod + def show_graph(graph, defns=None): + import matplotlib.pyplot as plt + + # g = Utils._find_subgraph_from_defn(graph, sink) + if defns is None: + defns = set() + depth_by_node = GraphHelper.calculate_node_depth(graph) + node_depth = {} + for n in depth_by_node: + depth = depth_by_node[n] + if depth not in node_depth: + node_depth[depth] = [] + node_depth[depth].append(n) + + max_width = 1 + max_height = 1 + height_step = max_height / len(node_depth) + pos = {} + for d in sorted(node_depth): + width_step = max_width / (len(node_depth[d]) + 1) + cur_step = width_step + for node in sorted( + node_depth[d], key=lambda x: len(nx.descendants(graph, x)), reverse=True + ): + pos[node] = (cur_step, max_height - (d * height_step)) + cur_step += width_step + labels = {} + colors = [] + for node in graph.nodes(): + if node not in pos: + pos[node] = (1, 1) + labels[node] = node + if node in defns: + colors.append("orange") + elif graph.in_degree(node) == 0: + colors.append("green") + elif graph.out_degree(node) == 0: + colors.append("red") + else: + colors.append("blue") + + nx.draw(graph, pos=pos, labels=labels, font_size=8, node_color=colors) + plt.show() diff --git a/package/argument_resolver/utils/rank.py b/package/argument_resolver/utils/rank.py new file mode 100644 index 0000000..b50ba68 --- /dev/null +++ b/package/argument_resolver/utils/rank.py @@ -0,0 +1,48 @@ +from itertools import combinations +from functools import reduce + +tag_values = { + "env": 0.7, + "file": 0.5, + "argv": 0.4, + "network": 0.6, + "unknown": 0, +} + +categories = { + "env": ["env", "getenv", "nvram", "frontend_param", "getvalue"], + "file": ["fopen", "read", "open", "fread", "fgets", "stdin"], + "argv": ["argv"], + "network": ["socket", "accept", "recv", "nflog_get_payload"], + "unknown": ["unknown"], +} + + +def calc_probability(tags): + valid_tags = set() + for tag in tags: + for category, funcs in categories.items(): + if tag in funcs: + valid_tags.add(category.lower()) + + if len(tags) == 0: + return tag_values["unknown"] + elif len(tags) == 1: + return tag_values[next(iter(valid_tags))] + + probability = max(tag_values[x] for x in valid_tags) + return probability + + +def get_value_from_source(tag): + func = tag.split("(")[0].lower() + func = "nvram" if "nvram" in func else func + func = "recv" if "recv" in func else func + for category, funcs in categories.items(): + if func in funcs: + return tag_values[category] + return 0 + + +def get_rank(sources): + return {source: get_value_from_source(source) for source in sources} diff --git a/package/argument_resolver/utils/rda.py b/package/argument_resolver/utils/rda.py new file mode 100644 index 0000000..922e446 --- /dev/null +++ b/package/argument_resolver/utils/rda.py @@ -0,0 +1,307 @@ +import claripy +import logging +from pathlib import Path + +from typing import Iterable + +import pyvex + +from angr.analyses.reaching_definitions.reaching_definitions import ( + ReachingDefinitionsAnalysis, + ReachingDefinitionsState, +) +from angr.analyses.reaching_definitions.engine_vex import SimEngineRDVEX +from angr.knowledge_plugins.key_definitions.atoms import ( + Register, + MemoryLocation, + SpOffset, +) +from angr.calling_conventions import SimRegArg +from angr.code_location import CodeLocation +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues + +l = logging.getLogger(name=__name__) + +from argument_resolver.utils.utils import Utils + + +class CustomRDA(ReachingDefinitionsAnalysis): + timeout_set = False + + def __init__( + self, + *args, + is_reanalysis=False, + start_time=None, + rda_timeout=None, + prev_observed=None, + **kwargs, + ): + self.prev_observed = prev_observed + self.is_reanalysis = is_reanalysis + self.start_time = start_time + self.rda_timeout = rda_timeout + super().__init__(*args, **kwargs) + + def _run_on_node(self, node, state: ReachingDefinitionsState): + """ + + :param node: The current node. + :param state: The analysis state. + :return: A tuple: (reached fix-point, successor state) + """ + if not isinstance(self._engine_vex, CustomVexEngine): + # This is the first instance of the analysis + self._engine_vex = CustomVexEngine( + self.project, + functions=self.kb.functions, + function_handler=self._function_handler, + ) + self.model.func_addr = ( + self.project.kb.cfgs.get_most_accurate() + .get_any_node(node.addr) + .function_address + ) + + if self.prev_observed: + self.model.observed_results = self.prev_observed + + if node.addr in self.project.kb.functions and not self.is_reanalysis: + if self._function_handler.current_parent is None: + if ( + "main" in self.project.kb.functions + and node.addr == self.project.kb.functions["main"].addr + ): + state = self.taint_main_args(state) + state.codeloc = CodeLocation(block_addr=node.addr, stmt_idx=None) + self._engine_vex.state = state + self._engine_vex._handle_function( + MultiValues(claripy.BVV(node.addr, state.arch.bits)) + ) + + stored_func = self._function_handler.call_trace[-1] + self._function_handler.analyzed_list = [stored_func] + + stored_func = self._function_handler.call_trace[-1] + self._function_handler.call_stack.append(stored_func) + self._function_handler.current_parent = stored_func + res = super()._run_on_node(node, state) + return res + + def taint_main_args(self, state): + argc = claripy.BVS("ARGC", state.arch.bits, explicit_name=True) + argv = claripy.BVS("ARGV", state.arch.bits, explicit_name=True) + envp = claripy.BVS("ENVP", state.arch.bits, explicit_name=True) + taints = [argc, argv, envp] + main = self.project.kb.functions["main"] + + for idx, taint in enumerate(taints): + if idx >= len(main.arguments): + return state + if not isinstance(main.arguments[idx], SimRegArg): + raise ValueError("Expected Register Argument") + reg_tup = state.arch.registers[main.arguments[idx].reg_name] + if "ARGV" in taint.variables: + argv_loc = 0xDEADC0DE + arg_size = 0x100 + state.registers.store( + reg_tup[0], + state.stack_address(argv_loc), + endness=state.arch.memory_endness, + ) + for loc in range(10): + arg_pointer = state.stack_address(argv_loc) + loc * state.arch.bytes + pointer_dst = state.stack_address(argv_loc) + (1 + loc) * arg_size + state.stack.store( + state.get_stack_address(arg_pointer), + pointer_dst, + endness=state.arch.memory_endness, + ) + if loc == 0: + name = Path(self.project.filename).name.encode() + b"\x00" + memloc = MemoryLocation( + SpOffset(state.arch.bits, argv_loc + (1 + loc) * arg_size), + len(name), + ) + state.kill_and_add_definition( + memloc, MultiValues(claripy.BVV(name, len(name) * 8)) + ) + else: + arg_str = f"ARGV_{loc}" + argv_val = claripy.BVS( + arg_str, state.arch.bits, explicit_name=True + ) + argv_val.variables = frozenset( + set(argv_val.variables) | {"TOP"} + ) + memloc = MemoryLocation( + SpOffset(state.arch.bits, argv_loc + (1 + loc) * arg_size), + argv_val.size() // 8, + ) + state.kill_and_add_definition(memloc, MultiValues(argv_val)) + else: + old_val = state.registers.load(*reg_tup) + new_mv = MultiValues() + for offset, val_set in old_val.items(): + for val in val_set: + if taint.variables <= val.variables: + new_mv.add_value(offset, val) + else: + new_mv.add_value(offset, val + taint) + state.registers.store( + reg_tup[0], new_mv, endness=state.arch.memory_endness + ) + return state + + +class CustomVexEngine(SimEngineRDVEX): + + # Having a context for Codelocations makes the hash difficult to resolve for our purposes + @property + def _context(self) -> None: + return None + + # Normal Guarded Load also loads the alt value, we don't care + def _handle_LoadG(self, stmt): + guard = self._expr(stmt.guard) + guard_v = guard.one_value() + + if claripy.is_true(guard_v): + # FIXME: full conversion support + if stmt.cvt.find("Ident") < 0: + l.warning("Unsupported conversion %s in LoadG.", stmt.cvt) + load_expr = pyvex.expr.Load(stmt.end, stmt.cvt_types[1], stmt.addr) + wr_tmp_stmt = pyvex.stmt.WrTmp(stmt.dst, load_expr) + self._handle_WrTmp(wr_tmp_stmt) + elif claripy.is_false(guard_v): + wr_tmp_stmt = pyvex.stmt.WrTmp(stmt.dst, stmt.alt) + self._handle_WrTmp(wr_tmp_stmt) + else: + if stmt.cvt.find("Ident") < 0: + l.warning("Unsupported conversion %s in LoadG.", stmt.cvt) + load_expr = pyvex.expr.Load(stmt.end, stmt.cvt_types[1], stmt.addr) + + load_expr_v = self._expr(load_expr) + # alt_v = self._expr(stmt.alt) + + # data = load_expr_v.merge(alt_v) + self._handle_WrTmpData(stmt.dst, load_expr_v) + + # Normal Guarded Store also stores the alt value + # this sometimes messes up if alt-values are not intended for use so just ignore it to be safe. + def _handle_StoreG(self, stmt: pyvex.IRStmt.StoreG): + guard = self._expr(stmt.guard) + guard_v = guard.one_value() + + if claripy.is_false(guard_v): + return + + else: + addr = self._expr(stmt.addr) + if addr.count() == 1: + addrs = next(iter(addr.values())) + size = stmt.data.result_size(self.tyenv) // 8 + data = self._expr(stmt.data) + self._store_core(addrs, size, data) + + # Merging these values often causes conflicting issues down the line so always take the true if unresolvable + def _handle_ITE(self, expr: pyvex.IRExpr.ITE): + cond = self._expr(expr.cond) + cond_v = cond.one_value() + iftrue = self._expr(expr.iftrue) + iffalse = self._expr(expr.iffalse) + + if claripy.is_true(cond_v): + return iftrue + elif claripy.is_false(cond_v): + return iffalse + else: + data = iftrue + return data + + def _handle_Put(self, stmt): + reg_offset: int = stmt.offset + size: int = stmt.data.result_size(self.tyenv) // 8 + reg = Register(reg_offset, size, self.arch) + data = self._expr(stmt.data) + + if self.arch.sp_offset == reg_offset and any( + self.state.is_top(x) for x in next(iter(data.values())) + ): + old_sp = self.state.registers.load(self.arch.sp_offset, self.arch.bytes) + if old_sp.one_value() is None: + stripped_values_set = { + v._apply_to_annotations(lambda alist: None) + for v in next(iter(old_sp.values())) + } + + annotations = [] + for v in stripped_values_set: + annotations += list(v.annotations) + + if len(stripped_values_set) > 1: + new_sp = next(iter(stripped_values_set)) + if annotations: + new_sp = new_sp.annotate(*annotations) + data = MultiValues(new_sp) + + else: + offsets = {} + new_data = MultiValues() + for value in next(iter(data.values())): + stack_offset = self.state.get_stack_offset(value) + if stack_offset not in offsets: + offsets[stack_offset] = value - 0x30 + else: + offsets[stack_offset].annotate( + *( + list(offsets[stack_offset].annotations) + + list(value.annotations) + ) + ) + + for val in offsets.values(): + new_data.add_value(0, val) + data = new_data + else: + data = MultiValues(old_sp.one_value() - 0x30) + + # special handling for references to heap or stack variables + if data.count() == 1: + for d in next(iter(data.values())): + if self.state.is_heap_address(d): + heap_offset = self.state.get_heap_offset(d) + if heap_offset is not None: + self.state.add_heap_use(heap_offset, 1, "Iend_BE") + elif self.state.is_stack_address(d): + stack_offset = self.state.get_stack_offset(d) + if stack_offset is not None: + self.state.add_stack_use(stack_offset, 1, "Iend_BE") + + if self.state.exit_observed and reg_offset == self.arch.sp_offset: + return + self.state.kill_and_add_definition(reg, data) + + # This is an attempt to preserve ARGV and ENVP context for readability + def _load_core( + self, addrs: Iterable[claripy.ast.Base], size: int, endness: str + ) -> MultiValues: + argv_list = [] + addrs = list(addrs) + for addr in addrs: + if "ARGV" in addr.variables or "ENVP" in addr.variables: + argv_list.append(addr) + if self.state.is_heap_address(addr): + pass + result = super()._load_core(addrs, size, endness) + if argv_list: + new_mv = MultiValues() + for offset, values in result.items(): + for value in values: + if "TOP" in value.variables and argv_list: + new_mv.add_value(offset, argv_list.pop()) + else: + new_mv.add_value(offset, value) + return new_mv + else: + return result diff --git a/package/argument_resolver/utils/stored_function.py b/package/argument_resolver/utils/stored_function.py new file mode 100644 index 0000000..fb637ac --- /dev/null +++ b/package/argument_resolver/utils/stored_function.py @@ -0,0 +1,382 @@ +from typing import Dict, Set, Optional + +from angr.sim_type import SimTypePointer, SimTypeChar +from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState +from angr.analyses.reaching_definitions.function_handler import FunctionCallData +from angr.knowledge_plugins.key_definitions.atoms import ( + Atom, + Register, + MemoryLocation, + SpOffset, +) +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues + +from angr.code_location import CodeLocation + +from .transitive_closure import transitive_closures_from_defs, get_constant_data +from .utils import Utils +from .calling_convention import CallingConventionResolver, cc_to_rd + +from archinfo import Endness + + +class StoredFunction: + def __init__( + self, + state: ReachingDefinitionsState, + data: FunctionCallData, + call_stack, + depth: int, + ): + + self._data = data + self.depth = depth + 1 + self.target_defns: Set[Definition] = set() + self.state = state + self.definitions: Set[Definition] = set() + self.return_definitions: Set[Definition] = set() + self.call_stack = {x.code_loc.ins_addr for x in call_stack[1:]} + self._constant_data = {} + self._return_data = {} + self._arg_vals = {} + self._ret_val = None + self._closures = None + self._function = None + self._hash = None + + @property + def constant_data(self): + return self._constant_data + + @property + def subject(self): + return self.state._subject + + @property + def return_data(self): + if not self._return_data: + pass + return self._return_data + + @property + def name(self): + return self._data.name + + @property + def arg_vals(self): + if not self._arg_vals: + self._arg_vals = self.get_arg_vals() + return self._arg_vals + + @property + def args_atoms(self): + return self._data.args_atoms + + @property + def code_loc(self): + return self._data.callsite_codeloc + + @property + def atoms(self): + atoms = {y for x in self.args_atoms for y in x} + for atom in atoms: + if atom not in self.constant_data: + self._constant_data[atom] = None + if hasattr(atom, "endness") and atom.endness is None: + atom.endness = self.state.arch.memory_endness + return atoms + + @property + def visited_blocks(self): + return self._data.visited_blocks + + @property + def function(self): + addr = self._data.function_codeloc.block_addr + if ( + self._data.function is None + and addr + and addr in self.state.analysis.project.kb.functions + ): + self._function = self.state.analysis.project.kb.functions[addr] + else: + self._function = self._data.function + + return self._function + + @property + def ret_atoms(self): + return self._data.ret_atoms + + @property + def closures(self): + self.save_closures() + + return self._closures + + @property + def ret_val(self): + return self._ret_val + + @property + def all_definitions(self): + return self.definitions | self.return_definitions + + def depends( + self, + dest: Optional[Atom], + *sources: Atom, + value: Optional[MultiValues] = None, + apply_at_callsite: bool = False, + ): + self._data.depends( + dest, *sources, value=value, apply_at_callsite=apply_at_callsite + ) + + def save_closures(self): + if self._closures is None: + self._closures = self.get_closures() + + @property + def cc(self): + if self._data.cc is None: + try: + self._data.cc = CallingConventionResolver( + self.state.analysis.project, + self.state.arch, + self.state.analysis.project.kb.functions, + ).get_cc(self._data.name) + except KeyError: + pass + + return self._data.cc + + def get_closures(self) -> Dict[Atom, Set[Definition]]: + closures: Dict[Atom, Set[Definition]] = {} + for atom in self.atoms: + defs = {defn for defn in self.definitions if defn.atom == atom} + closures[atom] = { + defn + for graph in transitive_closures_from_defs( + defs, self.state.dep_graph + ).values() + for defn in graph.nodes() + } + return closures + + def get_arg_vals(self): + vals = {} + for atom in [y for x in self.args_atoms for y in x]: + value = self.state.live_definitions.get_value_from_atom(atom) + if value is not None: + vals[atom] = value + else: + vals[atom] = Utils.unknown_value_of_unknown_size( + self.state, atom, self._data.callsite_codeloc + ) + return vals + + def tag_params(self, first_run=False): + if self.state.arch.name.startswith("MIPS"): + t9_reg = Register(*self.state.arch.registers["t9"], self.state.arch) + t9_val = self.state.live_definitions.get_value_from_atom(t9_reg) + self.depends(t9_reg, value=t9_val) + + if self.name.startswith("execl"): + self._data.args_atoms = self._get_execl_vararg_atoms(self.state) + elif any( + x in self.function.name + for x in ["printf", "scanf", "twsystem", "doSystemCmd", "execFormatCmd"] + ): + self._data.args_atoms = self._get_printf_vararg_atoms(self.state) + + for atom in self.atoms: + self.state.add_use(atom) + self.definitions |= set(self.state.get_definitions(atom)) + mv = self.state.live_definitions.get_value_from_atom(atom) + if not self.ret_atoms or atom not in self.ret_atoms: + self._data.depends(atom, value=mv, apply_at_callsite=True) + self._arg_vals[atom] = mv or MultiValues( + self.state.top(self.state.arch.bits) + ) + if mv is None: + self.constant_data[atom] = None + continue + + def save_constant_arg_data(self, state): + for defn in self.definitions: + mv = self.state.live_definitions.get_value_from_definition(defn) + atom = defn.atom + if isinstance(atom, MemoryLocation) and isinstance(atom.addr, SpOffset): + if -1 * atom.addr.offset >> 31 == 1: + real_offset = atom.addr.offset + 2**state.arch.bits + for x in self.atoms: + if isinstance(x, MemoryLocation) and isinstance( + x.addr, SpOffset + ): + if x.addr.offset == real_offset: + atom = x + break + if atom not in self.atoms: + for a in self.atoms: + if not isinstance(a, type(atom)): + continue + + if isinstance(a, Register): + if a.reg_offset == atom.reg_offset: + atom = a + break + elif isinstance(a, MemoryLocation): + if a.addr == atom.addr: + atom = a + break + + if hasattr(atom, "endness") and atom.endness is None: + atom.endness = self.state.arch.memory_endness + try: + self.constant_data[atom] = get_constant_data(defn, mv, state) + except (AssertionError, AttributeError): + self.constant_data[atom] = None + + def _get_function_code_loc(self): + return CodeLocation(self.function.addr, 0, ins_addr=self.function.addr) + + def _get_execl_vararg_atoms(self, state: ReachingDefinitionsState): + if self.function.calling_convention is None or self.function.prototype is None: + return [] + + atoms = [] + arg_session = self.function.calling_convention.arg_session(None) + ty = self.function.prototype.args[0] + for _ in range(10): + arg = self.function.calling_convention.next_arg( + arg_session, ty.with_arch(state.arch) + ) + atom = cc_to_rd(arg, state.arch, state) + val = Utils.get_values_from_cc_arg(arg, state, state.arch) + one_val = val.one_value() + if one_val is not None and one_val.concrete and one_val.concrete_value == 0: + break + atoms.append({atom}) + + return atoms + + def _get_printf_vararg_atoms(self, state: ReachingDefinitionsState): + if self.function.calling_convention is None or self.function.prototype is None: + return [] + + atoms = [] + arg_session = self.function.calling_convention.arg_session(None) + for ty in self.function.prototype.args: + arg = self.function.calling_convention.next_arg( + arg_session, ty.with_arch(state.arch) + ) + atoms.append({cc_to_rd(arg, state.arch, state)}) + + fmt_ptrs = Utils.get_values_from_cc_arg(arg, state, state.arch) + fmt_strs = Utils.get_strings_from_pointers(fmt_ptrs, state, self.code_loc) + for fmt_str in [ + x for x in Utils.get_values_from_multivalues(fmt_strs) if x.concrete + ]: + for _ in Utils.get_prototypes_from_format_string( + Utils.bytes_from_int(fmt_str) + ): + arg = self.function.calling_convention.next_arg( + arg_session, SimTypePointer(SimTypeChar()).with_arch(state.arch) + ) + atoms.append({cc_to_rd(arg, state.arch, state)}) + + return atoms + + def has_definition(self, definition: Definition) -> bool: + return any(definition == defn for defn in self.definitions) + + def save_ret_value(self, value=None): + for atom in self.ret_atoms: + self._ret_val = value + self.depends(atom, *self.atoms, value=value) + if value is not None: + self.return_definitions = set( + LiveDefinitions.extract_defs_from_mv(self._ret_val) + ) + + def save_constant_ret_data(self, new_state=None, value=None): + state = new_state or self.state + for atom in self.ret_atoms: + for defn in LiveDefinitions.extract_defs_from_mv(value): + if atom not in self.return_data: + self.return_data[atom] = [] + try: + if value is not None: + self.return_data[atom].extend( + get_constant_data(defn, value, state) + ) + else: + self.return_data[atom].extend([None]) + except AssertionError: + self.return_data[atom].extend([None]) + + def handle_ret(self, new_state=None, value=None): + if not self.ret_atoms: + return + if ( + value is None + and not self.function.is_simprocedure + and not self.function.is_plt + ): + merged_values = None + for atom in self.ret_atoms: + if new_state: + state = new_state + else: + state = self.state + v = state.get_values(atom) + if merged_values is None: + merged_values = v + else: + merged_values = merged_values.merge(v) + value = merged_values + + self.save_ret_value(value=value) + + @property + def failed_tuple(self): + return False, self.state, self.visited_blocks, self.state.dep_graph + + @property + def success_tuple(self): + return True, self.state, self.visited_blocks, self.state.dep_graph + + @property + def func_tuple(self): + return Utils.get_func_tuple( + self.function, self.subject, self.state.arch, self.code_loc + ) + + @property + def exit_site_addresses(self): + return [f.addr for f in self.function.ret_sites + self.function.jumpout_sites] + + def __str__(self): + return f"{self.function.name}({', '.join(str(self.arg_vals[y]) for x in self.args_atoms for y in x)}) @ {hex(self.code_loc.ins_addr or self.code_loc.block_addr)}" + + def __repr__(self): + return self.__str__() + + def __hash__(self): + if self._hash is None: + # hash_args = [] + # for arg in {y for x in self.args_atoms for y in x}: + # for vals in self.arg_vals[arg].values(): + # hash_args.extend(list(vals)) + # hash_args.extend(list(self.visited_blocks)) + # hash_args.append(self.code_loc.block_addr) + # self._hash = hash(tuple(hash_args)) + self._hash = hash(tuple([str(self)] + list(self.call_stack))) + return self._hash + + def __eq__(self, other): + + return hash(self) == hash(other) diff --git a/package/argument_resolver/utils/transitive_closure.py b/package/argument_resolver/utils/transitive_closure.py new file mode 100644 index 0000000..e3f5441 --- /dev/null +++ b/package/argument_resolver/utils/transitive_closure.py @@ -0,0 +1,176 @@ +from typing import Dict, Iterable, Set, TYPE_CHECKING + +import networkx +import claripy + +from angr.knowledge_plugins.key_definitions.live_definitions import LiveDefinitions +from angr.code_location import ExternalCodeLocation +from angr.analyses.reaching_definitions.reaching_definitions import ReachingDefinitionsAnalysis +from angr.analyses.reaching_definitions.reaching_definitions import ReachingDefinitionsState +from angr.engines.light import SpOffset +from angr.knowledge_plugins.cfg import CFGNode +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation, Register + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.errors import SimMemoryMissingError + +from .utils import Utils + + +def transitive_closures_from_defs( + vulnerable_defs: Iterable[Definition], dep_graph: "DepGraph" +) -> Dict[Definition, networkx.DiGraph]: + """ + Determine the transitive closure of a single atom in the dependency graph of a given + (computed for a given sink). + + :param vulnerable_atom: The vulnerable_atom to compute the transitive closure of in the dependency graph. + :param rda: The ReachingDefinitionsAnalysis that computed the dependency graph to get the transitive closure from. + :return: A dictionary where sink-caller nodes are keys, and related transitive closure are values. + """ + + closures = {} + for defn in vulnerable_defs: + closure = dep_graph.transitive_closure(defn) + if len(closure) > 0: + closures[defn] = closure + + return closures + + +def contains_an_external_definition( + transitive_closures: Dict[Definition, Set["Closure"]] +) -> bool: + """ + Determine if there is any values in the closure are marked as coming from External; + These values are not resolved. + + *Note* This is lazily evaluated. + """ + any_node_is_external = lambda nodes: any( + isinstance(node.codeloc, ExternalCodeLocation) for node in nodes + ) + return any( + any_node_is_external(closure.rda.dep_graph.nodes()) + for s in transitive_closures.values() for closure in s + ) + + +def represents_constant_data( + definition: Definition, + values: MultiValues, + livedef: "LiveDefinitions", +) -> bool: + """ + Tell if a is completely resolved to a constant value, or not (might be influenced by external factors). + + :param definition: The definition to consider. + :param values: The pair of all definitions and values. + :param livedef: LiveDefinition pertaining to the original definition. + :param dependency_graph: The transitive closure containing + + :return: `True` if the definition represents constant data, `False` otherwise. + """ + data = get_constant_data(definition, values, livedef) + if data is None or any(d is None for d in data): + return False + else: + return True + + +def get_constant_data( + definition: Definition, + values: MultiValues, + livedef: ReachingDefinitionsState, +) -> list: + """ + Tell if a is completely resolved to a constant value, or not (might be influenced by external factors). + + :param definition: The definition to consider. + :param values: The pair of all definitions and values. + :param livedef: LiveDefinition pertaining to the original definition. + :param dependency_graph: The transitive closure containing + + :return: `True` if the definition represents constant data, `False` otherwise. + """ + if isinstance(definition.atom, MemoryLocation): + + def _is_concrete(datum): + if datum.concrete: + addr = datum._model_concrete.value + try: + mv = livedef.heap.load(addr, definition.atom.size, endness=livedef.arch.memory_endness) + except SimMemoryMissingError: + try: + mv = livedef.memory.load(addr, definition.atom.size, endness=livedef.arch.memory_endness) + except SimMemoryMissingError: + mv = MultiValues(offset_to_values={0: {datum}}) + elif livedef.is_stack_address(datum): + try: + if datum.op == "Reverse": + datum = datum.args[0] + if livedef.get_stack_address(datum) is None: + return [None] + endness = livedef.arch.memory_endness + mv = livedef.stack.load( + livedef.get_stack_address(datum), + definition.atom.size, + endness=endness, + ) + except SimMemoryMissingError: + return [None] + else: + return [None] + + values = Utils.get_values_from_multivalues(mv) + if not all(isinstance(val, claripy.ast.Base) and val.concrete for val in values): + return [None] + + return values + + vals = [_is_concrete(v) for vals in values.values() for v in vals] + return [y for x in vals for y in x] + + elif isinstance(definition.atom, Register): + pointed_addresses = Utils.get_values_from_multivalues(values) + + data_all_concrete = all( + isinstance(v, (int, SpOffset)) + or ( + isinstance(v, claripy.ast.Base) + and (Utils.is_stack_address(v) or v.concrete or Utils.is_heap_address(v)) + ) + for v in pointed_addresses + ) + if not data_all_concrete: + return [None] + + new_mv = MultiValues() + concrete_vals = [] + for offset, vals in values.items(): + for val in vals: + try: + sp = livedef.get_sp() + except AssertionError: + sp = livedef.arch.initial_sp + if isinstance(livedef, LiveDefinitions): + proj = livedef.project + else: + proj = livedef.analysis.project + if val.concrete and proj is not None and not Utils.is_pointer(val, sp, proj): + concrete_vals.append(val) + else: + new_mv.add_value(offset, val) + + strings = Utils.get_strings_from_pointers( + new_mv, livedef, definition.codeloc + ) + values = [y for x in strings.values() for y in x] + if all(x.concrete for x in values): + return values + concrete_vals + return [None] + + else: + message = f"The case where the given definition's atom is of type {definition.atom} has not been handled!" + raise NotImplementedError(message) diff --git a/package/argument_resolver/utils/utils.py b/package/argument_resolver/utils/utils.py new file mode 100644 index 0000000..0c89026 --- /dev/null +++ b/package/argument_resolver/utils/utils.py @@ -0,0 +1,821 @@ +import string + +import claripy +import logging +import re +import networkx as nx +import itertools +import pprint + +from typing import Iterable, Optional, Tuple, Union, List, Set +from functools import lru_cache + +from archinfo.arch import Arch +from cle import ELF, PE + +import angr +from angr.calling_conventions import SimRegArg, SimStackArg, SimFunctionArgument +from angr.code_location import CodeLocation +from angr.knowledge_plugins.key_definitions.tag import ( + UnknownSizeTag, +) + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.sim_type import SimTypePointer, SimTypeChar +from angr.knowledge_plugins.functions.function import Function +from angr.knowledge_plugins.key_definitions import LiveDefinitions +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation, Atom +from angr.knowledge_plugins.key_definitions.constants import OP_AFTER, OP_BEFORE +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.heap_address import HeapAddress +from angr.errors import SimMemoryMissingError +from angr.engines.light import SpOffset +from angr.analyses.reaching_definitions.rd_state import ReachingDefinitionsState +from angr.analyses.reaching_definitions.dep_graph import DepGraph + +from argument_resolver.utils.format_prototype import FormatPrototype + +from archinfo import Endness + + +def _is_stack_pointer( + ptr: Union[int, SpOffset], sp: Union[int, SpOffset], initial_sp: int +): + if isinstance(ptr, SpOffset): + return True + if isinstance(sp, int): + return sp <= ptr <= initial_sp + + return False + + +def _get_strings_from_concrete_memory_area( + string_pointer: int, state: "ReachingDefinitionsState" +) -> Optional[MultiValues]: + """ + :param string_pointer: The concrete pointer to a memory region. + :param state: + :return: A set of possible string values that can be pointed by string_pointer, or None. + """ + + memory_values = _get_strings_from_pointer(string_pointer, state, state.memory.load) + return memory_values + + +# is_pointer = False +# for idx, arg in enumerate(trace.function.calling_convention.arg_locs(trace.function.prototype)): +# if arg.reg_name == self.project.arch.register_names[defn.atom.reg_offset]: +# if hasattr(trace.function.prototype.args[idx], "pts_to"): +# is_pointer = True +# break + + +def _get_strings_from_concrete_pointer( + string_pointer: int, state: "ReachingDefinitionsState", codeloc: CodeLocation +) -> MultiValues: + """ + :param string_pointer: + :param state: + :return: A . + """ + + # Read data from memory definition, check static memory region if no definition was found + # or if stack and heap pointers with string stored with load we will load all possible strings + memory_values = _get_strings_from_concrete_memory_area(string_pointer, state) + all_undefined = all( + state.is_top(v) for v in Utils.get_values_from_multivalues(memory_values) + ) + + if memory_values and not all_undefined: + return memory_values + + project = None + if isinstance(state, ReachingDefinitionsState): + project = state.analysis.project + elif isinstance(state, LiveDefinitions): + project = state.project + + if project is None: + return Utils.unknown_value_of_unknown_size( + state, MemoryLocation(string_pointer, state.arch.bytes * 8), codeloc + ) + size = 0 + is_null = False + try: + while not is_null: + memory_content = project.loader.memory.load(string_pointer + size, 1) + is_null = memory_content == b"\x00" + size += 1 - is_null + size = size if size != 0 else 1 + memory_content = project.loader.memory.load(string_pointer, size) + if memory_content == b"\x00": + memloc = MemoryLocation(string_pointer, state.arch.bytes) + return Utils.unknown_value_of_unknown_size(state, memloc, codeloc) + else: + memloc = MemoryLocation(string_pointer, size) + values = MultiValues(claripy.BVV(memory_content, size * 8)) + state.kill_and_add_definition(memloc, values, endness=Endness.BE) + return values + except KeyError: + pass + + heap_values = _get_strings_from_pointer(string_pointer, state, state.heap.load) + return heap_values + + +def _get_strings_from_heap_offset( + string_pointer: claripy.ast.Base, + state: "ReachingDefinitionsState", + codeloc: CodeLocation, +) -> MultiValues: + """ + Get string values pointed by string_pointer. + """ + heap_offset = Utils.get_heap_offset(string_pointer) + heap_values = _get_strings_from_pointer(heap_offset, state, state.heap.load) + + return heap_values + + +def _get_strings_from_stack_offset( + string_pointer: claripy.ast.Base, + state: "ReachingDefinitionsState", + codeloc: CodeLocation, +) -> MultiValues: + """ + Get string values pointed by string_pointer. + """ + if string_pointer.op == "Reverse": + stack_pointer = state.get_stack_address(string_pointer.reversed) + else: + stack_pointer = state.get_stack_address(string_pointer) + + if stack_pointer is None: + memloc = MemoryLocation(string_pointer, state.arch.bytes) + return Utils.unknown_value_of_unknown_size(state, memloc, codeloc) + + stack_values = _get_strings_from_pointer(stack_pointer, state, state.stack.load) + + return stack_values + + +def _get_strings_from_pointer(pointer, state, load_func): + new_mv = None + if pointer is None: + new_mv = MultiValues() + new_mv.add_value(0, state.top(state.arch.bits)) + return new_mv + try: + mv = load_func(pointer, 1) + except SimMemoryMissingError: + return MultiValues(state.top(state.arch.bits)) + + for bvv in mv[0]: + defns = list(state.extract_defs(bvv)) + if len(defns) == 0: + continue + + max_defn = max(defns, key=lambda x: x.size) + endness = Endness.BE + + try: + tmp_mv = load_func(pointer, max_defn.size, endness=endness) + except SimMemoryMissingError as e: + if e.missing_size < max_defn.size: + tmp_mv = load_func( + pointer, abs(pointer - e.missing_addr), endness=endness + ) + else: + if new_mv is None: + new_mv = MultiValues(state.top(max_defn.size * 8)) + else: + new_mv = new_mv.merge(MultiValues(state.top(max_defn.size * 8))) + continue + + mv_dict = {} + for offset, sub_vals in tmp_mv.items(): + for sub_val in sub_vals: + if max_defn not in state.extract_defs(sub_val): + continue + + for idx in range(sub_val.size() // 8): + is_zero = sub_val.get_byte(idx) == 0 + if is_zero.args[0] is True and idx != 0: + if offset not in mv_dict: + mv_dict[offset] = set() + mv_dict[offset].add(sub_val.get_bytes(0, idx + 1)) + break + else: + if offset not in mv_dict: + mv_dict[offset] = set() + mv_dict[offset].add(sub_val) + if new_mv is None: + new_mv = MultiValues(mv_dict) + else: + new_mv = new_mv.merge(MultiValues(mv_dict)) + + if new_mv is None or new_mv.count() == 0: + new_mv = MultiValues({0: {state.top(state.arch.bits)}}) + return new_mv + + +class Utils: + # + # RDA: Definitions + # + log = logging.getLogger("FastFRUIT") + arch = None + + @staticmethod + def get_values_from_cc_arg( + arg: Union[SimStackArg, SimRegArg], + state: ReachingDefinitionsState, + arch: Arch, + ) -> MultiValues: + """ + Return all definitions for an argument (represented by a calling_conventions' SimRegArg or SimStackArg) from a + LiveDefinitions object. + :param arg: Argument + :param state: Register and memory definitions + :param arch: Architecture + :return: Definition(s) of the argument + """ + try: + if isinstance(arg, SimRegArg): + reg_offset = arch.registers[arg.reg_name][0] + mv = state.registers.load(reg_offset, size=arch.bytes) + elif isinstance(arg, SimStackArg): + sp = Utils.get_sp(state) + if sp is None: + Utils.log.warning("Failed to get stack value, returning TOP") + return MultiValues(state.top(arch.bits)) + addr = sp + arg.stack_offset + if isinstance(addr, SpOffset): + mv = state.stack.load( + addr.offset, size=arch.bytes, endness=state.arch.memory_endness + ) + elif isinstance(addr, int): + mv = state.stack.load( + addr, size=arch.bytes, endness=state.arch.memory_endness + ) + else: + raise TypeError( + f"Unsupported stack address type {type(addr).__name__}" + ) + else: + raise TypeError( + f"Expected SimRegArg or SimStackArg, got {type(arg).__name__}" + ) + return mv + except SimMemoryMissingError: + return MultiValues(state.top(arch.bits)) + + @staticmethod + def get_memory_location_from_bv(ptr_bv: claripy.ast.BV, state, size: int): + method = Utils.get_store_method_from_ptr(ptr_bv, state) + if state.is_top(method): + return None + + return MemoryLocation(method, size) + + # + # Format strings + # + + @staticmethod + def get_prototypes_from_format_string(fmt_string): + # http://www.cplusplus.com/reference/cstdio/printf + try: + fmt_string = fmt_string.decode() + except (UnicodeDecodeError, AttributeError): + pass + + flags = r"[-+ #0]" + width = r"\d+|\*" + precision = r"\.(?:\d+|\*)" + length = r"hh|h|l|ll|j|z|t|L" + # noinspection SpellCheckingInspection + # specifier = r"[diuoxXfFeEgGaAcspn]" + specifier = r"[diuoxXcsp\[]" + + # group(0): match + # group(1) or group(2): specifier + pattern = rf"(?:%(?:{flags}{{0,5}})(?:{width})?(?:{precision})?(?:{length})?({specifier})|%(%%))" + + # https://docs.python.org/2/library/re.html#re.finditer + # The string is scanned left-to-right, and matches are returned in the order found. + if isinstance(fmt_string, bytes): + pattern = pattern.encode() + return [ + FormatPrototype(m.group(0), m.group(1) or m.group(2), m.start()) + for m in re.finditer(pattern, fmt_string, re.M) + ] + + # + # Strings from pointers + # + @staticmethod + def is_stack_address(addr: claripy.ast.Base) -> bool: + return "stack_base" in addr.variables + + @staticmethod + def is_heap_address(addr: claripy.ast.Base) -> bool: + return "heap_base" in addr.variables + + @staticmethod + def get_heap_offset(addr: claripy.ast.Base) -> Optional[int]: + if "heap_base" in addr.variables: + if addr.op == "BVS": + return 0 + elif ( + addr.op == "__add__" + and len(addr.args) == 2 + and addr.args[1].op == "BVV" + ): + return addr.args[1]._model_concrete.value + return None + + @staticmethod + def gen_heap_address(offset: int, arch: Arch): + base = claripy.BVS("heap_base", arch.bits, explicit_name=True) + return base + offset + + @staticmethod + def gen_stack_address(offset: int, arch: Arch): + base = claripy.BVS("stack_base", arch.bits, explicit_name=True) + return base + offset + + @staticmethod + def get_strings_from_pointer( + string_pointer: Union[SpOffset, claripy.ast.Base], + state: "ReachingDefinitionsState", + codeloc: CodeLocation, + ) -> MultiValues: + """ + Retrieve all the potential strings pointed by string_pointer. + :param string_pointer: + :param state: + :param codeloc: + :return: The potential values of the string pointed by string_pointer in memory. + """ + if state.is_top(string_pointer): + # Checking for top values that are also tainted + if any("@" in x for x in string_pointer.variables): + return MultiValues(string_pointer) + memloc = MemoryLocation(string_pointer, state.arch.bytes) + return Utils.unknown_value_of_unknown_size(state, memloc, codeloc) + if not string_pointer.symbolic: + return _get_strings_from_concrete_pointer( + string_pointer._model_concrete.value, state, codeloc + ) + elif state.is_stack_address(string_pointer): + return _get_strings_from_stack_offset(string_pointer, state, codeloc) + elif Utils.is_heap_address(string_pointer): + return _get_strings_from_heap_offset(string_pointer, state, codeloc) + else: + Utils.log.warning( + "Strings: Expected int or claripy.ast.Base, got %s", + type(string_pointer).__name__, + ) + memloc = MemoryLocation(string_pointer, state.arch.bytes) + return Utils.unknown_value_of_unknown_size(state, memloc, codeloc) + + @staticmethod + def get_strings_from_pointers( + string_pointers: MultiValues, + state: "ReachingDefinitionsState", + codeloc: CodeLocation, + ) -> MultiValues: + """ + :param string_pointers: + A MultiValues representing pointers to strings. + Data content can be of type: . + :param state: + :param codeloc: + :return: The values of the string pointed by the string_pointers in memory. + """ + strings_mv = MultiValues() + for pointer in Utils.get_values_from_multivalues(string_pointers): + res = Utils.get_strings_from_pointer(pointer, state, codeloc) + strings_mv = strings_mv.merge(res) + + return strings_mv + + # + # Pointers + # + + @staticmethod + def is_pointer(ptr, sp, project): + arch = project.arch + loader = project.loader + + if ( + isinstance(ptr, (SpOffset, HeapAddress)) + or Utils.is_heap_address(ptr) + or Utils.is_stack_address(ptr) + ): + return True + + if isinstance(ptr, claripy.ast.BV): + if ptr.concrete: + ptr = ptr.concrete_value + + if not isinstance(ptr, int): + return False + + # Check for global variables and static strings + if isinstance(loader.main_object, ELF): # ELF + if len(loader.main_object.sections) > 0: + section = loader.find_section_containing(ptr) + if section is not None and section.is_executable is False: + return True + elif loader.find_section_containing(ptr) is not None: + return True + elif loader.main_object.min_addr < ptr < loader.main_object.max_addr: + return True + + elif isinstance(loader.main_object, PE): # PE + section = loader.find_section_containing(ptr) + if section is not None and section.is_executable is False: + return True + + else: # Others + if loader.main_object.min_addr <= ptr <= loader.main_object.max_addr: + return True + + # Stack + if isinstance(sp, int): + if sp <= ptr <= arch.initial_sp: + return True + + return False + + @staticmethod + def get_store_method_from_ptr(ptr: claripy.ast.BV, state: ReachingDefinitionsState): + if ptr.concrete: + return ptr._model_concrete.value + + if state.is_heap_address(ptr): + return HeapAddress(state.get_heap_offset(ptr)) + + if state.is_stack_address(ptr): + return SpOffset(state.arch.bits, state.get_stack_offset(ptr)) + + return state.top(state.arch.bits) + + @staticmethod + def get_values_from_multivalues( + values: MultiValues, pretty=False + ) -> List[claripy.ast.Base]: + out_values = {} + for offset, value_set in sorted(values.items(), key=lambda x: x[0]): + concat_vals = {} + known_vals = {} + for value in value_set: + defns = list(LiveDefinitions.extract_defs(value)) or [None] + for new_defn in defns: + try: + if str(value) in known_vals[new_defn]: + continue + else: + known_vals[new_defn].add(str(value)) + concat_vals[new_defn].append(value) + except KeyError: + if offset != 0 and new_defn not in out_values: + continue + known_vals[new_defn] = {str(value)} + concat_vals[new_defn] = [value] + + for defn, vals in concat_vals.items(): + try: + out_values[defn].append(vals) + except KeyError: + out_values[defn] = [vals] + + # if pretty is True: + # final_vals = [] + # for x in out_values.values(): + # for prod in itertools.product(*x): + # new_val = None + # for y in prod: + # if new_val is None: + # new_val = y + # else: + # if len(new_val.args) == 3 and len(y.args) == 3 and isinstance(new_val.args[1], int) and isinstance(y.args[1], int): + # if new_val.args[1] - 1 == y.args[0]: + # if y.args[1] == 0: + # new_val = new_val.args[-1] + # else: + # new_val = new_val.args[-1][new_val.args[1]:y.args[1]] + # continue + + # new_val = new_val.concat(y) + + # if new_val is not None: + # for idx, final_val in enumerate(final_vals): + # if str(final_val) == str(new_val): + # annotated_val = final_val.annotate(*new_val.annotations) + # final_vals[idx] = annotated_val + # break + # else: + # final_vals.append(new_val) + + # return final_vals + + return [ + y[0].concat(*y[1:]) if len(y) > 1 else y[0] + for x in out_values.values() + for y in itertools.product(*x) + ] + + @staticmethod + def get_strings_from_multivalues(mv: MultiValues) -> Iterable[claripy.ast.Base]: + """ + :param mv: The MultiValues object to extract strings from + :retull possible string value combinationsrn: A list of a + """ + values = [] + for combo in itertools.product(*mv.values()): + combined_bvv = claripy.BVV(b"") + for c in combo: + combined_bvv = combined_bvv.concat(c) + values.append(c) + return values + + @staticmethod + def bytes_from_int(data: claripy.ast.Base) -> bytes: + if data.symbolic: + return data + output = data.concrete_value.to_bytes(data.size() // 8) + if output == b"": + return b"\x00" + return output + + @staticmethod + def get_size_from_multivalue(value: MultiValues) -> int: + max_offset = max(value.keys()) + max_size = max([x.size() for x in value[max_offset]]) // 8 + return max_size + max_offset + + @staticmethod + def strip_null_from_string(string: claripy.ast.Base) -> claripy.ast.Base: + if string.symbolic: + return string + + new_string = Utils.bytes_from_int(string) + while new_string.endswith(b"\x00"): + new_string = new_string[:-1] + result = claripy.BVV(new_string) + result.annotations = string.annotations + + return result + + @staticmethod + def unknown_value_of_unknown_size( + state: "ReachingDefinitionsState", atom: Atom, codeloc: CodeLocation + ) -> MultiValues: + return Utils.value_of_unknown_size( + state.top(state.arch.bytes * 8), state, atom, codeloc + ) + + @staticmethod + def value_of_unknown_size( + value, state: "ReachingDefinitionsState", atom: Atom, codeloc: CodeLocation + ) -> MultiValues: + atom._size = state.arch.bytes + tag = UnknownSizeTag(metadata={"tagged_by": "Utils"}) + definition: Definition = Definition(atom, codeloc, dummy=False, tags={tag}) + value = state.annotate_with_def(value, definition) + mv = MultiValues(offset_to_values={0: {value}}) + return mv + + @staticmethod + def has_unknown_size(value: claripy.ast.Base) -> bool: + for annotation in value.annotations: + if any( + map( + lambda tag: isinstance(tag, UnknownSizeTag), + annotation.definition.tags, + ) + ): + return True + + return False + + @staticmethod + def get_signed_value(value: int, size: int): + unsigned = value % 2**size + signed = unsigned - 2**size if unsigned >= 2 ** (size - 1) else unsigned + return signed + + @staticmethod + def get_sp(state: ReachingDefinitionsState) -> int: + try: + sp = state.get_sp() + except AssertionError: + sp_values: MultiValues = state.registers.load( + state.arch.sp_offset, size=state.arch.bytes + ) + next_vals = next(iter(sp_values.values())) + if len(next_vals) == 0: + raise AssertionError + else: + sp = max( + state.get_stack_address(x) + for x in next_vals + if state.get_stack_address(x) is not None + ) + return sp + + @staticmethod + def get_definition_dependencies( + graph: DepGraph, target_defns: Set[Definition], is_root=False + ) -> Set[Definition]: + """ + Recursively get all definitions that our target depends on + :param stored_func: + :param target_atoms: + :return: + """ + + # Get all root nodes of the dependency tree based on the target definitions + if not is_root: + graph = graph.graph.reverse(True) + else: + graph = graph.graph + + # Get all nodes reachable from the root nodes + dependent_defns: Set[Definition] = set() + for defn in {x for x in target_defns if x in graph}: + dependent_defns |= set(nx.dfs_preorder_nodes(graph, source=defn)) + return dependent_defns + + @staticmethod + def get_all_dependant_functions( + func_list: List["StoredFunction"], + graph: DepGraph, + target_defns: Set[Definition], + is_root=False, + ): + defns = Utils.get_definition_dependencies(graph, target_defns, is_root=is_root) + dependant_funcs = [] + for func in func_list: + if any(x in defns for x in func.all_definitions): + dependant_funcs.append(func) + return dependant_funcs + + @staticmethod + @lru_cache + def get_all_callsites(project: angr.Project): + """ + Retrieve all function callsites + :return: + A list of tuples, for each sink present in the binary, containing: the representation of the itself, the representation, + and the list of addresses in the binary the sink is called from. + """ + + def _call_statement_in_node(node) -> Tuple[str, int, OP_AFTER]: + """ + Assuming the node is the predecessor of a function start. + Returns the statement address of the `call` instruction. + """ + if len(node.block.disassembly.insns) < 2: + return None + addrs = [x.address for x in node.block.disassembly.insns] + addr = addrs[-1] + if project.arch.branch_delay_slot: + if node.block.disassembly.insns[-1].mnemonic == "nop": + addr = addrs[-2] + + return "insn", addr, OP_AFTER + + cfg = project.kb.cfgs.get_most_accurate() + final_callsites = [] + for func in project.kb.functions.values(): + if cfg.get_any_node(func.addr) is None: + continue + + calling_nodes = [ + x + for x in cfg.get_any_node(func.addr).predecessors + if x.block is not None and not x.has_return + ] + if calling_nodes: + calling_insns = list( + filter( + lambda x: x is not None, + map(_call_statement_in_node, calling_nodes), + ) + ) + calling_insns.append(("node", func.addr, OP_BEFORE)) + pre_nodes = [("node", x.addr, OP_BEFORE) for x in calling_nodes] + pre_nodes += [("node", x.addr, OP_AFTER) for x in calling_nodes] + final_callsites.extend(calling_insns + pre_nodes) + for x in func.ret_sites + func.jumpout_sites: + final_callsites.append(("node", x.addr, OP_AFTER)) + return final_callsites + + @staticmethod + def value_from_simarg( + simarg: SimFunctionArgument, livedef: LiveDefinitions, arch: Arch + ): + if isinstance(simarg, SimRegArg): + mv = livedef.registers.load(*arch.registers[simarg.reg_name]) + elif isinstance(simarg, SimStackArg): + mv = livedef.stack.load( + livedef.get_sp() + simarg.stack_offset, + arch.bytes, + endness=arch.memory_endness, + ) + else: + raise Exception(f"SimArg Value {simarg} not Handled") + return mv + + @staticmethod + def is_in_text_section(definition, project) -> bool: + res = project.loader.find_section_containing(definition.codeloc.ins_addr) + if res is None: + return False + if res.name != ".text": + return False + return True + + @staticmethod + def arguments_from_function(function: Function): + arguments = function.arguments + cc = function.calling_convention + if cc and any( + not isinstance(x, SimRegArg) and not isinstance(x, SimStackArg) + for x in arguments + ): + session = cc.arg_session(None) + new_args = [] + prototype_args = function.prototype.args + for idx, arg in enumerate(arguments): + if not isinstance(arg, SimRegArg) or not isinstance(arg, SimStackArg): + new_arg = cc.next_arg( + session, + SimTypePointer(SimTypeChar().with_arch(cc.ARCH)).with_arch( + cc.ARCH + ), + ) + else: + new_arg = cc.next_arg(session, prototype_args[idx]) + new_args.append(new_arg) + + arguments = new_args + if not arguments: + # Handler.function_cache[tuple([function])] = True + return [] + return arguments + + @staticmethod + def get_atoms_from_function(function: Function, registers: list): + function_arguments = Utils.arguments_from_function(function) + atoms = [Atom.from_argument(x, registers) for x in function_arguments] + return atoms + + @staticmethod + def get_callstring_for_function( + function: Function, callsites: List["CallSite"], codeloc: CodeLocation + ) -> List[int]: + chain = [codeloc.ins_addr] + all_callsites = set() + for callsite in reversed(callsites): + all_callsites.add(callsite.caller_func_addr) + all_callsites.add(callsite.callee_func_addr) + chain.extend(sorted(list(all_callsites))) + return chain + + @staticmethod + def get_func_tuple(function: Function, subject, registers, codeloc): + atoms = Utils.get_atoms_from_function(function, registers) + if isinstance(subject, Function) or isinstance(subject.content, Function): + call_string = [codeloc.ins_addr] + else: + call_string = Utils.get_callstring_for_function( + function, subject.content.callsites, codeloc + ) + + func_tuple = tuple([function] + atoms + call_string) + return func_tuple + + @staticmethod + def get_concrete_value_from_int(mv: MultiValues) -> Union[List[int], None]: + out = None + vals = Utils.get_values_from_multivalues(mv) + if all(x.concrete for x in vals): + out = [x.concrete_value for x in vals] + + return out + + @staticmethod + def get_bv_from_atom(atom: Atom, arch: Arch): + if isinstance(atom.addr, SpOffset): + return Utils.gen_stack_address(atom.addr.offset, arch) + elif isinstance(atom.addr, HeapAddress): + return Utils.gen_heap_address(atom.addr.value, arch) + elif isinstance(atom.addr, int): + return claripy.BVV(atom.addr, arch.bits) + return None diff --git a/package/assumptions.md b/package/assumptions.md new file mode 100644 index 0000000..a560e2c --- /dev/null +++ b/package/assumptions.md @@ -0,0 +1,19 @@ +# RDA +### Assumed Execution +We assume that only functions which touch arguments that flow to our sink, either in parameters or in return values, are relevant. +All other functions are skipped during analysis. + +#### Cons: +- Pointer Aliasing is excluded from this assumption + +#### Finding Callees for Assumed Execution +Find the transitive closure of the sink node in the dep graph. +Take all root nodes and DFS them in the graph. +Analyze only the calles that have parameters included in the resulting nodes. + +If the final sink has a TOP value for the pointer then we ignore that index + +# Engine Vex +### Guarded Load +Guarded loads can contain conditions that will always result in a true result. +Thus, we assume that if a guarded load occurs and one of the values is a top, collapse into the other value if it is resolvable. \ No newline at end of file diff --git a/package/py.typed b/package/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/package/requirements.txt b/package/requirements.txt new file mode 100644 index 0000000..86a352b --- /dev/null +++ b/package/requirements.txt @@ -0,0 +1,7 @@ +angr==9.2.94 +ipdb +nose2 +python-magic +pydot +claripy +networkx diff --git a/package/requirements_dev.txt b/package/requirements_dev.txt new file mode 100644 index 0000000..5c8910d --- /dev/null +++ b/package/requirements_dev.txt @@ -0,0 +1,5 @@ +ipdb +flake8 +mypy +pytest +pytest-cov \ No newline at end of file diff --git a/package/tests/binaries/Makefile b/package/tests/binaries/Makefile new file mode 100644 index 0000000..916713a --- /dev/null +++ b/package/tests/binaries/Makefile @@ -0,0 +1,22 @@ +CC=gcc + +PROGRAMS := $(wildcard */program.c) +DEBUG_FLAGS := -gdwarf -fvar-tracking-assignments + +all: $(PROGRAMS) + +clean: + rm $(wildcard */program) + +$(PROGRAMS): $@ + @echo + @echo $@: + $(CC) -o $(subst .c,,$@) $(DEBUG_FLAGS) $@ + +# Override `memcpy_resolved_and_unresolved/program.c`'s recipe to add specific flags. +memcpy_resolved_and_unresolved/program.c: + @echo + @echo $@: + $(CC) -o $(subst .c,,$@) $(DEBUG_FLAGS) -fno-builtin-memcpy $@ + +.PHONY: all clean $(PROGRAMS) diff --git a/package/tests/binaries/after_values/program b/package/tests/binaries/after_values/program new file mode 100755 index 0000000..783b31a Binary files /dev/null and b/package/tests/binaries/after_values/program differ diff --git a/package/tests/binaries/after_values/program.c b/package/tests/binaries/after_values/program.c new file mode 100644 index 0000000..fdc5d2e --- /dev/null +++ b/package/tests/binaries/after_values/program.c @@ -0,0 +1,39 @@ +#include +#include +#include + + +void* id(void *parameter) { + return parameter; +} + +void a(void *parameter) { + id(parameter); + system(parameter); + printf(parameter); + system(parameter); + printf(parameter); + system(parameter); + printf(parameter); +} + +void b(void *parameter) { + id(parameter); + execve(parameter, NULL, NULL); + printf(parameter); + execve(parameter, NULL, NULL); + printf(parameter); + execve(parameter, NULL, NULL); + printf(parameter); +} + + +int main(int argc, char *argv[]) { + puts("***** a *****"); + a(argv[1]); + + puts("***** b *****"); + b(argv[1]); + + return 0; +} diff --git a/package/tests/binaries/early_resolve/program b/package/tests/binaries/early_resolve/program new file mode 100755 index 0000000..b4853ea Binary files /dev/null and b/package/tests/binaries/early_resolve/program differ diff --git a/package/tests/binaries/early_resolve/program.c b/package/tests/binaries/early_resolve/program.c new file mode 100644 index 0000000..359ff4f --- /dev/null +++ b/package/tests/binaries/early_resolve/program.c @@ -0,0 +1,41 @@ +#include +#include + + +void wrapper(char *command) { + system(command); +} + +void child() { + char buf[0x40]; + char command[0x100]; + + read(0, buf, 0x40); + sprintf(command, "echo %s", buf); + wrapper(command); +} + +void parent_1() { + child(); +} + +void parent_2() { + child(); +} + +void nested_1(char *log) { + printf("Log: %s\n", log); + char *val = getenv("UNKNOWN"); + system(val); +} + +void constant_1(char *buf) { + system(buf); +} + +void main(int argc, char **argv) { + parent_1(); + parent_2(); + nested_1(argv[1]); + constant_1("echo 'HELLO WORLD'"); +} \ No newline at end of file diff --git a/package/tests/binaries/execlp/program b/package/tests/binaries/execlp/program new file mode 100755 index 0000000..9076b9c Binary files /dev/null and b/package/tests/binaries/execlp/program differ diff --git a/package/tests/binaries/execlp/program.c b/package/tests/binaries/execlp/program.c new file mode 100644 index 0000000..f752938 --- /dev/null +++ b/package/tests/binaries/execlp/program.c @@ -0,0 +1,10 @@ +#include +#include + + +void main(int argc, char** argv) { + char *args[4]; + args[0] = "echo"; + + execlp(args[0], args[0], "Hello", argv[1], 0); +} diff --git a/package/tests/binaries/execve/other_prog b/package/tests/binaries/execve/other_prog new file mode 100755 index 0000000..c03d16c Binary files /dev/null and b/package/tests/binaries/execve/other_prog differ diff --git a/package/tests/binaries/execve/other_prog.c b/package/tests/binaries/execve/other_prog.c new file mode 100644 index 0000000..046840f --- /dev/null +++ b/package/tests/binaries/execve/other_prog.c @@ -0,0 +1,9 @@ +#include +#include + +void main(int argc, char **argv) { + char cmd[0x40]; + printf("%s Running: %s %s\n", argv[0], argv[1], argv[2]); + snprintf(cmd, 0x40, "echo '%s %s'", argv[1], argv[2]); + system(cmd); +} diff --git a/package/tests/binaries/execve/program b/package/tests/binaries/execve/program new file mode 100755 index 0000000..d0bd256 Binary files /dev/null and b/package/tests/binaries/execve/program differ diff --git a/package/tests/binaries/execve/program.c b/package/tests/binaries/execve/program.c new file mode 100644 index 0000000..07d33ce --- /dev/null +++ b/package/tests/binaries/execve/program.c @@ -0,0 +1,14 @@ +#include +#include + + +void main(int argc, char** argv) { + char *args[4]; + args[0] = "./other_prog"; + args[1] = "Hello!"; + args[2] = argv[1]; + args[3] = 0; + + printf("Running: %s %s %s\n", args[0], args[1], args[2]); + execve(args[0], args, 0); +} diff --git a/package/tests/binaries/heap/program b/package/tests/binaries/heap/program new file mode 100755 index 0000000..e106a2f Binary files /dev/null and b/package/tests/binaries/heap/program differ diff --git a/package/tests/binaries/heap/program.c b/package/tests/binaries/heap/program.c new file mode 100644 index 0000000..da385a6 --- /dev/null +++ b/package/tests/binaries/heap/program.c @@ -0,0 +1,16 @@ +#include +#include + + +void main(int argc, char** argv) { + char *command1 = malloc(0x20); + strncpy(command1, "ls -la", 0x1f); + system(command1); + + char *command2 = malloc(0x20); + strncpy(command2, argv[1], 0x1f); + system(command2); + + free(command1); + free(command2); +} \ No newline at end of file diff --git a/package/tests/binaries/layered/program b/package/tests/binaries/layered/program new file mode 100755 index 0000000..55f8c9c Binary files /dev/null and b/package/tests/binaries/layered/program differ diff --git a/package/tests/binaries/layered/program.c b/package/tests/binaries/layered/program.c new file mode 100644 index 0000000..01c82a7 --- /dev/null +++ b/package/tests/binaries/layered/program.c @@ -0,0 +1,42 @@ +#include +#include + +void layer_1(char *arg1, char *arg2) { + layer_2(arg1, arg2); +} + +void layer_2(char *arg1, char *arg2) { + layer_3(arg1, arg2); +} + +void layer_3(char *arg1, char *arg2) { + layer_4(arg1, arg2); +} + +void layer_4(char *arg1, char *arg2) { + layer_5(arg1, arg2); +} + +void layer_5(char *arg1, char *arg2) { + layer_6(arg1, arg2); +} + +void layer_6(char *arg1, char *arg2) { + layer_7a(arg1); + layer_7b(arg2); +} + +void layer_7a(char *arg1) { + system(arg1); +} + +void layer_7b(char *arg1) { + system(arg1); +} + +int main(int argc, char **argv, char **envp) { + + char *buf = "ls -la"; + layer_1(buf, argv[1]); + +} diff --git a/package/tests/binaries/looper/program b/package/tests/binaries/looper/program new file mode 100755 index 0000000..7b92791 Binary files /dev/null and b/package/tests/binaries/looper/program differ diff --git a/package/tests/binaries/looper/program.c b/package/tests/binaries/looper/program.c new file mode 100644 index 0000000..c4e7837 --- /dev/null +++ b/package/tests/binaries/looper/program.c @@ -0,0 +1,15 @@ +#include +#include + + +int main(int argc, char *argv[]) { + puts("***** main *****"); + + for(int i=0; i<42; i++) { + puts("***** looping *****"); + } + + system(argv[1]); + + return 0; +} diff --git a/package/tests/binaries/multi_input/program b/package/tests/binaries/multi_input/program new file mode 100755 index 0000000..d339cc1 Binary files /dev/null and b/package/tests/binaries/multi_input/program differ diff --git a/package/tests/binaries/multi_input/program.c b/package/tests/binaries/multi_input/program.c new file mode 100644 index 0000000..6d77853 --- /dev/null +++ b/package/tests/binaries/multi_input/program.c @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include +#include +#define PORT 8080 + + +void vuln(char *cmdline, int sock_fd, int file_fd) { + char buf1[0x40]; + char buf2[0x40]; + char buf3[0x40]; + char buf4[0x40]; + char buf5[0x40]; + char cmd[1024]; + + read(0, buf1, 0x40); + read(sock_fd, buf2, 0x40); + read(file_fd, buf3, 0x40); + recv(sock_fd, buf4, 0x40, 0); + FILE *f = fopen("/etc/passwd", "r"); + fgets(buf5, 0x40, f); + + sprintf(cmd, "%s%s%s%s%s%s", buf1, buf2, buf3, buf4, buf5, cmdline); + system(cmd); +} + +int main(int argc, char *argv[]) { + int server_fd, new_socket, valread; + struct sockaddr_in address; + int opt = 1; + int addrlen = sizeof(address); + char buffer[1024] = { 0 }; + char* hello = "Hello from server"; + + // Creating socket file descriptor + if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("socket failed"); + exit(EXIT_FAILURE); + } + + // Forcefully attaching socket to the port 8080 + if (setsockopt(server_fd, SOL_SOCKET, + SO_REUSEADDR | SO_REUSEPORT, &opt, + sizeof(opt))) { + perror("setsockopt"); + exit(EXIT_FAILURE); + } + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(PORT); + + // Forcefully attaching socket to the port 8080 + if (bind(server_fd, (struct sockaddr*)&address, + sizeof(address)) + < 0) { + perror("bind failed"); + exit(EXIT_FAILURE); + } + if (listen(server_fd, 3) < 0) { + perror("listen"); + exit(EXIT_FAILURE); + } + if ((new_socket + = accept(server_fd, (struct sockaddr*)&address, + (socklen_t*)&addrlen)) + < 0) { + perror("accept"); + exit(EXIT_FAILURE); + } + + int file_fd = open("/etc/passwd", O_RDONLY); + vuln(argv[1], new_socket, file_fd); + return 0; +} diff --git a/package/tests/binaries/nested/program b/package/tests/binaries/nested/program new file mode 100755 index 0000000..19e3d63 Binary files /dev/null and b/package/tests/binaries/nested/program differ diff --git a/package/tests/binaries/nested/program.c b/package/tests/binaries/nested/program.c new file mode 100644 index 0000000..88233cc --- /dev/null +++ b/package/tests/binaries/nested/program.c @@ -0,0 +1,21 @@ +#include +#include + + +void b() { } + +void a(int i) { + if (i < 10) { + b(); + } + + int j = 42; +} + +int main(int argc, char *argv[]) { + a(argc); + + system(argv[1]); + + return 0; +} diff --git a/package/tests/binaries/nvram/keys b/package/tests/binaries/nvram/keys new file mode 100755 index 0000000..bdd2060 Binary files /dev/null and b/package/tests/binaries/nvram/keys differ diff --git a/package/tests/binaries/nvram/keys.c b/package/tests/binaries/nvram/keys.c new file mode 100644 index 0000000..ffdfb88 --- /dev/null +++ b/package/tests/binaries/nvram/keys.c @@ -0,0 +1,11 @@ +#include +#include +#include +#include "nvram_lib.h" + + +int main(int argc, char *argv[]) { + acosNvramConfig_set("command1", "ls -la"); + acosNvramConfig_set("command2", argv[1]); + return 0; +} diff --git a/package/tests/binaries/nvram/nvram_lib.c b/package/tests/binaries/nvram/nvram_lib.c new file mode 100644 index 0000000..6e27c61 --- /dev/null +++ b/package/tests/binaries/nvram/nvram_lib.c @@ -0,0 +1,6 @@ + +void acosNvramConfig_set(char *key, char *value) {} + +char* acosNvramConfig_get(char *key) { + return "echo 'HELLO WORLD'"; +} diff --git a/package/tests/binaries/nvram/nvram_lib.h b/package/tests/binaries/nvram/nvram_lib.h new file mode 100644 index 0000000..b559df4 --- /dev/null +++ b/package/tests/binaries/nvram/nvram_lib.h @@ -0,0 +1,3 @@ +/* nvram_lib.h */ +extern char* acosNvramConfig_get(char *); +extern void acosNvramConfig_set(char *, char *); diff --git a/package/tests/binaries/nvram/program b/package/tests/binaries/nvram/program new file mode 100644 index 0000000..2363092 Binary files /dev/null and b/package/tests/binaries/nvram/program differ diff --git a/package/tests/binaries/nvram/program.c b/package/tests/binaries/nvram/program.c new file mode 100644 index 0000000..0306430 --- /dev/null +++ b/package/tests/binaries/nvram/program.c @@ -0,0 +1,16 @@ +#include +#include +#include +#include "nvram_lib.h" + + +int main(int argc, char *argv[]) { + char* command = acosNvramConfig_get("command1"); + system(command); + + command = acosNvramConfig_get("command2"); + system(command); + + acosNvramConfig_set("command1", "ls -la"); + return 0; +} diff --git a/package/tests/binaries/off_shoot/program b/package/tests/binaries/off_shoot/program new file mode 100755 index 0000000..fc221ac Binary files /dev/null and b/package/tests/binaries/off_shoot/program differ diff --git a/package/tests/binaries/off_shoot/program.c b/package/tests/binaries/off_shoot/program.c new file mode 100644 index 0000000..95a19ef --- /dev/null +++ b/package/tests/binaries/off_shoot/program.c @@ -0,0 +1,47 @@ +#include +#include +#include + + +void log(char *buf) { + printf("YOU SAID: %s\n", buf); +} + +void sub_func(char *buf) { + sprintf(buf, "%s; echo 'DONE'", buf); +} + +void alter_command(char *buf1, char *buf2) { + strcat(buf1, buf2); + sub_func(buf1); +} + +void off_shoot_resolved(char *command) { + log("OFFSHOOT RESOLVED"); + alter_command(command, " 'Hello World'"); + log(command); + system(command); +} + +void off_shoot_unresolved(char *command) { + log("OFFSHOOT UNRESOLVED"); + char extras[0x40]; + memset(extras, 0, 0x40); + read(0, extras, 0x40); + int len = strlen(extras); + if (extras[len-1] == '\n') { + extras[len-1] = '\0'; + } + alter_command(command, extras); + log(command); + system(command); +} + +void main(int argc, char **argv) { + char buf[0x64] = {"echo"}; + off_shoot_resolved(buf); + memset(buf, 0, 0x64); + memcpy(buf, "ls ", 4); + off_shoot_unresolved(buf); + +} \ No newline at end of file diff --git a/package/tests/binaries/recursive/program b/package/tests/binaries/recursive/program new file mode 100755 index 0000000..b993a29 Binary files /dev/null and b/package/tests/binaries/recursive/program differ diff --git a/package/tests/binaries/recursive/program.c b/package/tests/binaries/recursive/program.c new file mode 100644 index 0000000..ae4c80c --- /dev/null +++ b/package/tests/binaries/recursive/program.c @@ -0,0 +1,29 @@ +#include +#include +#include + +void child_func(char *buf, int count) { + recursive_child_resolve(buf, --count); +} + +void recursive_child_resolve(char *buf, int count) { + if (count > 0) { + child_func(buf, count); + } else { + system(buf); + } +} + +void recursive_self_resolve(char *buf, int count) { + if (count > 0) { + recursive_self_resolve(buf, --count); + } else { + system(buf); + } +} + +void main(int argc, char **argv) { + char buf[0x64] = {"ls -la"}; + recursive_self_resolve(buf, 10); + recursive_child_resolve(argv[1], 10); +} \ No newline at end of file diff --git a/package/tests/binaries/simple/program b/package/tests/binaries/simple/program new file mode 100755 index 0000000..f54f681 Binary files /dev/null and b/package/tests/binaries/simple/program differ diff --git a/package/tests/binaries/simple/program.c b/package/tests/binaries/simple/program.c new file mode 100644 index 0000000..eca54f4 --- /dev/null +++ b/package/tests/binaries/simple/program.c @@ -0,0 +1,29 @@ +#include +#include +#include + + +void* id(void *parameter) { + return parameter; +} + +void a(void *parameter) { + id(parameter); + system(parameter); +} + +void b(void *parameter) { + id(parameter); + execve(parameter, NULL, NULL); +} + + +int main(int argc, char *argv[]) { + puts("***** a *****"); + a(argv[1]); + + puts("***** b *****"); + b(argv[1]); + + return 0; +} diff --git a/package/tests/binaries/sprintf_resolved_and_unresolved/program b/package/tests/binaries/sprintf_resolved_and_unresolved/program new file mode 100755 index 0000000..c6906a6 Binary files /dev/null and b/package/tests/binaries/sprintf_resolved_and_unresolved/program differ diff --git a/package/tests/binaries/sprintf_resolved_and_unresolved/program.c b/package/tests/binaries/sprintf_resolved_and_unresolved/program.c new file mode 100644 index 0000000..8dcee72 --- /dev/null +++ b/package/tests/binaries/sprintf_resolved_and_unresolved/program.c @@ -0,0 +1,28 @@ + +#include +#include + + +void resolved() { + char command[64]; + char *format_str = "ls -alh %s"; + char *dir = "~/"; + + sprintf(command, format_str, dir); + system(command); +} + +void unresolved(char *augment) { + char command[64]; + char *format_str = "ls %s"; + sprintf(command, format_str, augment); + system(command); +} + +int main(int argc, char **argv) { + puts("--- fixed ---"); + resolved(); + puts("--- user controlled ---"); + unresolved(argv[1]); + return 0; +} \ No newline at end of file diff --git a/package/tests/binaries/wrapper/program b/package/tests/binaries/wrapper/program new file mode 100755 index 0000000..31a644f Binary files /dev/null and b/package/tests/binaries/wrapper/program differ diff --git a/package/tests/binaries/wrapper/program.c b/package/tests/binaries/wrapper/program.c new file mode 100644 index 0000000..6436aba --- /dev/null +++ b/package/tests/binaries/wrapper/program.c @@ -0,0 +1,15 @@ +#include + + +void system_wrapper(char *command) { + system(command); +} + + +void main(int argc, char** argv) { + char command[0x40]; + snprintf(command, 0x40, "echo 'Executing: %s'", argv[1]); + system_wrapper(command); + snprintf(command, 0x40, "%s", argv[1]); + system_wrapper(command); +} diff --git a/package/tests/test_basic_facts_for_handcrafted_binaries.py b/package/tests/test_basic_facts_for_handcrafted_binaries.py new file mode 100644 index 0000000..bdc6a09 --- /dev/null +++ b/package/tests/test_basic_facts_for_handcrafted_binaries.py @@ -0,0 +1,234 @@ +import json +import os +import pathlib +import re +import subprocess +import tempfile + +import unittest + + +class TestBasicFactsForHandcraftedBinaries(unittest.TestCase): + """ + Sanity checks over handcrafted binaries. + Real-world binaries tests are usually a bit slower, and more involved, so it's more practical to keep them separate. + """ + + PROJECT_ROOT = pathlib.Path(__file__).parent.parent.absolute() + + def _run_analysis(self, name, expected_results, *args): + binary_path = self.PROJECT_ROOT / "tests" / "binaries" / name / "program" + with tempfile.TemporaryDirectory() as results_folder: + p = subprocess.run( + [ + "mango", + binary_path, + "--disable-progress", + "--results", + results_folder, + *args, + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + results_file = pathlib.Path(results_folder) / "cmdi_results.json" + try: + self.assertTrue(p.returncode == 0) # Ran Successfully + if "mango" in expected_results: + results = json.loads(results_file.read_text()) + + self.assertEqual(len(results["closures"]), len(expected_results["mango"])) + + for expected_result in expected_results["mango"]: + matching_res = next( + ( + x + for x in results["closures"] + if int(x["sink"]["ins_addr"], 16) + == expected_result["call_addr"] + ), + None, + ) + self.assertIsNotNone(matching_res) + self.assertEqual( + matching_res["sink"]["function"], expected_result["sink"] + ) + self.assertEqual(matching_res["depth"], expected_result["depth"]) + + if "execv" in expected_results: + exec_file = pathlib.Path(results_folder) / "execv.json" + results = json.loads(exec_file.read_text()) + for expected_result in expected_results["execv"]: + print(results) + self.assertTrue(expected_result["bin"] in results["execv"]) + + name = expected_result["bin"] + num_args = len(results["execv"][name][0]["args"]) + vuln_args = results["execv"][name][0]["vulnerable_args"] + self.assertEqual(num_args, expected_result["num_args"]) + self.assertListEqual(vuln_args, expected_result["vuln_args"]) + + except AssertionError as e: + print("FAIL") + print(p.stdout.decode()) + raise e + + def _run_env_analysis(self, name, bin_name, expected_results): + binary_path = self.PROJECT_ROOT / "tests" / "binaries" / name / bin_name + with tempfile.TemporaryDirectory() as results_folder: + p = subprocess.run( + [ + "env_resolve", + binary_path, + "--disable-progress", + "--results", + results_folder, + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + results_file = pathlib.Path(results_folder) / "env.json" + try: + self.assertTrue(p.returncode == 0) # Ran Successfully + results = json.loads(results_file.read_text()) + + for expected_result in expected_results["env"]: + matching_res = next( + ( + val_dict + for sink, val_dict in results["results"].items() + if sink == expected_result["sink"] + ), + None, + ) + self.assertIsNotNone(matching_res) + for key_dict in expected_result["keys"]: + for key, args_dict in key_dict.items(): + self.assertTrue(key in matching_res) + for arg, arg_dict in args_dict.items(): + self.assertTrue(arg in matching_res[key]) + self.assertTrue(arg_dict["value"] in matching_res[key][arg]) + self.assertTrue(hex(arg_dict["loc"]) in matching_res[key][arg][arg_dict["value"]]) + except AssertionError as e: + print("FAIL") + print(p.stdout.decode()) + raise e + + subprocess.run(["env_resolve", results_folder, "--merge", "--results", "/tmp/env.json"]) + + + def test_simple_binary(self): + name = "simple" + expected_results = { + "mango": [ + {"depth": 2, "sink": "system", "call_addr": 0x40115F}, + {"depth": 2, "sink": "execve", "call_addr": 0x401190}, + ] + } + self._run_analysis(name, expected_results) + + def test_looper_binary(self): + name = "looper" + expected_results = { + "mango": [{"depth": 1, "sink": "system", "call_addr": 0x401166}] + } + self._run_analysis(name, expected_results) + + def test_nested_binary(self): + name = "nested" + expected_results = { + "mango": [{"depth": 1, "sink": "system", "call_addr": 0x401165}] + } + self._run_analysis(name, expected_results) + + def test_sprintf_resolved_and_unresolved_binary(self): + name = "sprintf_resolved_and_unresolved" + expected_results = { + "mango": [{"depth": 2, "sink": "system", "call_addr": 0x401266}] + } + self._run_analysis(name, expected_results) + + def test_layered_binary(self): + name = "layered" + expected_results = { + "mango": [{"depth": 8, "sink": "system", "call_addr": 0x4012A3}] + } + self._run_analysis(name, expected_results, "--max-depth", "10") + + def test_off_shoot_binary(self): + name = "off_shoot" + expected_results = { + "mango": [{"depth": 2, "sink": "system", "call_addr": 0x4013E6}] + } + self._run_analysis(name, expected_results) + + def test_recursion(self): + name = "recursive" + expected_results = { + "mango": [{"depth": 2, "sink": "system", "call_addr": 0x4011D1}] * 2 + } + self._run_analysis(name, expected_results) + + def test_nvram(self): + name = "nvram" + expected_results = { + "mango": [{"depth": 1, "sink": "system", "call_addr": 0x40004C}], + "env": [{"sink": "acosNvramConfig_set", + "keys": [{"command1": + {"1": + {"value": "ls -la", + "loc": 0x401150} + } + }, + {"command2": + {"1": + {"value": "TOP", + "loc": 0x40116d} + } + }] + }] + } + + self._run_env_analysis(name, "keys", expected_results) + + env_file = "/tmp/env.json" + self._run_analysis(name, expected_results, "--env-dict", env_file) + os.unlink(env_file) + + def test_heap(self): + name = "heap" + expected_results = { + "mango": [{"depth": 1, "sink": "system", "call_addr": 0x401225}] + } + self._run_analysis(name, expected_results) + + def test_wrapper_funcs(self): + name = "wrapper" + expected_results = { + "mango": [{"depth": 2, "sink": "system", "call_addr": 0x4011A5}] * 2 + } + self._run_analysis(name, expected_results) + + def test_early_resolve(self): + name = "early_resolve" + expected_results = { + "mango": [ + {"depth": 1, "sink": "system", "call_addr": 0x4012FE}, + {"depth": 2, "sink": "system", "call_addr": 0x401200}, + ], + } + self._run_analysis(name, expected_results) + + def test_execve_resolve(self): + name = "execve" + expected_results = { + "execv": [{"bin": "other_prog", "num_args": 3, "vuln_args": [2]}] + } + self._run_analysis(name, expected_results) + + def test_execlp_resolve(self): + name = "execlp" + expected_results = { + "execv": [{"bin": "echo", "num_args": 3, "vuln_args": [2]}] + } + self._run_analysis(name, expected_results) diff --git a/package/tests/test_call_trace.py b/package/tests/test_call_trace.py new file mode 100644 index 0000000..06c3446 --- /dev/null +++ b/package/tests/test_call_trace.py @@ -0,0 +1,98 @@ +from unittest import TestCase + +import networkx + +from angr.analyses.reaching_definitions.call_trace import CallSite + +from argument_resolver.utils.call_trace import traces_to_sink + + +class MockFunction: + def __init__(self, _addr): + self.addr = _addr + + +class TestCallTrace(TestCase): + def test_traces_to_sink(self): + f0, f1, f2 = [MockFunction(i) for i in range(0, 3)] + sink = MockFunction(0x42) + + # Represent the following callgraph: + # 0 -> 1 -> 2 -> 0x42 + callgraph = networkx.MultiDiGraph( + [ + (f0.addr, f1.addr), + (f1.addr, f2.addr), + (f2.addr, sink.addr), + ] + ) + + traces = list(traces_to_sink(sink, callgraph, 3, {})) + expected_first_callsites = [ + CallSite(2, None, 0x42), + CallSite(1, None, 2), + CallSite(0, None, 1), + ] + + self.assertEqual(len(traces), 1) + + self.assertEqual(traces[0].target, 0x42) + self.assertListEqual(traces[0].callsites, expected_first_callsites) + + def test_traces_to_sink_recovers_everything_when_given_a_super_big_depth(self): + f0, f1, f2 = [MockFunction(i) for i in range(0, 3)] + sink = MockFunction(0x42) + + # Represent the following callgraph: + # 0 -> 1 -> 2 -> 0x42 + callgraph = networkx.MultiDiGraph( + [ + (f0.addr, f1.addr), + (f1.addr, f2.addr), + (f2.addr, sink.addr), + ] + ) + + traces = list(traces_to_sink(sink, callgraph, 9999, {})) + + expected_first_callsites = [ + CallSite(2, None, 0x42), + CallSite(1, None, 2), + CallSite(0, None, 1), + ] + + self.assertEqual(len(traces), 1) + + self.assertEqual(traces[0].target, 0x42) + self.assertListEqual(traces[0].callsites, expected_first_callsites) + + def test_traces_to_sink_recover_all_the_traces_flowing_into_a_sink(self): + f0, f1, f2, f3 = [MockFunction(i) for i in range(0, 4)] + sink = MockFunction(0x42) + + # Represent the following callgraph: + # 0 -> 1 -> 2 -> 0x42, 3 -> 0x42 + callgraph = networkx.MultiDiGraph( + [ + (f0.addr, f1.addr), + (f1.addr, f2.addr), + (f2.addr, sink.addr), + (f3.addr, sink.addr), + ] + ) + + traces = list(traces_to_sink(sink, callgraph, 3, {})) + traces.sort(key=lambda x: x.callsites[0].caller_func_addr) + + expected_first_callsites = [ + CallSite(2, None, 0x42), + CallSite(1, None, 2), + CallSite(0, None, 1), + ] + expected_second_callsites = [CallSite(3, None, 0x42)] + + self.assertEqual(len(traces), 2) + self.assertEqual(traces[0].target, 0x42) + self.assertListEqual(traces[0].callsites, expected_first_callsites) + self.assertEqual(traces[1].target, 0x42) + self.assertListEqual(traces[1].callsites, expected_second_callsites) diff --git a/package/tests/test_calling_convention.py b/package/tests/test_calling_convention.py new file mode 100644 index 0000000..b7b3ecc --- /dev/null +++ b/package/tests/test_calling_convention.py @@ -0,0 +1,312 @@ +import logging +import os + +from unittest import TestCase + +from angr.knowledge_plugins.key_definitions.atoms import Register +from angr.calling_conventions import SimRegArg, SimStackArg, SimCC +from angr.engines.light import SpOffset +from archinfo import ( + ArchX86, + ArchAMD64, + ArchARM, + ArchAArch64, + ArchMIPS32, + ArchMIPS64, + ArchPPC32, + ArchPPC64, +) + +from argument_resolver.utils.calling_convention import ( + # CallingConventionResolver, + cc_to_rd, + get_default_cc_with_args, +) + + +BINARIES_DIR = os.path.realpath( + os.path.join( + os.path.realpath(__file__), "..", "..", "..", "..", "binaries", "tests" + ) +) +LOGGER = logging.getLogger("argument_resolver/test_calling_convention") + + +class TestCallingConvention(TestCase): + def test_get_default_cc_with_args(self): + def run_all_tests_for(arch): + list( + map( + lambda x: run_test( + arch["arch"], x[0], x[1], arch["expected_return_value"] + ), + zip( + arch["number_of_parameters"], + arch["expected_args"], + ), + ) + ) + + def run_test(arch, number_of_parameters, expected_args, expected_return_value): + cc_with_args = get_default_cc_with_args(number_of_parameters, arch) + + computed_args = [] + int_args = cc_with_args.int_args + mem_args = cc_with_args.memory_args + while (x := next(int_args, None)) is not None and len( + computed_args + ) < number_of_parameters: + computed_args.append(str(x)) + for _ in range(number_of_parameters - len(computed_args)): + computed_args.append(str(next(mem_args))) + computed_return_value = str(cc_with_args.RETURN_VAL) + + self.assertEqual(expected_args, computed_args) + self.assertEqual(expected_return_value, computed_return_value) + self.assertTrue(isinstance(cc_with_args, SimCC)) + + # Create some data, which expected results are based on commit 2c195cb implementation's returns, + # and throw the tests at them. + data = [ + { + "arch": ArchX86(), + "number_of_parameters": [1, 3], + "expected_args": [ + ["[0x4]"], + ["[0x4]", "[0x8]", "[0xc]"], + ], + "expected_return_value": "", + }, + { + "arch": ArchAMD64(), + "number_of_parameters": [1, 9], + "expected_args": [ + [""], + [ + "", + "", + "", + "", + "", + "", + "[0x8]", + "[0x10]", + "[0x18]", + ], + ], + "expected_return_value": "", + }, + { + "arch": ArchARM(), + "number_of_parameters": [1, 6], + "expected_args": [ + [""], + ["", "", "", "", "[0x0]", "[0x4]"], + ], + "expected_return_value": "", + }, + { + "arch": ArchAArch64(), + "number_of_parameters": [1, 6], + "expected_args": [ + [""], + ["", "", "", "", "", ""], + ], + "expected_return_value": "", + }, + { + "arch": ArchMIPS32(), + "number_of_parameters": [1, 6], + "expected_args": [ + [""], + ["", "", "", "", "[0x10]", "[0x14]"], + ], + "expected_return_value": "", + }, + { + "arch": ArchMIPS64(), + "number_of_parameters": [1, 6], + "expected_args": [ + [""], + ["", "", "", "", "", ""], + ], + "expected_return_value": "", + }, + { + "arch": ArchPPC32(), + "number_of_parameters": [1, 10], + "expected_args": [ + [""], + [ + "", + "", + "", + "", + "", + "", + "", + "", + "[0x8]", + "[0xc]", + ], + ], + "expected_return_value": "", + }, + { + "arch": ArchPPC64(), + "number_of_parameters": [1, 10], + "expected_args": [ + [""], + [ + "", + "", + "", + "", + "", + "", + "", + "", + "[0x70]", + "[0x78]", + ], + ], + "expected_return_value": "", + }, + ] + + list(map(run_all_tests_for, data)) + + # def test_get_cc_with_known_external_function(self): + # @mock.patch("argument_resolver.calling_convention.get_default_cc_with_args") + # def run_test_for_external_function( + # project, + # function, + # expected_number_of_arguments, + # mock_get_default_cc_with_args, + # ): + # calling_convention_resolver = CallingConventionResolver( + # project, project.arch, None + # ) + # _ = calling_convention_resolver.get_cc(function) + + # # Just test proper delegation to `get_default_cc_with_args()` as it has been thoroughly tested + # mock_get_default_cc_with_args.assert_called_once_with( + # expected_number_of_arguments, project.arch + # ) + + # def run_test_for_arch(arch_and_binary): + # functions = [ + # {"name": "system", "expected_number_of_arguments": 1}, + # {"name": "popen", "expected_number_of_arguments": 2}, + # {"name": "printf", "expected_number_of_arguments": 1}, + # {"name": "strcmp", "expected_number_of_arguments": 2}, + # {"name": "strncmp", "expected_number_of_arguments": 3}, + # {"name": "strcasecmp", "expected_number_of_arguments": 2}, + # {"name": "strncasecmp", "expected_number_of_arguments": 3}, + # {"name": "strcoll", "expected_number_of_arguments": 2}, + # {"name": "strcpy", "expected_number_of_arguments": 2}, + # {"name": "strncpy", "expected_number_of_arguments": 3}, + # {"name": "strcat", "expected_number_of_arguments": 2}, + # {"name": "strncat", "expected_number_of_arguments": 3}, + # {"name": "sprintf", "expected_number_of_arguments": 2}, + # {"name": "snprintf", "expected_number_of_arguments": 3}, + # {"name": "atoi", "expected_number_of_arguments": 1}, + # {"name": "nvram_set", "expected_number_of_arguments": 2}, + # {"name": "acosNvramConfig_set", "expected_number_of_arguments": 2}, + # {"name": "nvram_get", "expected_number_of_arguments": 1}, + # {"name": "nvram_safe_get", "expected_number_of_arguments": 1}, + # {"name": "acosNvramConfig_get", "expected_number_of_arguments": 1}, + # {"name": "malloc", "expected_number_of_arguments": 1}, + # {"name": "calloc", "expected_number_of_arguments": 2}, + # {"name": "read", "expected_number_of_arguments": 3}, + # {"name": "fgets", "expected_number_of_arguments": 3}, + # ] + + # arch = arch_and_binary[0] + # binary = arch_and_binary[1] + + # # If arch is one of the PPC ones, do not specify it in the `Project` constructor, to avoid + # # https://github.com/angr/angr/issues/1553 . + # def is_ppc(arch): + # return arch.name.find("ppc") > -1 + + # project = Project(binary, arch=arch) if is_ppc(arch) else Project(binary) + + # list( + # map( + # lambda x: run_test_for_external_function( # pylint: disable=[no-value-for-parameter] + # project, x["name"], x["expected_number_of_arguments"] + # ), + # functions, + # ) + # ) + + # arches = [ + # ArchX86(), + # ArchAMD64(), + # ArchARM(), + # ArchAArch64(), + # ArchMIPS32(), + # ArchMIPS64(), + # ArchPPC32(), + # ArchPPC64(), + # ] + # binaries = list( + # map( + # lambda binary: os.path.join(BINARIES_DIR, binary), + # [ + # "i386/fauxware", + # "x86_64/fauxware", + # "android/arm/fauxware", + # "android/aarch64/fauxware", + # "mips/fauxware", + # "mips64/ld.so.1", + # "ppc/fauxware", + # "ppc64/fauxware", + # ], + # ) + # ) + # list(map(run_test_for_arch, zip(arches, binaries))) + + # @mock.patch.object(logging.Logger, "error") + # def test_get_cc_with_a_function_not_in_CFG(self, mock_Logger_error): + # MockFunctions = {} + + # arch = ArchX86() + # project = Project(os.path.join(BINARIES_DIR, "i386/fauxware"), arch=arch) + + # calling_convention_resolver = CallingConventionResolver( + # project, arch, MockFunctions # pylint: disable=[undefined-variable] + # ) + + # function_name = "unknown" + # _ = calling_convention_resolver.get_cc(function_name) + + # mock_Logger_error.assert_called_once_with( + # "CCA: Failed for %s(), function neither an external function nor have its name in CFG", + # function_name, + # ) + + def test_cc_to_rd_return_a_stack_pointer_offset_when_given_a_SimStackArg(self): + arch = ArchX86() + sim = SimStackArg(0x42, 1) + result = cc_to_rd(sim, arch) + + # See angr/angr/engines/light/data.py for `SpOffset` formatting + self.assertEqual(str(result.addr), "SP+0x42") + self.assertEqual(result.addr.__class__, SpOffset) + + def test_cc_to_rd_return_a_register_when_given_a_SimRegArg(self): + arch = ArchX86() + sim = SimRegArg("esp", 1) + result = cc_to_rd(sim, arch) + + # See angr/angr/analyses/reaching_definitions/atoms.py for `Register` formatting + self.assertEqual(str(result), ">") + self.assertEqual(result.__class__, Register) + + def test_cc_to_rd_with_a_parameter_of_the_wrong_type(self): + arch = ArchX86() + param = "This is a string so that won't work." + + self.assertRaises(TypeError, cc_to_rd, param, arch) diff --git a/package/tests/test_handlers/handler_tester.py b/package/tests/test_handlers/handler_tester.py new file mode 100644 index 0000000..9821381 --- /dev/null +++ b/package/tests/test_handlers/handler_tester.py @@ -0,0 +1,79 @@ +import os +import subprocess +import tempfile + +from unittest import TestCase +from typing import Tuple + +from angr.project import Project +from angr.analyses.analysis import AnalysisFactory + +from argument_resolver.utils.rda import CustomRDA +from argument_resolver.utils.call_trace_visitor import CallTraceSubject +from argument_resolver.utils.call_trace import traces_to_sink + + +class HandlerTester(TestCase): + """ + Helper to test handlers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.folder = "" + self.source_path = "" + self.binary_path = "" + self.RDA = None + + def subject_from_function(self, project, function, depth=1): + traces = traces_to_sink(function, project.kb.functions.callgraph, depth, []) + assert len(traces) == 1 + + trace = traces.pop() + function_address = trace.current_function_address() + init_function = project.kb.functions[function_address] + return CallTraceSubject(trace, init_function) + + + def project_and_cfg_analysis_from(self, program: str) -> Project: + """ + Build an `angr.Project` for a program corresponding to a given C source; + Then run the `CFGFast` analysis on it. + """ + binary = self._compile(program) + project = Project(binary, auto_load_libs=False) + cfg = project.analyses.CFGFast(normalize=True, data_references=True) + project.analyses.CompleteCallingConventions(recover_variables=True, cfg=cfg) + self.RDA = AnalysisFactory(project, CustomRDA) + + return project + + def _compile(self, program: str) -> str: + """ + Compile a binary given a source. + + :param program: The program source code in C. + :return: The absolute path of the generated binary on the filesystem. + """ + self.folder = tempfile.TemporaryDirectory() + + self.source_path = os.path.join(self.folder.name, "program.c") + with open(self.source_path, "w", encoding="ascii") as source_file: + source_file.write(program) + + self.binary_path = os.path.join(self.folder.name, "program") + + subprocess.call( + [ + "gcc", + "-O0", + "-w", + "-fno-builtin", + "-fno-stack-protector", + self.source_path, + "-o", + self.binary_path, + ] + ) + + return self.binary_path diff --git a/package/tests/test_handlers/test_stdio.py b/package/tests/test_handlers/test_stdio.py new file mode 100644 index 0000000..1c41663 --- /dev/null +++ b/package/tests/test_handlers/test_stdio.py @@ -0,0 +1,105 @@ +from handler_tester import HandlerTester + +from angr.knowledge_plugins.key_definitions.constants import OP_AFTER +from angr.analyses.reaching_definitions.dep_graph import DepGraph + +from argument_resolver.handlers import handler_factory, StdioHandlers +from argument_resolver.utils.utils import Utils + +from archinfo import Endness + + +class TestStdioHandlers(HandlerTester): + TESTED_HANDLER = handler_factory([StdioHandlers]) + + def test_handle_sprintf(self): + string = "Hello World!" + program = f""" + #include + void main() {{ + char greeting[0x40]; + sprintf(greeting, "Greeting: %s", "{string}"); + }} + """ + final_output = "Greeting: " + string + project = self.project_and_cfg_analysis_from(program) + + handler = self.TESTED_HANDLER(project, False) + + sprintf = project.kb.functions.function(name="sprintf") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, sprintf) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + results = handler.analyzed_list[-1].state + cc = project.analyses.CallingConvention(sprintf).cc + args = cc.int_args + + arg_dst = next(args) + + dst_values = Utils.get_values_from_cc_arg( + arg_dst, + results, + rda.project.arch, + ) + + printed_str = Utils.get_strings_from_pointers(dst_values, results, None) + + self.assertEqual( + Utils.bytes_from_int(printed_str.one_value()).decode("utf-8"), + final_output, + ) + + def test_handle_sprintf_unknown_string(self): + string = "Greeting: " + format_string = "%s" + program = f""" + #include + void main(int argc, char **argv) {{ + char greeting[0x40]; + sprintf(greeting, "{string + format_string}", argv[1]); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + handler = self.TESTED_HANDLER(project, False) + + sprintf = project.kb.functions.function(name="sprintf") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, sprintf) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + results = handler.analyzed_list[-1].state + cc = project.analyses.CallingConvention(sprintf).cc + args = cc.int_args + arg_dst = next(args) + + dst_values = Utils.get_values_from_cc_arg( + arg_dst, + results, + rda.project.arch, + ) + + printed_str = Utils.get_strings_from_pointers(dst_values, results, None) + + self.assertEqual( + Utils.bytes_from_int( + printed_str.one_value()[: results.arch.bytes * 8] + ).decode("utf-8"), + string, + ) + self.assertTrue( + results.is_top(printed_str.one_value()[(results.arch.bytes * 8) - 1 :]) + ) diff --git a/package/tests/test_handlers/test_stdlib.py b/package/tests/test_handlers/test_stdlib.py new file mode 100644 index 0000000..d039d72 --- /dev/null +++ b/package/tests/test_handlers/test_stdlib.py @@ -0,0 +1,123 @@ +from handler_tester import HandlerTester + +from angr.knowledge_plugins.key_definitions.constants import OP_AFTER +from angr.knowledge_plugins.key_definitions.atoms import Register +from angr.analyses.reaching_definitions.dep_graph import DepGraph +from argument_resolver.handlers import handler_factory, StdlibHandlers, StringHandlers +from argument_resolver.utils.utils import Utils + + +class TestStdlibHandlers(HandlerTester): + TESTED_HANDLER = handler_factory([StdlibHandlers, StringHandlers]) + + def test_handle_malloc(self): + program = """ + #include + void main() { + char *buf = (char *) malloc(0x40); + char *buf2 = (char *) malloc(0x40); + } + """ + Utils.ALL_CALLSITES = [] + project = self.project_and_cfg_analysis_from(program) + + strcpy = project.kb.functions.function(name="malloc") + subject = self.subject_from_function(project, strcpy) + handler = self.TESTED_HANDLER(project, strcpy, [Register(*project.arch.registers["rdi"])]) + handler.assumed_execution = False + + rda = self.RDA( + subject=subject, + function_handler=handler, + dep_graph=DepGraph(), + observation_points=set() + ) + + malloc1, malloc2 = [x for x in handler.analyzed_list if x.function.name == 'malloc'] + + self.assertEqual(Utils.get_heap_offset(malloc1.ret_val.one_value()), 0x0) + self.assertEqual(Utils.get_heap_offset(malloc2.ret_val.one_value()), 0x40) + + def test_handle_calloc(self): + nmemb = 0x4 + size = 0x20 + program = f""" + #include + void main() {{ + char *buf = (char *) calloc({nmemb}, {size}); + char *buf2 = (char *) calloc({nmemb}, {size}); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + calloc = project.kb.functions.function(name="calloc") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, calloc) + handler = self.TESTED_HANDLER(project, calloc, [Register(*project.arch.registers["rdi"])]) + handler.assumed_execution = False + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(calloc).cc + + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, + handler.analyzed_list[-1].state, + rda.project.arch, + ) + + pointer = return_values.one_value()._model_concrete.value + state = handler.analyzed_list[-1].state + zeroed_memory = state.heap.load( + pointer, nmemb * size, endness=state.arch.memory_endness + ) + + self.assertEqual(pointer, nmemb * size) + self.assertEqual(zeroed_memory.one_value().size(), nmemb * size * 8) + self.assertEqual(zeroed_memory.one_value()._model_concrete.value, 0x0) + + def test_handle_env(self): + env_var = "greeting" + env_val = "Hello World!" + program = f""" + #include + void main() {{ + setenv("{env_var}", "{env_val}", 0); + getenv("{env_var}"); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + getenv = project.kb.functions["getenv"] + + subject = self.subject_from_function(project, getenv) + observation_points = set(Utils.get_all_callsites(project)) + handler = self.TESTED_HANDLER(project, getenv, [Register(*project.arch.registers["rdi"])]) + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(getenv).cc + + state = handler.analyzed_list[-1].state + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, + state, + rda.project.arch, + ) + + env_string = state.heap.load(state.get_heap_offset(return_values.one_value()), len(env_val)) + + self.assertEqual( + Utils.bytes_from_int(env_string.one_value()).decode("utf-8"), env_val + ) diff --git a/package/tests/test_handlers/test_string.py b/package/tests/test_handlers/test_string.py new file mode 100644 index 0000000..7162c12 --- /dev/null +++ b/package/tests/test_handlers/test_string.py @@ -0,0 +1,468 @@ +from handler_tester import HandlerTester + +from angr.analyses.reaching_definitions.dep_graph import DepGraph +from angr.knowledge_plugins.key_definitions.atoms import Register +from angr.knowledge_plugins.key_definitions.constants import OP_AFTER +from angr.knowledge_plugins.key_definitions.live_definitions import DerefSize +from argument_resolver.handlers import handler_factory, StringHandlers +from argument_resolver.utils.utils import Utils + + +class TestStringHandlers(HandlerTester): + TESTED_HANDLER = handler_factory([StringHandlers]) + + def test_handle_strlen(self): + string = "Hello World!" + program = f""" + #include + void main() {{ + int i = strlen("{string}"); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + strlen = project.kb.functions.function(name="strlen") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strlen) + handler = self.TESTED_HANDLER(project, strlen, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strlen).cc + + state = handler.analyzed_list[-1].state + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, + state, + rda.project.arch, + ) + self.assertEqual(return_values.one_value()._model_concrete.value, len(string)) + + def test_handle_strlen_unknown_size(self): + program = f""" + #include + void main(int argc, char **argv) {{ + int i = strlen(argv[1]); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + strlen = project.kb.functions.function(name="strlen") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strlen) + handler = self.TESTED_HANDLER(project, strlen, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strlen).cc + + state = handler.analyzed_list[-1].state + + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, + state, + rda.project.arch, + ) + self.assertTrue(state.is_top(return_values.one_value())) + + def test_handle_strcat(self): + string1 = "Hello" + string2 = " World!" + program = f""" + #include + void main() {{ + char buf1[0x40] = {{"{string1}"}}; + char buf2[0x40] = {{"{string2}"}}; + strcat(buf1, buf2); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + strcat = project.kb.functions.function(name="strcat") + subject = self.subject_from_function(project, strcat) + observation_points = set(Utils.get_all_callsites(project)) + handler = self.TESTED_HANDLER(project, strcat, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + stored_func = handler.analyzed_list[-1] + + found_str = Utils.get_strings_from_pointers(stored_func.ret_val, stored_func.state, None) + + self.assertEqual( + Utils.bytes_from_int(found_str.one_value()).decode("utf-8").replace("\x00", ""), + string1 + string2, + ) + + def test_handle_strcat_unknown_value(self): + string = "Hello" + program = f""" + #include + void main(int argc, char **argv) {{ + char buf[40] = {{"{string}"}}; + strcat(buf, argv[1]); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + strcat = project.kb.functions.function(name="strcat") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strcat) + handler = self.TESTED_HANDLER(project, strcat, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + stored_func = handler.analyzed_list[-1] + + atom = stored_func.state.deref(stored_func.ret_val.one_value(), DerefSize.NULL_TERMINATE) + concat_str = stored_func.state.get_one_value(atom) + + self.assertEqual(Utils.bytes_from_int(concat_str[:(1 + project.arch.bytes) * 8]).decode("utf-8"), string) + self.assertTrue(stored_func.state.is_top(concat_str[((1 + project.arch.bytes)*8) - 1:])) + + def test_handle_strcpy(self): + program = """ + #include + void main() { + char s[12]; + strcpy(s, "Hello World!"); + } + """ + project = self.project_and_cfg_analysis_from(program) + + + strcpy = project.kb.functions.function(name="strcpy") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strcpy) + handler = self.TESTED_HANDLER(project, strcpy, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strcpy).cc + + state = handler.analyzed_list[-1].state + + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, state, rda.project.arch + ) + resulting_string_bytes = Utils.get_strings_from_pointers(return_values, state, None) + data = Utils.bytes_from_int(resulting_string_bytes.one_value()) + self.assertEqual(data.decode("utf-8"), "Hello World!") + + def test_handle_strcpy_unknown_value(self): + program = """ + #include + void main(int argc, char **argv) { + char s[12]; + strcpy(s, argv[1]); + } + """ + project = self.project_and_cfg_analysis_from(program) + + + strcpy = project.kb.functions.function(name="strcpy") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strcpy) + handler = self.TESTED_HANDLER(project, strcpy, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strcpy).cc + + state = handler.analyzed_list[-1].state + + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, state, rda.project.arch + ) + + resulting_string_bytes = Utils.get_strings_from_pointers(return_values, state, None) + + self.assertTrue(state.is_top(resulting_string_bytes.one_value())) + + def test_handle_strncpy(self): + program = """ + #include + void main() { + char s[5]; + strncpy(s, "Hello World!", 5); + } + """ + project = self.project_and_cfg_analysis_from(program) + + + strncpy = project.kb.functions.function(name="strncpy") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strncpy) + handler = self.TESTED_HANDLER(project, strncpy, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strncpy).cc + + state = handler.analyzed_list[-1].state + + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, state, rda.project.arch + ) + + resulting_string = Utils.get_strings_from_pointers(return_values, state, None) + + # We handle strncpy as strcpy + self.assertEqual(Utils.bytes_from_int(resulting_string.one_value()).decode("utf-8"), "Hello World!") + + def test_handle_strncpy_unknown_value(self): + program = """ + #include + void main(int argc, char **argv) { + char s[5]; + strncpy(s, argv[1], 5); + } + """ + project = self.project_and_cfg_analysis_from(program) + + + strncpy = project.kb.functions.function(name="strncpy") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strncpy) + handler = self.TESTED_HANDLER(project, strncpy, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strncpy).cc + + state = handler.analyzed_list[-1].state + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, state, rda.project.arch + ) + + resulting_string = Utils.get_strings_from_pointers(return_values, state, None) + + self.assertTrue(state.is_top(resulting_string.one_value())) + + def test_handle_atoi(self): + string = "42" + program = f""" + #include + void main() {{ + char *s = "{string}"; + int i = atoi(s); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + atoi = project.kb.functions.function(name="atoi") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, atoi) + handler = self.TESTED_HANDLER(project, atoi, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + stored_func = handler.analyzed_list[-1] + + self.assertEqual(stored_func.ret_val.one_value()._model_concrete.value, int(string)) + + def test_handle_atoi_unknown_value(self): + program = """ + #include + void main(int argc, char **argv) { + int i = atoi(argv[1]); + } + """ + project = self.project_and_cfg_analysis_from(program) + + + atoi = project.kb.functions.function(name="atoi") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, atoi) + handler = self.TESTED_HANDLER(project, atoi, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + stored_func = handler.analyzed_list[-1] + self.assertTrue(stored_func.state.is_top(stored_func.ret_val.one_value())) + + def test_handle_memcpy(self): + string = "Hello World!" + program = f""" + #include + void main() {{ + char s[12]; + memcpy(s, "{string}", 12); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + memcpy = project.kb.functions.function(name="memcpy") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, memcpy) + handler = self.TESTED_HANDLER(project, memcpy, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + stored_func = handler.analyzed_list[-1] + memory_value = Utils.get_strings_from_pointers(stored_func.ret_val, stored_func.state, None) + + self.assertEqual( + Utils.bytes_from_int(memory_value.one_value()).decode("utf-8"), + string, + ) + + def test_handle_memcpy_unknown_value(self): + program = """ + #include + void main(int argc, char **argv) { + char s[12]; + memcpy(s, argv[1], 12); + } + """ + project = self.project_and_cfg_analysis_from(program) + + + memcpy = project.kb.functions.function(name="memcpy") + observation_point = ("node", memcpy.addr, OP_AFTER) + subject = self.subject_from_function(project, memcpy) + handler = self.TESTED_HANDLER(project, memcpy, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points={observation_point}, + function_handler=handler, + dep_graph=DepGraph(), + ) + + stored_func = handler.analyzed_list[-1] + memory_value = Utils.get_strings_from_pointers(stored_func.ret_val, stored_func.state, None) + + self.assertTrue(stored_func.state.is_top(memory_value.one_value())) + + def test_handle_memset(self): + byte = 0 + size = 10 + program = f""" + #include + void main() {{ + char s[10]; + memset(s, {byte}, {size}); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + memset = project.kb.functions.function(name="memset") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, memset) + handler = self.TESTED_HANDLER(project, memset, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + stored_func = handler.analyzed_list[-1] + memory_value = stored_func.state.stack.load( + stored_func.state.get_stack_address(stored_func.ret_val.one_value()), size + ) + + self.assertEqual( + Utils.bytes_from_int(memory_value.one_value()).decode("utf-8"), + chr(byte) * size, + ) + + def test_handle_strdup(self): + test_string = "Hello World!" + program = f""" + #include + void main() {{ + char *s = strdup("{test_string}"); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + + strdup = project.kb.functions.function(name="strdup") + observation_points = set(Utils.get_all_callsites(project)) + subject = self.subject_from_function(project, strdup) + handler = self.TESTED_HANDLER(project, strdup, [Register(*project.arch.registers["rdi"])]) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + calling_convention = project.analyses.CallingConvention(strdup).cc + + state = handler.analyzed_list[-1].state + + return_values = Utils.get_values_from_cc_arg( + calling_convention.RETURN_VAL, state, rda.project.arch + ) + + duplicated_string = Utils.get_strings_from_pointers(return_values, state, None) + + self.assertEqual( + Utils.bytes_from_int(duplicated_string.one_value()).decode("utf-8"), + test_string, + ) diff --git a/package/tests/test_handlers/test_unistd.py b/package/tests/test_handlers/test_unistd.py new file mode 100644 index 0000000..ad20397 --- /dev/null +++ b/package/tests/test_handlers/test_unistd.py @@ -0,0 +1,55 @@ +from handler_tester import HandlerTester + +from angr.knowledge_plugins.key_definitions.atoms import Register +from angr.analyses.reaching_definitions.dep_graph import DepGraph + +from argument_resolver.handlers import handler_factory, UnistdHandlers +from argument_resolver.utils.utils import Utils + + +class TestStdioHandlers(HandlerTester): + TESTED_HANDLER = handler_factory([UnistdHandlers]) + + def test_handle_read(self): + read_size = 0x40 + program = f""" + #include + void main() {{ + char greeting[{read_size}]; + read(0, greeting, {read_size}); + }} + """ + project = self.project_and_cfg_analysis_from(program) + + read = project.kb.functions.function(name="read") + subject = self.subject_from_function(project, read) + observation_points = set(Utils.get_all_callsites(project)) + + handler = self.TESTED_HANDLER( + project, read, [Register(*project.arch.registers["rsi"])] + ) + + rda = self.RDA( + subject=subject, + observation_points=observation_points, + function_handler=handler, + dep_graph=DepGraph(), + ) + + cc = project.analyses.CallingConvention(read).cc + args = cc.int_args + next(args) + arg_dst = next(args) + + state = handler.analyzed_list[-1].state + dst_ptrs = Utils.get_values_from_cc_arg( + arg_dst, + state, + rda.project.arch, + ) + + printed_str = Utils.get_strings_from_pointers( + dst_ptrs, state, state.codeloc + ).one_value() + self.assertTrue(state.is_top(printed_str)) + self.assertTrue(printed_str.size() == 8 * self.TESTED_HANDLER.MAX_READ_SIZE) diff --git a/package/tests/test_sink.py b/package/tests/test_sink.py new file mode 100644 index 0000000..ca41437 --- /dev/null +++ b/package/tests/test_sink.py @@ -0,0 +1,31 @@ +from unittest import mock, TestCase + +from angr.sim_type import SimTypeFunction, SimTypeInt + +from argument_resolver.external_function import VULN_TYPES, Sink + + +class TestSink(TestCase): + MOCK_LIBRARIES = { + "a_sink": SimTypeFunction([SimTypeInt()], SimTypeInt(), arg_names=["key"]), + } + + @mock.patch( + "argument_resolver.external_function.CUSTOM_DECLS", + MOCK_LIBRARIES, + ) + def test_expose_list_of_command_injection_sinks(self): + for f in VULN_TYPES["cmdi"]: + self.assertEqual(type(f), Sink) + + @mock.patch( + "argument_resolver.external_function.CUSTOM_DECLS", + MOCK_LIBRARIES, + ) + def test_a_sink_has_a_dictionary_of_vulnerable_parameters_specifying_their_positions_and_type( + self, + ): + sink = Sink("a_sink", [1]) + vulnerable_parameters = [1] + + self.assertListEqual(sink.vulnerable_parameters, vulnerable_parameters) diff --git a/package/tests/test_transitive_closure.py b/package/tests/test_transitive_closure.py new file mode 100644 index 0000000..838e800 --- /dev/null +++ b/package/tests/test_transitive_closure.py @@ -0,0 +1,396 @@ +import claripy +import logging +import networkx + +from unittest import TestCase + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.definition import Definition +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation, Register +from angr.code_location import CodeLocation +from angr.analyses.reaching_definitions.reaching_definitions import LiveDefinitions +from angr.code_location import ExternalCodeLocation +from angr.analyses.reaching_definitions.dep_graph import DepGraph + +from archinfo import ArchAMD64 +from argument_resolver.utils.transitive_closure import ( + contains_an_external_definition, + represents_constant_data, +) +from argument_resolver.utils.closure import Closure + +LOGGER = logging.getLogger("argument_resolver/test_utils") + + +def _init_reach_def(): + arch = ArchAMD64() + reach_def = LiveDefinitions(arch=arch) + + sp = Register(arch.sp_offset, arch.bytes) + sp_offset = reach_def.stack_address(arch.sp_offset) + + reach_def.registers.store(sp.reg_offset, sp_offset, sp.size) + return reach_def + + +class TestTransitiveClosure(TestCase): + STRING_IN_MEMORY = "some string of data in memory" + STRING_IN_MEMORY_LENGTH = len(STRING_IN_MEMORY + "\x00") + + class ArchMock: + def __init__(self): + pass + + @property + def bits(self): + return 4 + + class CFGMock: + def __init__(self, memory_data): + self._memory_data = memory_data + + @property + def memory_data(self): + return self._memory_data + + class MemoryDataMock: + def __init__(self, address, content, size, sort): + self._address = address + self._content = content + self._size = size + self._sort = sort + + @property + def address(self): + return self._address + + @property + def content(self): + return self._content + + @property + def size(self): + return self._size + + @property + def sort(self): + return self._sort + + def test_contains_an_external_definition_return_false_when_all_definitions_are_local( + self, + ): + local_definitions = list( + map(lambda i: Definition(Register(i * 4, 4), CodeLocation(i, 0)), range(4)) + ) + + # Create the following dependency graph: + # R0 -> R1 -> R2 -> R3 + dependencies_graph = networkx.DiGraph( + [ + (local_definitions[0], local_definitions[1]), + (local_definitions[1], local_definitions[2]), + (local_definitions[2], local_definitions[3]), + ] + ) + + class A: + dep_graph = DepGraph(dependencies_graph) + + transitive_closures = {0: {Closure(None, A(), None)}} + + self.assertFalse(contains_an_external_definition(transitive_closures)) + + def test_contains_an_external_definition_return_true_when_at_least_one_definition_is_external( + self, + ): + external_definition = Definition(Register(0, 4), ExternalCodeLocation()) + local_definitions = list( + map( + lambda i: Definition(Register(i * 4, 4), CodeLocation(i, 0)), + range(1, 4), + ) + ) + + # Create the following dependency graph: + # R0 (external) -> R1 -> R2 -> R3 + dependencies_graph = networkx.DiGraph( + [ + (external_definition, local_definitions[0]), + (local_definitions[0], local_definitions[1]), + (local_definitions[1], local_definitions[2]), + ] + ) + + class A: + dep_graph = DepGraph(dependencies_graph) + transitive_closures = {0: {Closure(None, A(), None)}} + + self.assertTrue(contains_an_external_definition(transitive_closures)) + + #def test_represents_constant_data_fails_if_definition_is_not_in_dependency_graph( + # self, + #): + # reach_def = _init_reach_def() + # reg = Register(0, 4) + # codeloc = CodeLocation(0, 0) + # definition = Definition(reg, codeloc) + # values = MultiValues(offset_to_values={0: {claripy.BVV(0, 4 * 8)}}) + # reach_def.kill_and_add_definition(reg, codeloc, values) + + # dependency_graph = networkx.DiGraph() + + # with self.assertRaises(AssertionError) as cm: + # represents_constant_data(definition, values, reach_def, dependency_graph) + + # ex = cm.exception + # self.assertEqual( + # str(ex), "The given Definition must be present in the given graph." + # ) + + def test_represents_constant_data_returns_True_if_definition_is_a_memory_location_and_its_data_is_a_string( + self, + ): + reach_def = _init_reach_def() + memloc = MemoryLocation(0x42, len(self.STRING_IN_MEMORY)) + codeloc = CodeLocation(0, 0) + definition = Definition(memloc, codeloc) + + values = MultiValues( + offset_to_values={ + 0: {claripy.BVV(self.STRING_IN_MEMORY, len(self.STRING_IN_MEMORY) * 8)} + } + ) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(definition) + + self.assertTrue( + represents_constant_data(definition, values, reach_def) + ) + + def test_represents_constant_data_returns_False_if_definition_is_a_memory_location_and_its_data_contains_undefined( + self, + ): + reach_def = _init_reach_def() + memloc = MemoryLocation(0x42, len(self.STRING_IN_MEMORY)) + codeloc = CodeLocation(0, 0) + definition = Definition(memloc, codeloc) + values = MultiValues( + offset_to_values={ + 0: { + claripy.BVV(self.STRING_IN_MEMORY, len(self.STRING_IN_MEMORY) * 8), + claripy.BVS( + "TOP", len(self.STRING_IN_MEMORY) * 8, explicit_name=True + ), + } + } + ) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(definition) + self.assertFalse( + represents_constant_data(definition, values, reach_def) + ) + + def test_represents_constant_data_returns_True_if_definition_is_a_memory_location_and_its_data_is_an_integer( + self, + ): + reach_def = _init_reach_def() + memloc = MemoryLocation(0x42, 1) + codeloc = CodeLocation(0, 0) + definition = Definition(memloc, codeloc) + + values = MultiValues(offset_to_values={0: {claripy.BVV(0x84, 8)}}) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(definition) + + self.assertTrue( + represents_constant_data(definition, values, reach_def) + ) + + def test_represents_constant_data_returns_True_if_definition_is_a_memory_location_and_its_data_is_a_concat_completely_resolved( + self, + ): + reach_def = _init_reach_def() + memloc = MemoryLocation(0x42, 8) + codeloc = CodeLocation(0, 0) + definition = Definition(memloc, codeloc) + + values = MultiValues( + offset_to_values={ + 0: {claripy.BVV("cons", 8 * 4)}, + 4: {claripy.BVV("tant", 8 * 4)}, + } + ) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(definition) + self.assertTrue( + represents_constant_data(definition, values, reach_def) + ) + + def test_represents_constant_data_returns_False_if_definition_is_a_memory_location_and_its_data_is_a_concat_containing_undefined( + self, + ): + reach_def = _init_reach_def() + memloc = MemoryLocation(0x42, 9) + codeloc = CodeLocation(0, 0) + definition = Definition(memloc, codeloc) + + values = MultiValues( + offset_to_values={ + 0: {reach_def.top(8)}, + 1: {claripy.BVV("constant", 8 * 8)}, + } + ) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(definition) + + self.assertFalse( + represents_constant_data(definition, values, reach_def) + ) + + def test_represents_constant_data_returns_True_when_it_is_a_register_that_is_a_memory_address_pointing_to_a_constant_string( + self, + ): + reach_def = _init_reach_def() + mem_address = 0x42 + memloc = MemoryLocation(mem_address, 8) + codeloc = CodeLocation(0, 0) + mem_loc_definition = Definition(memloc, codeloc) + + values = MultiValues(claripy.BVV("constant", 8 * 8)) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + reg = Register(0, reach_def.arch.bits // 8) + codeloc = CodeLocation(10, 0) + reg_definition = Definition(reg, codeloc) + + values = MultiValues(claripy.BVV(mem_address, reach_def.arch.bits)) + reach_def.kill_and_add_definition(reg, codeloc, values) + + dependency_graph = networkx.DiGraph([(mem_loc_definition, reg_definition)]) + + self.assertTrue( + represents_constant_data( + reg_definition, values, reach_def + ) + ) + + def test_represents_constant_data_returns_True_when_it_is_a_register_that_is_a_stack_offset_pointing_to_a_constant_string( + self, + ): + reach_def = _init_reach_def() + sp_offset = reach_def.stack_address(0x8) + sp_offset_loc = MemoryLocation(sp_offset, len(self.STRING_IN_MEMORY)) + codeloc = CodeLocation(0, 0) + sp_offset_definition = Definition(sp_offset_loc, codeloc) + + values = MultiValues( + offset_to_values={ + 0: {claripy.BVV(self.STRING_IN_MEMORY, len(self.STRING_IN_MEMORY) * 8)} + } + ) + reach_def.kill_and_add_definition(sp_offset_loc, codeloc, values) + + reg = Register(0, reach_def.arch.bytes) + codeloc = CodeLocation(10, 0) + register_definition = Definition(reg, codeloc) + + values = MultiValues(offset_to_values={0: {sp_offset}}) + reach_def.kill_and_add_definition(reg, codeloc, values) + + dependency_graph = networkx.DiGraph( + [(sp_offset_definition, register_definition)] + ) + + self.assertTrue( + represents_constant_data( + register_definition, values, reach_def + ) + ) + + def test_represents_constant_data_returns_False_when_it_is_a_register_taking_at_least_an_unknown_value( + self, + ): + reach_def = _init_reach_def() + memory_address = 0x42 + memloc = MemoryLocation(memory_address, len(self.STRING_IN_MEMORY)) + codeloc = CodeLocation(0, 0) + memory_location_definition = Definition(memloc, codeloc) + + values = MultiValues( + offset_to_values={ + 0: {claripy.BVV(self.STRING_IN_MEMORY, len(self.STRING_IN_MEMORY) * 8)} + } + ) + reach_def.kill_and_add_definition(memloc, codeloc, values) + + reg = Register(0, reach_def.arch.bytes) + codeloc = CodeLocation(10, 0) + register_definition = Definition(reg, codeloc) + + values = MultiValues( + offset_to_values={ + 0: { + claripy.BVV(memory_address, reach_def.arch.bits), + reach_def.top(reach_def.arch.bits), + } + } + ) + reach_def.kill_and_add_definition(reg, codeloc, values) + + self.assertFalse( + represents_constant_data( + register_definition, values, reach_def + ) + ) + + def test_represents_constant_data_returns_False_when_it_is_a_register_that_can_take_memory_address_not_defined_earlier( + self, + ): + reach_def = _init_reach_def() + reg = Register(0, reach_def.arch.bytes) + codeloc = CodeLocation(10, 0) + register_definition = Definition(reg, codeloc) + + values = MultiValues( + offset_to_values={0: {claripy.BVV(0xBEEF, reach_def.arch.bits)}} + ) + reach_def.kill_and_add_definition(reg, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(register_definition) + + self.assertFalse( + represents_constant_data( + register_definition, values, reach_def + ) + ) + + def test_represents_constant_data_returns_False_when_it_is_a_register_that_can_take_sp_offset_not_defined_earlier( + self, + ): + reach_def = _init_reach_def() + reg = Register(0, reach_def.arch.bytes) + codeloc = CodeLocation(10, 0) + register_definition = Definition(reg, codeloc) + + values = MultiValues(offset_to_values={0: {reach_def.stack_address(0x4)}}) + reach_def.kill_and_add_definition(reg, codeloc, values) + + dependency_graph = networkx.DiGraph() + dependency_graph.add_node(register_definition) + + self.assertFalse( + represents_constant_data( + register_definition, values, reach_def + ) + ) diff --git a/package/tests/test_utils.py b/package/tests/test_utils.py new file mode 100644 index 0000000..138f280 --- /dev/null +++ b/package/tests/test_utils.py @@ -0,0 +1,212 @@ +import claripy +import logging + +from unittest import TestCase + +from archinfo import ArchAMD64 +from angr.calling_conventions import SimStackArg + +from angr.storage.memory_mixins.paged_memory.pages.multi_values import MultiValues +from angr.knowledge_plugins.key_definitions.atoms import MemoryLocation, Register +from angr.code_location import CodeLocation +from angr.analyses.reaching_definitions.reaching_definitions import LiveDefinitions + +from argument_resolver.utils.utils import Utils + +LOGGER = logging.getLogger("argument_resolver/test_utils") + + +def _write(string, memory, at_address): + i = 0 + for char in string: + memory[at_address + i] = char + i += 1 + + +class TestUtils(TestCase): + def test_get_values_from_cc_arg_with_sp_offset(self): + arch = ArchAMD64() + reach_def = LiveDefinitions(arch=arch) + + sp = Register(arch.sp_offset, arch.bytes) + sp_offset = reach_def.stack_address(arch.sp_offset) + + # Assume argument is located 4 words away from the stack pointer, + # and has some arbitrary value. + arg_stack_offset = 4 + arbitrary_value = 0 + + # Segment of memory to represent a portion of the stack containing our parameter. + ml = MemoryLocation(sp_offset + arg_stack_offset, arch.bytes) + ml_value = MultiValues( + offset_to_values={0: {claripy.BVV(arbitrary_value, arch.bits)}} + ) + + # We need to setup two components in the `LiveDefinitions`: + # - the `sp` register value to make offset computations possible; + # - the portion of the stack containing our parameter. + reach_def.registers.store(sp.reg_offset, sp_offset, sp.size) + reach_def.stack.store( + reach_def.get_stack_address(ml.addr), ml_value, ml.size + ) + + definitions = Utils.get_values_from_cc_arg( + SimStackArg(arg_stack_offset, arch.bytes), reach_def, arch + ) + + self.assertEqual(definitions.one_value()._model_concrete.value, 0x0) + + def test_get_values_from_cc_arg_with_invalid_first_arg(self): + with self.assertRaises(TypeError) as cm: + Utils.get_values_from_cc_arg([], None, None) + + ex = cm.exception + self.assertEqual(str(ex), "Expected SimRegArg or SimStackArg, got list") + + def test_get_prototypes_from_format_string(self): + prototypes = Utils.get_prototypes_from_format_string( + "foo: %s, bar: %#x, baz: %10i" + ) + + self.assertEqual(prototypes[0].prototype, "%s") + self.assertEqual(prototypes[0].specifier, "s") + self.assertEqual(prototypes[0].position, 5) + + self.assertEqual(prototypes[1].prototype, "%#x") + self.assertEqual(prototypes[1].specifier, "x") + self.assertEqual(prototypes[1].position, 14) + + self.assertEqual(prototypes[2].prototype, "%10i") + self.assertEqual(prototypes[2].specifier, "i") + self.assertEqual(prototypes[2].position, 24) + + def test_is_stack_address(self): + base = claripy.BVS("stack_base", 64, explicit_name=True) + offset = 0xBEEF + base += offset + + not_base = claripy.BVS("TOP", 64, explicit_name=True) + + self.assertTrue(Utils.is_stack_address(base)) + self.assertFalse(Utils.is_stack_address(not_base)) + + def test_bytes_from_int(self): + string = b"Hello World!" + byte_string = claripy.BVV(string, len(string) * 8) + + result = Utils.bytes_from_int(byte_string) + self.assertTrue(result == string) + + def test_get_strings_from_pointer_concrete_memory_address(self): + arch = ArchAMD64() + reach_def = LiveDefinitions(arch=arch) + string = b"Hello World!" + + sp = Register(arch.sp_offset, arch.bytes) + sp_offset = reach_def.stack_address(arch.sp_offset) + + reach_def.registers.store(sp.reg_offset, sp_offset, sp.size) + + mem_loc = MemoryLocation(claripy.BVV(0x40000, arch.bits), len(string)) + code_loc = CodeLocation(0, 0) + concrete_mv = MultiValues( + offset_to_values={0: {claripy.BVV(string, len(string) * 8)}} + ) + + reach_def.kill_and_add_definition(mem_loc, code_loc, concrete_mv) + + strings = Utils.get_strings_from_pointer(mem_loc.addr, reach_def, code_loc) + self.assertEqual(Utils.bytes_from_int(strings.one_value()), string) + + def test_get_strings_from_pointer_concrete_stack_address(self): + arch = ArchAMD64() + reach_def = LiveDefinitions(arch=arch) + string = b"Hello World!" + + sp = Register(arch.sp_offset, arch.bytes) + sp_offset = reach_def.stack_address(arch.sp_offset) + + reach_def.registers.store(sp.reg_offset, sp_offset, sp.size) + + mem_loc = MemoryLocation( + claripy.BVV(reach_def.get_stack_address(sp_offset), arch.bits), len(string) + ) + code_loc = CodeLocation(0, 0) + concrete_mv = MultiValues( + offset_to_values={0: {claripy.BVV(string, len(string) * 8)}} + ) + + reach_def.kill_and_add_definition(mem_loc, code_loc, concrete_mv) + + strings = Utils.get_strings_from_pointer(mem_loc.addr, reach_def, code_loc) + self.assertEqual(Utils.bytes_from_int(strings.one_value()), string) + + def test_get_strings_from_pointer_symbolic_stack_address(self): + arch = ArchAMD64() + reach_def = LiveDefinitions(arch=arch) + string = b"Hello World!" + + sp = Register(arch.sp_offset, arch.bytes) + sp_offset = reach_def.stack_address(arch.sp_offset) + + reach_def.registers.store(sp.reg_offset, sp_offset, sp.size) + + mem_loc = MemoryLocation(sp_offset, len(string)) + code_loc = CodeLocation(0, 0) + concrete_mv = MultiValues( + offset_to_values={0: {claripy.BVV(string, len(string) * 8)}} + ) + + reach_def.kill_and_add_definition(mem_loc, code_loc, concrete_mv) + + strings = Utils.get_strings_from_pointer(mem_loc.addr, reach_def, code_loc) + self.assertEqual(Utils.bytes_from_int(strings.one_value()), string) + + def test_get_strings_from_pointer_unknown_address(self): + arch = ArchAMD64() + reach_def = LiveDefinitions(arch=arch) + string = b"Hello World!" + + sp = Register(arch.sp_offset, arch.bytes) + sp_offset = reach_def.stack_address(arch.sp_offset) + + reach_def.registers.store(sp.reg_offset, sp_offset, sp.size) + + mem_loc = MemoryLocation(reach_def.top(arch.bits), len(string)) + code_loc = CodeLocation(0, 0) + concrete_mv = MultiValues( + offset_to_values={0: {claripy.BVV(string, len(string) * 8)}} + ) + + reach_def.kill_and_add_definition(mem_loc, code_loc, concrete_mv) + + strings = Utils.get_strings_from_pointer(mem_loc.addr, reach_def, code_loc) + self.assertTrue(reach_def.is_top(strings.one_value())) + + def test_get_values_from_multivalues(self): + values = MultiValues( + offset_to_values={ + 0: {claripy.BVV(0x0, 8), claripy.BVV(0x1, 8)}, + 8: {claripy.BVV(0x2, 8)}, + } + ) + all_vals = Utils.get_values_from_multivalues(values) + self.assertEqual( + sorted([2, 258]), sorted([x._model_concrete.value for x in all_vals]) + ) + + def test_get_size_from_multivalues(self): + string = "Hello World!" + values = MultiValues( + offset_to_values={ + 0: {claripy.BVV(string[: len(string) // 2], len(string) // 2 * 8)}, + 6: {claripy.BVV(string[len(string) // 2 :], len(string) // 2 * 8)}, + } + ) + size = Utils.get_size_from_multivalue(values) + self.assertEqual(size, len(string)) + + def test_strip_null_from_string(self): + string = "Hello World!\x00" + value = claripy.BVV(string, len(string) * 8) + self.assertTrue(value[7:]._model_concrete.value == 0) diff --git a/pipeline/MANIFEST.in b/pipeline/MANIFEST.in new file mode 100644 index 0000000..f0d4a9c --- /dev/null +++ b/pipeline/MANIFEST.in @@ -0,0 +1 @@ +recursive-include mango_pipeline/configs * diff --git a/pipeline/README.md b/pipeline/README.md new file mode 100644 index 0000000..257bce4 --- /dev/null +++ b/pipeline/README.md @@ -0,0 +1,70 @@ +# Mango Pipeline + +This package is specifically for running operation mango on large datasets for maximum parallelization + +## Installation + +Run `pip install -e .` (Note: mango must already be installed and this can only be installed with `-e`) + +The `mango-pipeline` utility is now available for your use. + +## Layout + +`mango-pipeline` expects either a flat directory filled with binaries or a structured directory. +This tool was specifically built to tackle firmware, you can try your luck with other usecases. +If you are attempting to run this tool on multiple firmware, this is the directory structure you will want to use: +``` +root_dir/ + vendor_name/ + firmware_name/ + ..... + squashfs-root/ +``` + +## Usage + +This package allows you to run a parallelized version of mango either locally or remotely. + +You'll always want these two options filled out. + +`mango-pipeline --path /directory/to/analyze --results /output_dir` + +If you want to run a parallelized workload locally, try doing something like this: + +```mango-pipeline --path /directory/to/analyze --results /output_dir --build-docker --full --parallel NUM_CONTAINERS --categories cmdi``` + +If you're attempting to run this on a remote kubernetes cluster, first edit the `mango_pipeline/configs/pipeline.toml` with your cluster information. +> [!CAUTION] +> This is probably super broken as it was tailored to my exact setup, try the docker container version. +> Attempting this route is pain and you have been warned. + +First run the environment analysis: + +```mango-pipeline --path /directory/to/analyze --results /output_dir --build-docker --kube --env --parallel NUM_CONTAINERS``` + +Download the results: + +```mango-pipeline --path /directory/to/analyze --results /output_dir --kube --download-results``` + +Run the analysis: + +```mango-pipeline --path /directory/to/analyze --results /output_dir --kube --mango --categories cmdi --parallel NUM_CONTAINERS``` + +Download the final results: + +```mango-pipeline --path /directory/to/analyze --results /output_dir --kube --download-results``` + +### Local +`--build-docker` - You only need to run this once unless you're editing the codebase. +This will build the docker container found in the `docker` folder in the root of this project. + +### Remote +`--kube` - This option forces all work to be done on a remote kubernetes setup. +Use `--build-docker` it will attempt to push the container to the remote docker repository denoted in the `mango_pipeline/configs` folder. (unstable) + +### Output +`--status` - A static printout of how many results there are, how many have errored and which vendor/firmware the results belong to. + +`--gen-csv` - Generate a CSV file of all the results. + +`--aggregate-results AGG_FOLDER` - Generate a folder of all the unique results with potential bugs diff --git a/pipeline/mango_pipeline.egg-info/PKG-INFO b/pipeline/mango_pipeline.egg-info/PKG-INFO new file mode 100644 index 0000000..fbb8cad --- /dev/null +++ b/pipeline/mango_pipeline.egg-info/PKG-INFO @@ -0,0 +1,21 @@ +Metadata-Version: 2.1 +Name: mango_pipeline +Version: 0.0.1 +Summary: A utility to facilitate parallelization across multiple target files for argument_resolver +Author-email: Wil Gibbs +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Python: >=3.10 +Requires-Dist: argument_resolver +Requires-Dist: rich==13.7.1 +Requires-Dist: docker==7.0.0 +Requires-Dist: toml==0.10.2 +Requires-Dist: kubernetes==29.0.0 +Requires-Dist: esprima==4.0.1 +Requires-Dist: phply==1.2.6 +Requires-Dist: bs4==0.0.2 +Requires-Dist: lxml==5.1.0 +Requires-Dist: binwalk@ git+https://github.com/ReFirmLabs/binwalk +Requires-Dist: pyyaml==6.0.1 diff --git a/pipeline/mango_pipeline.egg-info/SOURCES.txt b/pipeline/mango_pipeline.egg-info/SOURCES.txt new file mode 100644 index 0000000..6eeed9f --- /dev/null +++ b/pipeline/mango_pipeline.egg-info/SOURCES.txt @@ -0,0 +1,30 @@ +MANIFEST.in +README.md +pyproject.toml +mango_pipeline/__init__.py +mango_pipeline/base.py +mango_pipeline/kube.py +mango_pipeline/local.py +mango_pipeline/remote.py +mango_pipeline/run.py +mango_pipeline.egg-info/PKG-INFO +mango_pipeline.egg-info/SOURCES.txt +mango_pipeline.egg-info/dependency_links.txt +mango_pipeline.egg-info/entry_points.txt +mango_pipeline.egg-info/requires.txt +mango_pipeline.egg-info/top_level.txt +mango_pipeline/firmware/__init__.py +mango_pipeline/firmware/elf_finder.py +mango_pipeline/firmware/elf_info.py +mango_pipeline/firmware/keyword_finder.py +mango_pipeline/scripts/__init__.py +mango_pipeline/scripts/ablation.py +mango_pipeline/scripts/aggregate.py +mango_pipeline/scripts/data_printer.py +mango_pipeline/scripts/de-dup.py +mango_pipeline/scripts/extract.py +mango_pipeline/scripts/get_tp_from_sheet.py +mango_pipeline/scripts/path_context_aggregator.py +mango_pipeline/scripts/show_data.py +mango_pipeline/scripts/show_table.py +mango_pipeline/scripts/symbols_and_vendors.py \ No newline at end of file diff --git a/pipeline/mango_pipeline.egg-info/dependency_links.txt b/pipeline/mango_pipeline.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/pipeline/mango_pipeline.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/pipeline/mango_pipeline.egg-info/entry_points.txt b/pipeline/mango_pipeline.egg-info/entry_points.txt new file mode 100644 index 0000000..aac90cf --- /dev/null +++ b/pipeline/mango_pipeline.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +mango-pipeline = mango_pipeline.run:main diff --git a/pipeline/mango_pipeline.egg-info/requires.txt b/pipeline/mango_pipeline.egg-info/requires.txt new file mode 100644 index 0000000..03635e5 --- /dev/null +++ b/pipeline/mango_pipeline.egg-info/requires.txt @@ -0,0 +1,11 @@ +argument_resolver +rich==13.7.1 +docker==7.0.0 +toml==0.10.2 +kubernetes==29.0.0 +esprima==4.0.1 +phply==1.2.6 +bs4==0.0.2 +lxml==5.1.0 +binwalk@ git+https://github.com/ReFirmLabs/binwalk +pyyaml==6.0.1 diff --git a/pipeline/mango_pipeline.egg-info/top_level.txt b/pipeline/mango_pipeline.egg-info/top_level.txt new file mode 100644 index 0000000..c78bbfc --- /dev/null +++ b/pipeline/mango_pipeline.egg-info/top_level.txt @@ -0,0 +1 @@ +mango_pipeline diff --git a/pipeline/mango_pipeline/__init__.py b/pipeline/mango_pipeline/__init__.py new file mode 100644 index 0000000..b2fcbb8 --- /dev/null +++ b/pipeline/mango_pipeline/__init__.py @@ -0,0 +1,10 @@ +from pathlib import Path +PROJECT_DIR = Path(__file__).resolve(strict=True).parent.parent.parent.absolute() + +from mango_pipeline.local import PipelineLocal +from mango_pipeline.remote import PipelineRemote +from mango_pipeline.kube import PipelineKube + + +__version__ = "0.0.1" + diff --git a/pipeline/mango_pipeline/base.py b/pipeline/mango_pipeline/base.py new file mode 100644 index 0000000..c2f949a --- /dev/null +++ b/pipeline/mango_pipeline/base.py @@ -0,0 +1,842 @@ +import csv +import datetime +import hashlib +import os +import subprocess +import shutil +import collections.abc + +import elftools.common.exceptions +import toml +import json + +from typing import Set, Dict, List, Tuple +from pathlib import Path + +import docker +import rich + +from docker.errors import ContainerError + +from rich.progress import ( + Progress, + MofNCompleteColumn, + TimeElapsedColumn, + TimeRemainingColumn, + BarColumn, + TextColumn, + SpinnerColumn, +) + +from argument_resolver.external_function.sink import ENV_SINKS +from argument_resolver.analysis.base import ScriptBase + +from rich.console import Console, Group + +from elftools.elf.elffile import ELFFile + +from . import PROJECT_DIR +from .firmware import ELFInfo, FirmwareFinder +from .scripts import data_printer + + +class MyProgress(Progress): + def __init__(self, *args, renderable_callback=None, **kwargs): + self.renderable_callback = renderable_callback + super().__init__( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + *args, + **kwargs, + ) + + def get_renderables(self): + if self.renderable_callback: + yield Group(self.make_tasks_table(self.tasks), self.renderable_callback()) + else: + yield self.make_tasks_table(self.tasks) + + +class Pipeline: + docker_file: Path + + def __init__( + self, + target: Path, + results_dir: Path, + parallel: int = 1, + is_env=False, + is_mango=False, + category: str = None, + brand="", + firmware="", + extra_args=None, + job_name=None, + py_spy=False, + timeout=120 * 60, + rda_timeout=5 * 60, + bin_prep=False, + exclude_libs=True, + show_dups=False, + ): + + self.target = target + self.results_dir = results_dir + self.is_env = is_env + self.is_mango = is_mango + self.brand = brand + self.firmware = firmware + self.category = category + self.parallel = parallel + self.total_mango_results = {} + self.total_env_results = {} + self.extra_args = extra_args + self.job_name = job_name + self.py_spy = py_spy + self.timeout = timeout + self.rda_timeout = rda_timeout + self.bin_prep = bin_prep + self.container_name = "mango_user" + self.docker_file = Path(__file__).parent.parent.parent / "docker" / "Dockerfile" + self.vendor_dict = {} + self.exclude_libs = exclude_libs + self.show_dups = show_dups + + @staticmethod + def create_default_config(config_path: Path): + """ + Creates Empty/Default Config Files + :return: + """ + + pipeline_config_file = config_path + pipeline_config = dict() + pipeline_config["remote"] = { + "local_image": "mango_kube", + "auth_config": {"username": "", "password": ""}, + "registry": "", + "registry_image": "", + "registry_secret_name": "", + } + + pipeline_config_file.parent.mkdir(parents=True, exist_ok=True) + with pipeline_config_file.open("w") as f: + toml.dump(pipeline_config, f) + + def get_experiment_targets(self, remote=False) -> Tuple[Set[ELFInfo], Dict]: + if self.target.is_dir() or remote: + if self.vendor_dict: + + if self.brand and self.brand not in self.vendor_dict: + rich.get_console().log(f"[red]FAILED TO FIND BRAND {self.brand}") + return set(), dict() + + if ( + self.firmware + and self.firmware not in self.vendor_dict[self.brand]["firmware"] + ): + rich.get_console().log( + f"[red]FAILED TO FIND FIRMWARE {self.firmware} in {self.brand}" + ) + return set(), dict() + + exp_list = set() + duplicates = dict() + for brand, firmware_dict in self.vendor_dict.items(): + if self.brand and brand != self.brand: + continue + (self.results_dir / brand).mkdir(parents=True, exist_ok=True) + for firmware, bin_dict in firmware_dict["firmware"].items(): + if self.firmware and firmware != self.firmware: + continue + (self.results_dir / brand / firmware).mkdir( + parents=True, exist_ok=True + ) + for sha, elf_dict in bin_dict["elfs"].items(): + info = ELFInfo( + path=elf_dict["path"], + brand=brand, + firmware=firmware, + sha=sha, + ) + if sha not in duplicates: + duplicates[sha] = [] + duplicates[sha].append(info) + exp_list.add(info) + for sha, elfs in duplicates.copy().items(): + if len(elfs) < 2: + duplicates.pop(sha) + return exp_list, duplicates + + with Progress() as progress: + elfs = FirmwareFinder.find_elf_files( + self.target, progress, exclude_libs=self.exclude_libs + ) + return { + ELFInfo(sha=sha, path=path["path"], firmware="", brand="") + for sha, path in elfs.items() + }, dict() + + return { + ELFInfo( + sha=self.get_sha(self.target), + path=str(self.target), + firmware="", + brand="", + ) + }, dict() + + def get_symbols_and_targets(self) -> Tuple[Dict[str, set[str]], Set[ELFInfo]]: + targets, duplicates = self.get_experiment_targets() + self.link_duplicates(targets, duplicates) + + symbols = self.get_target_symbols(targets) + + return symbols, targets + + def run_experiment(self): + """ + Function to run all experiments + """ + + symbols, targets = self.get_symbols_and_targets() + + if self.is_env: + env_targets = self.filter_env_targets(targets, symbols) + self.run_env_resolve(env_targets) + + if self.is_mango: + mango_targets = self.filter_mango_targets(targets, symbols) + self.run_mango(mango_targets) + self.mango_results_to_csv() + + def link_duplicates( + self, targets: Set[ELFInfo], duplicates: Dict[str, List[ELFInfo]] + ): + for target in targets: + if target.sha in duplicates: + target_file = ( + self.results_dir / target.brand / target.firmware / target.sha + ) + for dup in duplicates[target.sha]: + dup_file = self.results_dir / dup.brand / dup.firmware / dup.sha + dup_file.mkdir(parents=True, exist_ok=True) + if ( + not (dup_file / "env.json").exists() + and (target_file / "env.json").exists() + ): + os.link( + (target_file / "env.json").absolute().resolve(), + (dup_file / "env.json").absolute().resolve(), + ) + if ( + not (dup_file / f"{self.category}_results.json").exists() + and (target_file / f"{self.category}_results.json").exists() + ): + os.link( + (target_file / f"{self.category}_results.json") + .absolute() + .resolve(), + (dup_file / f"{self.category}_results.json") + .absolute() + .resolve(), + ) + if ( + not (dup_file / f"{self.category}_mango.out").exists() + and (target_file / f"{self.category}_mango.out").exists() + ): + os.link( + (target_file / f"{self.category}_mango.out") + .absolute() + .resolve(), + (dup_file / f"{self.category}_mango.out") + .absolute() + .resolve(), + ) + + def get_target_symbols(self, targets: Set[ELFInfo]) -> Dict[str, Set[str]]: + """ + Analyzes target and creates a set of available functions + :param targets: + :return: + """ + symbols = {} + symbol_file = self.results_dir / self.brand / self.firmware / "symbols.json" + prev_data = {} + if symbol_file.exists(): + try: + prev_data = json.loads(symbol_file.read_text()) + if not self.bin_prep: + return prev_data + except json.decoder.JSONDecodeError: + pass + + progressbar = Progress() + progressbar.start() + symbol_task = progressbar.add_task("Getting Symbols ...", total=len(targets)) + for target in targets: + if target.brand not in symbols: + symbols[target.brand] = {} + if target.firmware not in symbols[target.brand]: + symbols[target.brand][target.firmware] = {} + + with open(target.path, "rb") as f: + try: + elf = ELFFile(f) + symbols_sections = [elf.get_section_by_name(".dynsym")] + except ( + elftools.common.exceptions.ELFParseError, + elftools.common.exceptions.ELFError, + ): + progressbar.update(symbol_task, advance=1) + continue + symbols_sections += [x for x in elf.iter_segments(type="PT_DYNAMIC")] + symbols_sections = [x for x in symbols_sections if x] + symbols[target.brand][target.firmware][target.sha] = [] + if symbols_sections: + for symbols_section in symbols_sections: + try: + symbols[target.brand][target.firmware][target.sha] = list( + set(symbols[target.brand][target.firmware][target.sha]) + | { + symbol.name + for symbol in symbols_section.iter_symbols() + } + ) + except ( + elftools.common.exceptions.ELFError, + elftools.common.exceptions.ELFParseError, + ValueError, + AttributeError, + AssertionError, + ): + pass + progressbar.update(symbol_task, advance=1) + + progressbar.stop() + final_symbols = {} + for brand, firmware_dict in symbols.items(): + for firmware, sha_dict in firmware_dict.items(): + symbol_dict = { + "brand": brand, + "firmware": firmware, + "symbols": {k: list(v) for k, v in sha_dict.items()}, + } + final_symbols.update(sha_dict) + with open( + self.results_dir / brand / firmware / "symbols.json", "w+" + ) as f: + json.dump(symbol_dict, f, indent=4) + + final_symbols.update(prev_data) + (self.results_dir / "symbols.json").write_text(json.dumps(final_symbols)) + + return final_symbols + + def run_mango(self, targets: Set[ELFInfo]): + """ + Function to run operation mango on targets + :return: + """ + raise NotImplementedError("Run Mango Function Must Be Implemented") + + def run_env_resolve(self, targets: Set[ELFInfo]): + """ + Function to run env_resolve on targets + :return: + """ + raise NotImplementedError("Run EnvResolve Function Must Be Implemented") + + def build_container(self): + """ + Builds docker container for target. + """ + cli = docker.APIClient() + resp = cli.build( + path=str(PROJECT_DIR), + dockerfile=str(self.docker_file), + tag=self.container_name, + decode=True, + ) + console = Console() + output = "" + with console.screen(): + for line in resp: + if "stream" in line: + console.print(line["stream"], end="") + output += line["stream"] + elif "errorDetail" in line: + break + if "errorDetail" in line: + console.print(output) + console.print(f"[red bold]{line['error']}") + + @staticmethod + def get_sha(file: Path) -> str: + """ + Get SHA256 sum for given file + """ + with file.open("rb") as f: + return hashlib.file_digest(f, "sha256").hexdigest() + + def mango_results_to_csv(self, full=False): + csv_file = open(self.results_dir / "results.csv", "w", newline="") + results_writer = csv.writer( + csv_file, delimiter="\t", quotechar="|", quoting=csv.QUOTE_MINIMAL + ) + + titles = [ + "Brand", + "Firmware", + "SHA256", + "Name", + "Sink", + "Addr", + "TP", + "CFG Time", + "VRA Time", + "Analysis Time", + "Checked By", + "Notes", + ] + rows = [titles] + results = {} + shas = set() + brands = sorted( + [x for x in self.results_dir.iterdir() if x.is_dir()], key=lambda x: x.name + ) + for brand in brands: + results[brand.name] = {} + firmwares = sorted( + [x for x in brand.iterdir() if x.is_dir()], key=lambda x: x.name + ) + for firmware in firmwares: + results[brand.name][firmware.name] = { + "env_time": 0, + "cfg_time": 0, + "vra_time": 0, + "analysis_time": 0, + } + + elfs = sorted( + [x for x in firmware.iterdir() if x.is_dir()], key=lambda x: x.name + ) + for elf in elfs: + results_file = elf / f"{self.category}_results.json" + env_file = elf / "env.json" + + if not results_file.exists(): + continue + + if env_file.exists(): + try: + env_data = json.loads(results_file.read_text()) + if env_data["error"] is None: + results[brand.name][firmware.name]["env_time"] += ( + env_data["cfg_time"] + + env_data["vra_time"] + + env_data["mango_time"] + ) + except: + pass + + data = json.loads(results_file.read_text()) + if data["error"] is not None or not data["has_sinks"]: + continue + + results[brand.name][firmware.name]["cfg_time"] += data["cfg_time"] + results[brand.name][firmware.name]["vra_time"] += data["vra_time"] + results[brand.name][firmware.name]["analysis_time"] += data[ + "mango_time" + ] + + if not full and data["sha256"] in shas: + continue + else: + shas.add(data["sha256"]) + + if not full and len(data["closures"]) == 0: + continue + + rows.append( + [ + brand.name, + firmware.name, + elf.name, + data["name"], + "", + "", + "", + f"{data['cfg_time']:.2f}", + f"{data['vra_time']:.2f}", + f"{data['mango_time']:.2f}", + "", + "", + ] + ) + for closure in sorted( + data["closures"], key=lambda x: x["sink"]["function"] + ): + sources = {x.split("(")[0].lower() for x in closure["inputs"]} + sources = {x if "nvram" not in x else "nvram" for x in sources} + sources = {x if "recv" not in x else "recv" for x in sources} + rows.append( + [ + "", + "", + "", + "", + closure["sink"]["function"], + closure["sink"]["ins_addr"], + "", + "", + "", + "", + "", + "", + "", + ",".join(sorted(sources)), + str(closure["reachable_from_main"]), + ] + ) + + rows.append([] * 4) + rows.append([]) + rows.insert( + 0, + [ + "Completed", + f'=CountA(G5:G{len(rows)}) & " of " & CountA(F5:F{len(rows)})', + ], + ) + rows.insert( + 0, + [ + "True Positives", + f'=CountIF(H5:H{len(rows)-1}, "Y")/(CountIF(H5:H{len(rows)-1}, "Y") + CountIF(H5:H{len(rows)-1}, "N"))', + ], + ) + rows.insert(0, []) + rows.append([] * 4) + rows.append( + [ + "Firmware", + "ENV Time", + "CFG Time", + "VRA Time", + "Analysis Time", + "Total (Minutes)", + ] + ) + for firmware_dict in results.values(): + for firmware, vals in firmware_dict.items(): + rows.append( + [ + firmware, + vals["env_time"], + vals["cfg_time"], + vals["vra_time"], + vals["analysis_time"], + f"=SUM(B{len(rows)+1}:E{len(rows)+1})/60", + ] + ) + results_writer.writerows(rows) + csv_file.close() + + def filter_mango_targets( + self, targets: Set[ELFInfo], symbols: Dict[str, Set[str]] + ) -> Set[ELFInfo]: + if "symbols" in symbols: + symbols = symbols["symbols"] + mango_targets = set() + mango_sinks = ScriptBase.load_sinks(category=self.category) + + known_shas = set() + for target in targets: + is_dup = target.sha in known_shas + res_file = ( + self.result_dir_from_target(target) / f"{self.category}_results.json" + ) + known_shas.add(target.sha) + if res_file.exists(): + data_printer.parse_mango_result( + self.total_mango_results, res_file, target, dup=is_dup + ) + mango_targets.discard(target) + else: + if target.sha not in symbols: + continue + if not any(sink.name in symbols[target.sha] for sink in mango_sinks): + res_file.parent.mkdir(exist_ok=True, parents=True) + self.save_error_result( + res_file, target, "mango", None, has_sinks=False + ) + else: + mango_targets.add(target) + + return mango_targets + + def result_dir_from_target(self, target: ELFInfo) -> Path: + return self.results_dir / target.brand / target.firmware / target.sha + + def filter_env_targets( + self, targets: Set[ELFInfo], symbols: Dict[str, Set[str]] + ) -> Set[ELFInfo]: + env_targets = set() + known_shas = set() + for target in targets: + is_dup = target.sha in known_shas + known_shas.add(target.sha) + res_file = self.result_dir_from_target(target) / "env.json" + if res_file.exists(): + data_printer.parse_env_result( + self.total_env_results, res_file, target, dup=is_dup + ) + else: + if not any( + target.sha not in symbols or sink.name in symbols[target.sha] + for sink in ENV_SINKS + ): + res_file.parent.mkdir(exist_ok=True, parents=True) + self.save_error_result( + res_file, target, "env_resolve", None, has_sinks=False + ) + else: + env_targets.add(target) + return env_targets + + @staticmethod + def save_error_result( + result_path: Path, target: ELFInfo, script: str, error: str, has_sinks=True + ): + if script == "mango": + out_dict = { + "sha256": target.sha, + "name": Path(target.path).name, + "path": target.path, + "closures": [], + "cfg_time": 0, + "vra_time": 0, + "mango_time": 0, + "error": error, + "ret_code": 0 if not error else 1, + "has_sinks": has_sinks, + } + + elif script == "env_resolve": + out_dict = { + "sha256": target.sha, + "name": Path(target.path).name, + "path": target.path, + "results": {}, + "cfg_time": 0, + "vra_time": 0, + "mango_time": 0, + "ret_code": 0 if not error else 1, + "has_sinks": has_sinks, + "error": error, + } + + else: + out_dict = {"error": error} + + with result_path.open("w+") as f: + json.dump(out_dict, f, indent=4) + + @staticmethod + def console_subprocess(command: list, cwd=None): + if cwd is None: + cwd = os.getcwd() + with Progress(SpinnerColumn(), TextColumn("{task.description}")) as progressbar: + command_task = progressbar.add_task( + description="[bold]Running ...", total=None + ) + progressbar.update( + command_task, + description="[bold]Running " + (" ".join(str(x) for x in command)), + ) + with Console() as c: + with subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=cwd + ) as p: + output = [] + while p.poll() is None: + line = p.stdout.readline().decode().strip() + if line: + c.print(line) + output.append(line) + + ret_code = p.returncode + progressbar.update(command_task, completed=1, total=1) + return ret_code, output + + @staticmethod + def prep_results(directory: Path, result_path: Path, category: str): + results_files = ( + subprocess.check_output( + [ + "find", + str(directory.resolve()), + "-type", + "f", + "-name", + f"{category}_results.json", + ] + ) + .decode() + .strip() + .split("\n") + ) + results_files = [ + Path(x) + for x in results_files + if Path(x).is_file() and not Path(x).is_symlink() + ] + + for results_file in results_files: + data = json.loads(results_file.read_text()) + if data["has_sinks"] is False or data["error"]: + continue + + out_dir = result_path / ( + data["sha256"] if "sha256" in data else data["sha"] + ) + if out_dir.exists(): + pass + else: + out_dir.mkdir(parents=True) + shutil.copy(results_file, out_dir / results_file.name) + shutil.copy( + str(data["path"]).replace( + "/shared/clasm", "/home/clasm/projects/angr-squad" + ), + out_dir / data["name"], + ) + if (results_file.parent / f"{category}_mango.out").exists(): + shutil.copy( + results_file.parent / f"{category}_mango.out", + out_dir / f"{category}_mango.out", + ) + + def env_merge(self): + for brand in [x for x in self.results_dir.iterdir() if x.is_dir()]: + if self.brand and self.brand != brand.name: + continue + + with MyProgress(transient=True) as progress: + firmwares = [ + x + for x in brand.iterdir() + if x.is_dir() + and ( + not self.firmware or (self.firmware and self.firmware == x.name) + ) + ] + firm_task = progress.add_task( + description="Merging Env Results", total=len(firmwares) + ) + for firmware in firmwares: + merge_task = progress.add_task( + description=f"Merging {firmware.name}", total=None + ) + self.env_merge_firmware(firmware) + progress.update(merge_task, visible=False) + progress.update(firm_task, advance=1) + + def env_merge_firmware(self, result_dir: Path): + + docker_client = docker.from_env() + + docker_res = Path("/tmp") / result_dir.name + + volumes = dict() + volumes[str(result_dir.absolute())] = {"bind": str(docker_res), "mode": "rw"} + + escaped_docker_res = str(docker_res) + escaped_out_file = str(docker_res / "env.json") + command = [ + "/angr/bin/env_resolve", + escaped_docker_res, + "--merge", + "--results", + escaped_out_file, + ] + try: + docker_client.containers.run( + self.container_name, + name=f"env_merge_{result_dir.name.replace('%', '_').replace('(', '_').replace(')', '_')}", + command=command, + volumes=volumes, + stdout=True, + stderr=True, + auto_remove=True, + ) + except ContainerError as e: + print(e) + + @staticmethod + def dict_update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = Pipeline.dict_update(d.get(k, {}), v) + else: + d[k] = v + return d + + def merge_symbols(self): + symbols = ( + subprocess.check_output( + ["find", self.results_dir, "-type", "f", "-name", "symbols.json"] + ) + .decode() + .strip() + .split("\n") + ) + + symbol_out = {} + for symbol_file in symbols: + if symbol_file == str(self.results_dir / "symbols.json"): + continue + with open(symbol_file, "r") as f: + symbol_data = json.load(f) + if "symbols" not in symbol_data: + continue + symbol_out.update(symbol_data["symbols"]) + + with open(self.results_dir / "symbols.json", "w+") as f: + json.dump(symbol_out, f, indent=4) + + return symbol_out + + def merge_vendors(self): + vendors = ( + subprocess.check_output( + ["find", self.results_dir, "-type", "f", "-name", "vendor.json"] + ) + .decode() + .strip() + .split("\n") + ) + vendor_out = {} + for vendor_file in vendors: + if ( + not vendor_file + or vendor_file == str(self.results_dir / "vendor.json") + or not Path(vendor_file).exists() + ): + continue + with open(vendor_file, "r") as f: + vendor_data = json.load(f) + p = Path(vendor_file) + vendor_data = { + p.parent.parent.name: {"firmware": {p.parent.name: vendor_data}} + } + vendor_out = self.dict_update(vendor_out, vendor_data) + + with open(self.results_dir / "vendor.json", "w+") as f: + json.dump(vendor_out, f, indent=4) + + return vendor_out diff --git a/pipeline/mango_pipeline/firmware/__init__.py b/pipeline/mango_pipeline/firmware/__init__.py new file mode 100644 index 0000000..bea4ca3 --- /dev/null +++ b/pipeline/mango_pipeline/firmware/__init__.py @@ -0,0 +1,2 @@ +from .elf_info import ELFInfo +from .elf_finder import FirmwareFinder diff --git a/pipeline/mango_pipeline/firmware/elf_finder.py b/pipeline/mango_pipeline/firmware/elf_finder.py new file mode 100644 index 0000000..24a051c --- /dev/null +++ b/pipeline/mango_pipeline/firmware/elf_finder.py @@ -0,0 +1,176 @@ +import os +import sys +import subprocess +import json +import hashlib + +from pathlib import Path + +import binwalk + +from rich.progress import Progress + +from .keyword_finder import find_keywords +from .elf_info import FS_LIST + + +class FirmwareFinder: + """ + The Finder assumes that each Vendor is a top-level directory in the given target directory. + """ + + def __init__( + self, target_dir: Path, results_dir: Path, bin_prep=False, exclude_libs=True + ): + self.target_dir = target_dir.absolute().resolve() + self.results_dir = results_dir + self.exclude_libs = exclude_libs + vendor_file = results_dir / "vendors.json" + if vendor_file.exists() and not bin_prep: + try: + self.vendor_dict = json.loads(vendor_file.read_text()) + except json.decoder.JSONDecodeError: + self.vendor_dict = {} + else: + self.vendor_dict = {} + if bin_prep or not self.vendor_dict: + new_vendor_dict = self.search() + self.vendor_dict.update(new_vendor_dict) + with open(vendor_file, "w+") as f: + json.dump(self.vendor_dict, f, indent=4) + + def extract_firmware(self, vendor): + for root, _, files in os.walk(vendor): + for file in files: + f = Path(root) / file + if f.is_file(): + modules = binwalk.scan(str(f), signature=True, quiet=True) + if any("filesystem" in x.description for result in modules for x in result.results): + binwalk.scan(str(f), signature=True, extract=True, quiet=True) + + def search(self): + vendors = [x for x in self.target_dir.iterdir() if x.is_dir()] + vendor_dict = dict() + with Progress() as progress: + progress.stop() + vendor_task_str = "[red]Scanning Vendor" + vendor_task = progress.add_task( + f"{vendor_task_str} ...", total=len(vendors) + ) + for idx, vendor in enumerate(vendors): + progress.update( + vendor_task, + description=f"{vendor_task_str} {vendor.name} [{idx}/{len(vendors)}]", + ) + found_fs = self.find_extracted_fs(vendor) + if not found_fs: + self.extract_firmware(vendor) + found_fs = self.find_extracted_fs(vendor) + vendor_dict[vendor.name] = {"path": str(vendor), "firmware": dict()} + fs_task_str = "[green]Iterating FS" + fs_task = progress.add_task(f"{fs_task_str} ...", total=len(found_fs)) + for fs_idx, fs in enumerate(found_fs): + keywords = find_keywords(fs, progress=progress) + firm_name = self.firm_name_from_path(fs) + if firm_name is None: + continue + + progress.update( + fs_task, + description=f"{fs_task_str} {firm_name} [{fs_idx}/{len(found_fs)}]", + ) + elf_dict = self.find_elf_files( + fs, progress, exclude_libs=self.exclude_libs + ) + + if firm_name in vendor_dict[vendor.name]["firmware"]: + vendor_dict[vendor.name]["firmware"][firm_name]["elfs"].update( + elf_dict + ) + else: + vendor_dict[vendor.name]["firmware"][firm_name] = { + "path": str(fs.parent), + "elfs": elf_dict, + } + firmware = self.results_dir / vendor.name / firm_name + firmware.mkdir(exist_ok=True, parents=True) + progress.print("WRITING FILE", str(firmware / "vendor.json")) + with (firmware / "vendor.json").open("w+") as f: + json.dump( + vendor_dict[vendor.name]["firmware"][firm_name], f, indent=4 + ) + with (firmware / "keywords.json").open("w+") as f: + json.dump(keywords, f, indent=4) + progress.update(fs_task, advance=1) + + progress.update(fs_task, visible=False) + progress.update(vendor_task, advance=1) + return vendor_dict + + @staticmethod + def firm_name_from_path(path: Path): + if path.parent.name == "fw" or path.parent.name == "firmware": + path = path.parent + firm_name = ( + path.parent.name.replace(".bin", "") + .replace(".extracted", "") + .replace(".chk", "") + .strip("_") + ) + black_list = ["functions", "kernel", "qemu", "net", "squashfs-root"] + if "qemu" in firm_name.lower() or firm_name.lower() in black_list: + return None + return firm_name + + @staticmethod + def find_extracted_fs(root_dir: Path): + found_fs = [] + for fs in FS_LIST: + command = ["find", str(root_dir), "-type", "d", "-name", f"{fs}*"] + output = subprocess.check_output(command) + current_fs = [Path(x) for x in output.decode().split("\n") if x] + # if any(fs.name in FS_LIST for fs in current_fs): + # current_fs = [fs for fs in current_fs if fs.name in FS_LIST] + + found_fs.extend(current_fs) + return found_fs + + @staticmethod + def find_elf_files(root_dir: Path, progress, exclude_libs=True): + + BANNED_LIST = ["busybox"] + + elf_task_str = "[cyan]Finding ELFs" + elf_task = progress.add_task(f"{elf_task_str} ...", total=None) + output = subprocess.check_output( + ["find", str(root_dir), "-type", "f", "-exec", "file", "{}", ";"] + ) + elfs = [ + x + for x in output.decode().split("\n") + if "ELF" in x and (not exclude_libs or "shared object" not in x) + ] + progress.update( + elf_task, description=f"{elf_task_str} [0/{len(elfs)}]", total=len(elfs) + ) + progress.start_task(elf_task) + elf_dict = {} + for idx, elf in enumerate(elfs): + path = Path(elf.split(":")[0].strip()) + if path.is_symlink(): + continue + if path.name in BANNED_LIST: + continue + with path.open("rb") as f: + sha256 = hashlib.file_digest(f, "sha256").hexdigest() + elf_dict[sha256] = {"path": str(path)} + progress.update( + elf_task, description=f"{elf_task_str} [{idx+1}/{len(elfs)}]", advance=1 + ) + progress.update(elf_task, visible=False) + + return elf_dict + + +if __name__ == "__main__": + FirmwareFinder(Path(sys.argv[1])) diff --git a/pipeline/mango_pipeline/firmware/elf_info.py b/pipeline/mango_pipeline/firmware/elf_info.py new file mode 100644 index 0000000..ca230b9 --- /dev/null +++ b/pipeline/mango_pipeline/firmware/elf_info.py @@ -0,0 +1,79 @@ +from typing import List +from pathlib import Path + +FS_LIST = ["squashfs-root", "ubifs-root", "cpio-root", "fs"] + + +class ELFInfo: + possible_lib_locations = [ + "/dumaos/ngcompat", + "/etc", + "/iQoS/R8900/TM", + "/iQoS/R8900/tm_key", + "/iQoS/R9000/TM", + "/iQoS/R9000/tm_key", + "/lib", + "/lib/lua", + "/lib/pptpd", + "/tmp/root/lib", + "/tmp/root/usr/lib", + "/usr/lib", + "/usr/lib/ebtables", + "/usr/lib/forked-daapd", + "/usr/lib/iptables", + "/usr/lib/lua", + "/usr/lib/lua/socket", + "/usr/lib/pppd/2.4.3", + "/usr/lib/tc", + "/usr/lib/uams", + "/usr/lib/xtables", + "/usr/local/lib/openvpn/plugins", + "/usr/share", + ] + + def __init__( + self, path: str, brand: str, firmware: str, sha: str, ld_paths: list = None + ): + self.path = path + self.brand = brand + self.firmware = firmware + self.sha = sha + if ld_paths is None: + self.ld_paths = self.get_lib_locations() + else: + self.ld_paths = ld_paths + + def get_lib_locations(self) -> List[str]: + firmware_fs = None + for fs in FS_LIST: + if fs in self.path: + firmware_fs = fs + break + firmware_root = ( + Path(self.path[: self.path.index(firmware_fs)] + firmware_fs) + .absolute() + .resolve() + ) + return [ + str(firmware_root / x) + for x in ELFInfo.possible_lib_locations + if (firmware_root / x).exists + ] + + def __hash__(self): + return hash(self.path) + + def to_dict(self): + return { + "path": self.path, + "brand": self.brand, + "firmware": self.firmware, + "sha": self.sha, + "ld_paths": self.ld_paths, + } + + def __repr__(self): + return str(self) + + def __str__(self): + return f"" \ No newline at end of file diff --git a/pipeline/mango_pipeline/firmware/keyword_finder.py b/pipeline/mango_pipeline/firmware/keyword_finder.py new file mode 100644 index 0000000..7048077 --- /dev/null +++ b/pipeline/mango_pipeline/firmware/keyword_finder.py @@ -0,0 +1,311 @@ +import sys +import re +import subprocess +import json +import string +import ipdb +import inspect + + +from typing import Set, Union +from pathlib import Path + +import yaml +import esprima + +from rich.progress import ( + Progress, + TimeElapsedColumn, + TextColumn, +) +from bs4 import BeautifulSoup + +MIN_STRING_LENGTH = 3 + +text_endings = { + ".cfg", + ".conf", + ".config", + ".ini", + ".init", + ".txt", +} + +bash_endings = { + ".sh", +} + +php_endings = { + ".php", + ".cgi", +} + +js_endings = { + ".js", +} + +html_endings = { + ".htm", + ".html", + ".asp", +} + +object_endings = {".xml", ".json", ".yaml", ".yml"} + +all_endings = ( + text_endings + | php_endings + | js_endings + | html_endings + | object_endings + | bash_endings +) + +current_progress = None + +AVOID_KEYWORDS = {"true", "false", "static", "radio", "none", "disabled", "fixed"} + + +def bp_hook(*args, **kwargs): + if current_progress is not None: + current_progress.stop() + + frame = inspect.currentframe().f_back + ipdb.set_trace(frame) + + +sys.breakpointhook = bp_hook + + +def strip_non_alpha(s) -> str: + s = s.split("=")[0] + s = s.split(":")[0] + s = s.split(";")[0] + stripped = re.sub(r"^[^a-zA-Z]+|[^a-zA-Z0-9]+$", "", s) + if any(x in stripped for x in string.whitespace): + return "" + left = stripped.find("<") + if left != -1: + stripped = stripped[left + 1 :] + + right = stripped.rfind(">") + if right != -1: + stripped = stripped[:right] + backslash = stripped.find("\\") + if backslash != -1: + stripped = stripped[:backslash] + question = stripped.find("?") + if question != -1 and question != len(stripped) - 1: + stripped = stripped[question + 1 :] + return stripped + + +def is_php(filename: Path) -> bool: + with filename.open("r") as f: + for line in f.readlines(): + line = line.strip() + if not line: + continue + + if line.startswith(" Set[str]: + try: + if isinstance(file, str): + tokens = esprima.tokenize(file) + script = esprima.parse(file) + else: + tokens = esprima.tokenize(file.read_text()) + script = esprima.parse(file.read_text()) + except (esprima.error_handler.Error, UnicodeDecodeError): + return set() + + pattern = re.compile(r'property:\s*{[^}]*name:\s*"([^"]*)"[^}]*}', re.DOTALL) + + # Find matches + valid_strings = set() + for obj_string in pattern.findall(str(script)): + stripped = strip_non_alpha(obj_string) + if len(stripped) > MIN_STRING_LENGTH: + valid_strings.add(stripped) + + for token in tokens: + if token.type == "String": + stripped = strip_non_alpha(token.value) + if len(stripped) > MIN_STRING_LENGTH: + valid_strings.add(stripped) + + return valid_strings + + +def get_strings_from_html(file_path: Path, parser=None) -> Set[str]: + try: + soup = BeautifulSoup(file_path.read_text(), parser or "html.parser") + except UnicodeDecodeError: + return set() + valid_strings = set() + + # Find all input tags and extract name and value attributes + for input_tag in soup.find_all("input"): + name = input_tag.get("name") + if name: + valid_strings.add(name) + id_ = input_tag.get("id") + if id_: + valid_strings.add(id_) + + # Find all select tags and extract name attribute + for select_tag in soup.find_all("select"): + name = select_tag.get("name") + if name: + valid_strings.add(name) + id_ = select_tag.get("id") + if id_: + valid_strings.add(id_) + + for script in soup.find_all("script"): + if script.string: + valid_strings |= get_strings_from_javascript(script.string) + + return valid_strings + + +def get_strings_from_php(file_path: Path) -> Set[str]: + valid_strings = get_strings_from_html(file_path, parser="lxml") + matches = re.findall( + """\$_(?:GET|POST|SERVER)\[(?:"|')(.*)(?:"|')\]""", file_path.read_text() + ) + valid_strings |= set(matches) + + return valid_strings + + +def get_strings_from_json(file_path: Path, data=None, strings=None) -> Set[str]: + # Parses JSON and returns all keys that are strings + strings = strings or set() + try: + data = data or json.loads(file_path.read_text()) + except json.decoder.JSONDecodeError: + return strings + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(key, str): + if len(key) > MIN_STRING_LENGTH: + strings.add(key) + get_strings_from_json(file_path, value, strings) + elif isinstance(data, list): + for item in data: + get_strings_from_json(file_path, item, strings) + return strings + + +def get_strings_from_xml(file_path: Path) -> Set[str]: + soup = BeautifulSoup(file_path.read_text(), "lxml-xml") + tag_names = {tag.name for tag in soup.find_all(True)} + return tag_names + + +def get_strings_from_object(file_path: Path) -> Set[str]: + # if file_path.suffix == ".json": + # strings = get_strings_from_json(file_path) + # return strings + # elif file_path.suffix == ".yml" or file_path.suffix == ".yaml": + # data = yaml.safe_load(file_path.read_text()) + # strings = get_strings_from_json(file_path, data) + # return strings + if file_path.suffix == ".xml": + strings = get_strings_from_xml(file_path) + return strings + + return set() + + +def find_potential_files(directory: Path): + endings = [] + for idx, tup in enumerate([["-iname", "*" + ending] for ending in all_endings]): + if idx > 0: + endings.append("-o") + endings.extend(tup) + files = ( + subprocess.check_output(["find", directory, "-type", "f", *endings]) + .decode() + .split("\n") + ) + final_files = [] + for file in files: + fp = Path(file) + if not fp.exists(): + continue + + file_out = subprocess.check_output(["file", fp]) + is_ascii = b"ASCII" in file_out or b"Unicode text" in file_out + if is_ascii: + final_files.append(fp) + + return final_files + + +def find_keywords(firmware_dir: Path, progress=None): + global current_progress + had_progress = progress is not None + progress = progress or Progress(transient=True) + current_progress = progress + find_task = progress.add_task("[green]Finding potential keyword files", total=1) + files = find_potential_files(firmware_dir) + progress.update(find_task, visible=False) + + string_dict = {} + current_progress = progress + task = progress.add_task("[green]Scanning files for keywords", total=len(files)) + for file in files: + strings = set() + progress.advance(task) + + if file.suffix in js_endings: + if not file.exists(): + continue + + strings |= get_strings_from_javascript(file) + + elif file.suffix in php_endings: + if is_php(file): + strings |= get_strings_from_php(file) + + elif file.suffix in object_endings: + strings |= get_strings_from_object(file) + + elif file.suffix in html_endings: + strings |= get_strings_from_html(file) + + elif file.suffix in bash_endings: + pass + + else: + pass + for s in strings: + if s.lower() in AVOID_KEYWORDS: + continue + if s not in string_dict: + string_dict[s] = [] + string_dict[s].append(file.name) + + progress.update(task, visible=False) + if not had_progress: + progress.stop() + + return string_dict + + +if __name__ == "__main__": + import pprint + + print(f"Searching", sys.argv[1]) + keywords = find_keywords(Path(sys.argv[1])) + pprint.pprint(keywords, indent=4) + with open("keywords.json", "w+") as f: + json.dump(keywords, f, indent=4) diff --git a/pipeline/mango_pipeline/kube.py b/pipeline/mango_pipeline/kube.py new file mode 100644 index 0000000..3deca87 --- /dev/null +++ b/pipeline/mango_pipeline/kube.py @@ -0,0 +1,623 @@ +import base64 +import json +import os +import time +import toml + +from typing import Set +from pathlib import Path +from zipfile import ZipFile + +from typing import Tuple, Dict, Set + +import docker + + +from rich.console import Console +from rich.table import Table +from rich.progress import ( + Progress, + BarColumn, + TextColumn, + TaskProgressColumn, + TransferSpeedColumn, + TimeRemainingColumn, +) + +from kubernetes import config, watch +from kubernetes import client as kube_client + +from kubernetes.client import ( + V1ResourceRequirements, + V1PersistentVolumeClaimVolumeSource, + V1VolumeMount, + V1Volume, + V1Container, + V1EnvVar, + V1PodTemplateSpec, + V1PodSpec, + V1ObjectMeta, + V1Job, + V1JobSpec, + V1LocalObjectReference, + V1DeleteOptions, +) + +from .base import Pipeline, MyProgress, ELFInfo + + +class PipelineKube(Pipeline): + """ + Pipeline process for running experiments on a remote kubernetes setup + (Tailored specifically for @Clasm's setup: buyer beware) + """ + + REMOTE_DIR = Path("/tank/kubernetes/clasm") + REMOTE_RESULT_DIR = REMOTE_DIR / "mango-results" + KUBE_MOUNT_DIR = Path("/shared") + KUBE_RESULT_DIR = KUBE_MOUNT_DIR / "clasm" / "mango-results" + SSH_SERVER = "nfs_server" + ZIP_DEST = "/tmp/firmware" + + def __init__(self, *args, quiet=True, parallel=500, watch=None, **kwargs): + super().__init__(*args, parallel=parallel, **kwargs) + + config_path = Path(__file__).parent.absolute() / "configs" / "pipeline.toml" + if not config_path.exists(): + self.create_default_config(config_path) + + self.config = toml.loads(config_path.read_text()) + + self.results_dir.mkdir(parents=True, exist_ok=True) + + self.REMOTE_RESULT_DIR = self.REMOTE_DIR / self.results_dir.name + self.KUBE_RESULT_DIR = self.KUBE_MOUNT_DIR / "clasm" / self.results_dir.name + config.load_kube_config() + self.api_client = kube_client.ApiClient() + self.k8_client = kube_client.CoreV1Api(self.api_client) + self._get_time = time.time + self.job_status = { + "name": "Unknown", + "status": "Unknown", + "active": None, + "failed": None, + "succeeded": None, + } + + if not self.check_if_path_exists(self.target, is_dir=True): + Console().print(f"[bold red]Error: {self.target} does not exist on remote") + + if watch is not None: + self.watch_job(job_name=watch, namespace="clasm") + + def check_if_path_exists(self, path, is_dir): + ret_code, _ = self.remote_subprocess(["test", "-d" if is_dir else "-f", path]) + if ret_code == 0: + return True + else: + return False + + def build_container(self): + super().build_container() + self.push_container() + + def push_container(self): + """ + Pushes built container to repository + """ + + client = docker.from_env() + image = client.images.get(self.container_name) + success = image.tag(self.config["remote"]["registry_image"]) + if not success: + with Console() as console: + console.print("[red bold]Failed to tag container") + return + + push_info = {} + resp = client.images.push( + self.config["remote"]["registry_image"], + auth_config=self.config["remote"]["auth_config"], + stream=True, + decode=True, + ) + + with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + transient=True, + ) as progress: + for line in resp: + if "status" in line and line["status"] == "Pushing": + item_id = line["id"] + + info = line["progressDetail"] + current = info["current"] + total = info["total"] if "total" in info else None + if item_id in push_info: + progress.update( + push_info[item_id], completed=current, total=total + ) + else: + push_info[item_id] = progress.add_task( + description=f"Pushing {item_id}", + completed=current, + total=total, + ) + else: + if "errorDetail" in line: + progress.print(line) + Console().print( + f"[green]Pushed Container to {self.config['remote']['registry_image']}!" + ) + + def get_symbols_and_targets(self) -> Tuple[Dict[str, set[str]], Set[ELFInfo]]: + if ( + not self.check_if_path_exists( + self.REMOTE_RESULT_DIR / "vendors.json", is_dir=False + ) + or not self.check_if_path_exists( + self.REMOTE_RESULT_DIR / "symbols.json", is_dir=False + ) + or self.bin_prep + ): + self.run_remote_bin_prep() + + self.vendor_dict = self.merge_vendors() + symbols = self.merge_symbols() + + targets, duplicates = self.get_experiment_targets(remote=True) + self.link_duplicates(targets, duplicates) + + return symbols, targets + + def run_remote_bin_prep(self): + _, remote_targets = self.remote_subprocess(["ls", self.target]) + _, found_symbols = self.console_subprocess( + ["find", self.results_dir, "-type", "f", "-iname", "symbols.json"] + ) + _, found_vendors = self.console_subprocess( + ["find", self.results_dir, "-type", "f", "-iname", "vendor.json"] + ) + found_symbols = [Path(x.strip()).parent.name + ".tar.gz" for x in found_symbols] + found_vendors = [Path(x.strip()).parent.name + ".tar.gz" for x in found_vendors] + invalid = set(found_symbols) - set(found_vendors) + invalid |= set(found_vendors) - set(found_symbols) + remote_targets = [x.strip() for x in remote_targets if x] + if not self.bin_prep: + remote_targets = [x for x in remote_targets if x in invalid] + if len(remote_targets) > 0: + old_timeout = self.timeout + self.timeout = 999999 + remote_targets_translated = [ + str( + str(self.target / x).replace( + "/tank/kubernetes", str(self.KUBE_MOUNT_DIR) + ) + ) + for x in remote_targets + ] + self.create_experiment_list("bin_prep", remote_targets_translated) + self.timeout = old_timeout + job_name = self.job_name or "bin-prep-job" + self.create_job( + completions=len(remote_targets), + job_name=job_name, + env_dict={"ZIP_DEST": self.ZIP_DEST}, + ) + self.watch_job(job_name=job_name, namespace="clasm") + self.download_new_results() + + def remote_subprocess(self, command: list): + command = ["ssh", self.SSH_SERVER] + command + return self.console_subprocess(command) + + def run_mango(self, targets: Set[ELFInfo]): + # self.env_merge() + job_name = "mango-job" if self.job_name is None else self.job_name + # self.upload_targets(targets) + self.create_experiment_list("mango", targets) + # self.upload_current_results(targets, "mango") + + self.create_job( + completions=len(targets), + job_name=job_name, + env_dict={"ZIP_DEST": self.ZIP_DEST}, + ) + + self.watch_job(job_name=job_name, namespace="clasm") + self.download_new_results() + + def translate_local_to_remote_targets(self, targets: Set[ELFInfo]): + remote_targets = set() + for target in targets: + root_path = self.target.absolute().resolve() + remote_root_path = self.KUBE_MOUNT_DIR / "clasm" / self.target.name + target_path = Path(target.path).absolute().resolve() + + remote_path = str(target_path).replace( + str(root_path), str(remote_root_path) + ) + remote_ld_paths = [ + x.replace(str(root_path), str(remote_root_path)) + for x in target.get_lib_locations() + ] + remote_target = ELFInfo( + path=remote_path, + brand=target.brand, + firmware=target.firmware, + sha=target.sha, + ld_paths=remote_ld_paths, + ) + + remote_targets.add(remote_target) + return remote_targets + + def upload_targets(self, targets: Set[ELFInfo]): + """ + Zips targets and extracts it at the remote dir, preserving path + :param local_dir_loc: + :param remote_dir_loc: + :param targets: + :return: + """ + temp_loc = Path("/tmp") / f"{self.target.name}.zip" + + with MyProgress() as progress: + zip_task = progress.add_task("Zipping...", total=len(targets)) + with ZipFile(temp_loc, "w") as zip_obj: + resolved_path = self.target.absolute().resolve() + for target in targets: + p = Path( + target.path.replace(str(resolved_path), resolved_path.name) + ) + zip_obj.write(target.path, str(p)) + progress.update(zip_task, advance=1) + + self.console_subprocess(["scp", temp_loc, f"{self.SSH_SERVER}:{temp_loc}"]) + self.console_subprocess( + ["ssh", self.SSH_SERVER, "unzip", "-o", temp_loc, "-d", self.REMOTE_DIR] + ) + self.console_subprocess(["ssh", self.SSH_SERVER, "rm", temp_loc]) + os.remove(temp_loc) + + def remote_copy(self, local_src: Path, remote_dst: Path): + self.console_subprocess( + ["scp", "-r", local_src, f"{self.SSH_SERVER}:{remote_dst}"] + ) + + def remote_download(self, local_dst: Path, remote_src: Path): + self.console_subprocess( + ["scp", "-r", f"{self.SSH_SERVER}:{remote_src}", str(local_dst)] + ) + + def create_experiment_list(self, script, remote_targets): + experiment_json = dict() + experiment_json["script"] = script + experiment_json["timeout"] = self.timeout + experiment_json["rda_timeout"] = self.rda_timeout + experiment_json["category"] = self.category + experiment_json["targets"] = { + idx: target.to_dict() if isinstance(target, ELFInfo) else target + for idx, target in enumerate(remote_targets) + } + experiment_json["result_dest"] = str(self.KUBE_RESULT_DIR) + experiment_json["target_dir"] = str(self.target).replace( + "/tank/kubernetes", "/shared" + ) + + self.remote_subprocess(["mkdir", "-p", self.REMOTE_RESULT_DIR]) + experiment_file = Path(f"/tmp/{self.category}_experiment_list.json") + with experiment_file.open("w+") as f: + json.dump(experiment_json, f, indent=4) + + self.remote_copy(experiment_file, self.REMOTE_RESULT_DIR / experiment_file.name) + + def upload_current_results(self, targets: Set[ELFInfo], script: str): + + remote_targets = self.translate_local_to_remote_targets(targets) + + experiment_json = dict() + experiment_json["script"] = script + experiment_json["timeout"] = self.timeout + experiment_json["category"] = self.category + experiment_json["targets"] = { + idx: target.to_dict() for idx, target in enumerate(remote_targets) + } + experiment_json["result_dest"] = str(self.KUBE_RESULT_DIR) + experiment_json["target_dir"] = "/tmp" + + self.results_dir.mkdir(exist_ok=True, parents=True) + with (self.results_dir / f"{self.category}_experiment_list.json").open( + "w+" + ) as f: + json.dump(experiment_json, f, indent=4) + + results_zip = Path("/tmp/results.zip") + self.console_subprocess( + ["zip", "-r", results_zip, self.results_dir.name], + cwd=self.results_dir.parent, + ) + self.console_subprocess( + ["scp", results_zip, f"{self.SSH_SERVER}:{results_zip}"] + ) + self.console_subprocess( + ["ssh", self.SSH_SERVER, "unzip", "-o", results_zip, "-d", self.REMOTE_DIR] + ) + self.console_subprocess(["ssh", self.SSH_SERVER, "rm", results_zip]) + os.remove(results_zip) + + def download_new_results(self): + results_zip = Path("/tmp/results.zip") + self.console_subprocess( + [ + "ssh", + self.SSH_SERVER, + "cd", + self.REMOTE_DIR, + "&&", + "zip", + "-r", + results_zip, + self.REMOTE_RESULT_DIR.name, + ] + ) + self.console_subprocess( + ["scp", f"{self.SSH_SERVER}:{results_zip}", results_zip] + ) + self.console_subprocess( + ["unzip", "-o", results_zip, "-d", self.results_dir.parent] + ) + self.console_subprocess(["ssh", self.SSH_SERVER, "rm", results_zip]) + try: + os.remove(results_zip) + except FileNotFoundError: + Console().print(f"[red]{results_zip} Not Found") + + def run_env_resolve(self, targets: Set[ELFInfo]): + + job_name = "env-resolve-job" if self.job_name is None else self.job_name + # self.upload_targets(targets) + self.create_experiment_list("env_resolve", targets) + # self.upload_current_results(targets, "env_resolve") + + self.create_job( + completions=len(targets), + job_name=job_name, + env_dict={"ZIP_DEST": self.ZIP_DEST}, + ) + + self.watch_job(job_name=job_name, namespace="clasm") + self.download_new_results() + + def parse_mango_result(self, res_file, target): + pass + + def create_job( + self, + completions: int, + job_name="mango-job", + cpu_min="1000m", + cpu_max="1000m", + mem_min="5Gi", + env_dict=None, + ): + # Container should grab the indexed job + share_name = "nfs-shared" + claim_name = "nfs" + + self.create_registry_secret() + + resources = V1ResourceRequirements( + requests={"memory": mem_min, "cpu": cpu_min}, limits={"cpu": cpu_max} + ) + + container_volume_mounts = [ + V1VolumeMount(name=share_name, mount_path=str(self.KUBE_MOUNT_DIR)) + ] + + env = [V1EnvVar(name="KUBE", value="True")] + if self.extra_args: + env += [ + V1EnvVar( + name="EXTRA_ARGS", + value=json.dumps(["--" + x for x in self.extra_args]), + ) + ] + if self.py_spy: + env += [V1EnvVar(name="PYSPY", value="")] + + if env_dict: + for k, v in env_dict.items(): + env += [V1EnvVar(name=str(k), value=str(v))] + + container = V1Container( + name="mango-worker", + image=self.config["remote"]["registry_image"], + env=env, + command=[ + "/entrypoint.py", + str(self.KUBE_RESULT_DIR / f"{self.category}_experiment_list.json"), + ], + resources=resources, + volume_mounts=container_volume_mounts, + ) + + # Create and configure a spec section + nfs_volume = V1Volume( + name=share_name, + persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( + claim_name=claim_name + ), + ) + + template = V1PodTemplateSpec( + metadata=V1ObjectMeta(labels={"app": "operation-mango"}), + spec=V1PodSpec( + restart_policy="Never", + containers=[container], + volumes=[nfs_volume], + image_pull_secrets=[ + V1LocalObjectReference( + name=self.config["remote"]["registry_secret_name"] + ) + ], + ), + ) + + # Create the specification of deployment + spec = V1JobSpec( + template=template, + completions=completions, + completion_mode="Indexed", + backoff_limit=completions, + parallelism=self.parallel, + ) + + # Instantiate the job object + job = V1Job( + api_version="batch/v1", + kind="Job", + metadata=V1ObjectMeta(name=job_name), + spec=spec, + ) + + batch_api = kube_client.BatchV1Api() + while True: + try: + batch_api.create_namespaced_job(body=job, namespace="clasm") + break + except kube_client.exceptions.ApiException as e: + if e.reason == "Conflict": + Console().print(f"Attempting to delete {job_name}") + self.delete_job(batch_api, job_name) + time.sleep(30) + batch_api.create_namespaced_job(body=job, namespace="clasm") + + def create_registry_secret(self): + auth = base64.b64encode( + f"{self.config['remote']['auth_config']['username']}:{self.config['remote']['auth_config']['password']}".encode( + "utf-8" + ) + ).decode("utf-8") + + docker_config_dict = { + "auths": { + self.config["remote"]["registry"]: { + "username": self.config["remote"]["auth_config"]["username"], + "password": self.config["remote"]["auth_config"]["password"], + "auth": auth, + } + } + } + + docker_config = base64.b64encode( + json.dumps(docker_config_dict).encode("utf-8") + ).decode("utf-8") + + try: + self.k8_client.create_namespaced_secret( + namespace="clasm", + body=kube_client.V1Secret( + metadata=kube_client.V1ObjectMeta( + name=self.config["remote"]["registry_secret_name"], + ), + type="kubernetes.io/dockerconfigjson", + data={".dockerconfigjson": docker_config}, + ), + ) + except kube_client.exceptions.ApiException as e: + if e.reason != "Conflict": + raise e + + @staticmethod + def delete_job(api_instance, job_name): + api_response = api_instance.delete_namespaced_job( + name=job_name, + namespace="clasm", + body=V1DeleteOptions( + propagation_policy="Foreground", grace_period_seconds=5 + ), + ) + Console().print(f"[bold red]Job deleted. status='{api_response.status}'") + + def gen_job_table(self): + # Display the job status in a table + table = Table( + title=f"{self.job_status['name']} Status - {self.job_status['status']}", + min_width=100, + ) + table.add_column("Status", style="cyan") + table.add_column("Value", style="magenta") + table.add_row("Active", str(self.job_status["active"] or 0)) + table.add_row("Succeeded", str(self.job_status["succeeded"] or 0)) + table.add_row("Failed", str(self.job_status["failed"] or 0)) + + return table + + def get_time(self): + return self._get_time() + + def watch_job(self, job_name, namespace): + # Load the kubeconfig file + config.load_kube_config() + + # Set up the Kubernetes API client + api = kube_client.BatchV1Api() + + # Create a watch object + job_watch = watch.Watch() + + # Watch for events related to the specified job + self.job_status["name"] = job_name + with MyProgress( + renderable_callback=self.gen_job_table, get_time=time.time + ) as progress: + progress_bar = None + for event in job_watch.stream( + api.list_namespaced_job, namespace=namespace, timeout_seconds=None + ): + job = event["object"] + + if job.metadata.name == job_name: + job_status = job.status + self.job_status["status"] = event["type"] + self.job_status["active"] = job_status.active + self.job_status["succeeded"] = job_status.succeeded + self.job_status["failed"] = job_status.failed + + completed = job_status.succeeded or 0 + total = job.spec.completions or 1 + if progress_bar is None: + self._get_time = ( + lambda: time.time() - job_status.start_time.timestamp() + ) + progress_bar = progress.add_task( + description=f"[cyan] Watching job..." + ) + try: + progress._tasks[ + progress_bar + ].start_time = job_status.start_time.timestamp() + if total == completed: + start_time = job_status.start_time.timestamp() + end_time = job_status.completion_time.timestamp() + progress._tasks[ + progress_bar + ].start_time = time.time() - (end_time - start_time) + except AttributeError: + pass + progress.start_task(progress_bar) + + progress.update(progress_bar, completed=completed, total=total) + + # Stop watching the job when it's complete + if total == completed: + job_watch.stop() + print(f"Job {job_name} completed.") + break + + time.sleep(10) diff --git a/pipeline/mango_pipeline/local.py b/pipeline/mango_pipeline/local.py new file mode 100644 index 0000000..68fae8c --- /dev/null +++ b/pipeline/mango_pipeline/local.py @@ -0,0 +1,259 @@ +import json +from typing import Set, Tuple, Dict + +import docker +from docker.errors import ContainerError, NotFound + +from rich.console import Console +from rich.table import Table + +from multiprocessing import Pool +from pathlib import Path + +from . import PROJECT_DIR +from .base import Pipeline, ELFInfo, MyProgress +from .firmware.elf_finder import FirmwareFinder +from .scripts import data_printer + + +class PipelineLocal(Pipeline): + """ + Pipeline Process for running experiments in local docker containers + """ + + docker_file = PROJECT_DIR / "./docker/mango_user/Dockerfile" + + def __init__(self, *args, quiet=True, **kwargs): + + super().__init__(*args, **kwargs) + self.quiet = quiet + + self.results_dir.mkdir(parents=True, exist_ok=True) + if self.target and self.target.is_dir(): + finder = FirmwareFinder( + self.target, + self.results_dir, + bin_prep=self.bin_prep, + exclude_libs=self.exclude_libs, + ) + self.vendor_dict = finder.vendor_dict + else: + self.vendor_dict = {} + + def run_mango(self, targets: Set[ELFInfo]): + with MyProgress( + renderable_callback=self.mango_table_wrapper, transient=True + ) as progress: + mango_task = progress.add_task( + description=f"Mango Analysis", total=len(targets) + ) + + with Pool(self.parallel) as p: + for idx, results in enumerate( + p.imap_unordered(self.mango_wrapper, targets) + ): + data_printer.parse_mango_result(self.total_mango_results, *results) + progress.update(mango_task, advance=1) + + Console().print(self.mango_table_wrapper()) + Console().print("[bold green]MANGO ANALYSIS COMPLETE") + + def mango_wrapper(self, target: ELFInfo) -> Tuple[Path, ELFInfo]: + return self.run_analysis_container(target, "mango") + + def mango_table_wrapper(self): + return data_printer.generate_mango_table( + self.total_mango_results, show_dups=self.show_dups + ) + + def env_table_wrapper(self): + return data_printer.generate_env_table( + self.total_env_results, show_dups=self.show_dups + ) + + def run_env_resolve(self, targets: Set[ELFInfo]): + with MyProgress( + renderable_callback=self.env_table_wrapper, transient=True + ) as progress: + env_task = progress.add_task( + description=f"ENV Analysis", total=len(targets) + ) + with Pool(self.parallel) as p: + for idx, results in enumerate( + p.imap_unordered(self.env_wrapper, targets) + ): + data_printer.parse_env_result(self.total_env_results, *results) + progress.update(env_task, advance=1) + + Console().print(self.env_table_wrapper()) + self.env_merge() + Console().print("[bold green]ENV RESOLVE COMPLETE") + + def env_wrapper(self, target: ELFInfo) -> Tuple[Path, ELFInfo]: + return self.run_analysis_container(target, "env_resolve") + + def run_analysis_container(self, *args) -> Tuple[Path, ELFInfo]: + target, script = args + docker_client = docker.from_env() + docker_bin_path = Path("/tmp") / target.path + docker_res = Path("/tmp/results") / target.sha + + volumes = { + Path(target.path).absolute(): {"bind": str(docker_bin_path), "mode": "ro"} + } + local_res_dir = self.result_dir_from_target(target) + local_res_dir.mkdir(parents=True, exist_ok=True) + + if script == "mango": + results_file = local_res_dir / f"{self.category}_results.json" + if (local_res_dir.parent / "env.json").exists(): + local_env = local_res_dir.parent / "env.json" + env_dict = Path("/tmp/results/env.json") + volumes[str(local_env.absolute())] = { + "bind": str(env_dict), + "mode": "ro", + } + else: + results_file = local_res_dir / "env.json" + + if results_file.exists(): + return results_file, target + + volumes[str(local_res_dir.absolute())] = {"bind": str(docker_res), "mode": "rw"} + environment = { + "SCRIPT": script, + "TIMEOUT": self.timeout, + "RDA_TIMEOUT": self.rda_timeout, + "CATEGORY": json.dumps(self.category), + "RESULT_DEST": str(docker_res), + "TARGET_PATH": str(docker_bin_path), + "TARGET_SHA": target.sha, + "TARGET_BRAND": target.brand, + "TARGET_FIRMWARE": target.firmware, + "LD_PATHS": json.dumps(target.ld_paths), + "EXTRA_ARGS": json.dumps(["--" + x for x in self.extra_args]) + } + + try: + output = docker_client.containers.run( + self.container_name, + name=f"{script}_{Path(target.path).name}_{target.sha[:6]}_{self.category}_{Path(self.results_dir).name}".replace( + "+", "plus" + ), + command=f"/entrypoint.py {script}", + volumes=volumes, + environment=environment, + stdout=True, + stderr=True, + auto_remove=True, + ) + except ContainerError as e: + try: + self.save_error_result( + results_file, target, script, e.container.logs().decode() + ) + except NotFound as e: + self.save_error_result(results_file, target, script, str(e)) + except Exception as e: + self.save_error_result(results_file, target, script, str(e)) + + if not results_file.exists(): + self.save_error_result(results_file, target, script, "EARLY TERMINATION") + + return results_file, target + + def print_status(self): + targets, duplicates = self.get_experiment_targets() + symbols = self.get_target_symbols(targets) + self.filter_env_targets(targets, symbols) + self.filter_mango_targets(targets, symbols) + + env_table = data_printer.generate_env_table( + self.total_env_results, show_dups=self.show_dups + ) + mango_table = data_printer.generate_mango_table( + self.total_mango_results, show_dups=self.show_dups + ) + + with Console() as console: + console.print(env_table) + console.print(mango_table) + + def print_errors(self): + targets, _ = self.get_experiment_targets() + error_table = {} + + paths = [] + for target in targets: + result_path = ( + self.results_dir + / target.brand + / target.firmware + / target.sha + / f"{self.category}_results.json" + ) + res_data = ( + json.loads(result_path.read_text()) if result_path.exists() else {} + ) + if ( + "ret_code" in res_data + and res_data["ret_code"] != 0 + and res_data["ret_code"] != -9 + and res_data != 124 + ): + mango_file = ( + self.results_dir + / target.brand + / target.firmware + / target.sha + / f"{self.category}_mango.out" + ) + if not mango_file.exists(): + continue + for line in reversed(mango_file.read_text().split("\n")): + line = line.strip() + if line: + if "Finished Running Analysis" in line: + break + if "angr.errors.SimMemoryMissingError" in line: + line = "angr.errors.SimMemoryMissingError" + elif any( + line.startswith(x) + for x in [ + "INFO |", + "ERROR |", + "WARNING |", + "WARNING |", + ] + ): + paths.append( + ( + str(result_path), + len(mango_file.read_text().split("\n")), + ) + ) + line = "UNKNOWN" + if line not in error_table: + error_table[line] = [] + error_table[line].append(result_path) + break + print( + "\n".join( + f"{x[0]} {x[1]}" + for x in sorted(paths, reverse=True, key=lambda x: x[1]) + ) + ) + console = Console() + + table = Table(title="Mango Errors") + table.add_column("Error") + table.add_column("Amount") + + worst = None + for error, count in sorted(error_table.items(), key=lambda x: len(x[1])): + if not worst: + worst = count + table.add_row(error, str(len(count))) + + console.print(table) + console.print(worst) diff --git a/pipeline/mango_pipeline/remote.py b/pipeline/mango_pipeline/remote.py new file mode 100644 index 0000000..208509a --- /dev/null +++ b/pipeline/mango_pipeline/remote.py @@ -0,0 +1,520 @@ +import base64 +import json +import os +import time +import toml + +from typing import Set +from pathlib import Path +from zipfile import ZipFile + +import docker + +from rich.console import Console +from rich.table import Table +from rich.progress import ( + Progress, + BarColumn, + TextColumn, + TaskProgressColumn, + TransferSpeedColumn, + TimeRemainingColumn, +) + +from kubernetes import config, watch +from kubernetes import client as kube_client + +from kubernetes.client import ( + V1ResourceRequirements, + V1PersistentVolumeClaimVolumeSource, + V1VolumeMount, + V1Volume, + V1Container, + V1EnvVar, + V1PodTemplateSpec, + V1PodSpec, + V1ObjectMeta, + V1Job, + V1JobSpec, + V1LocalObjectReference, + V1DeleteOptions, +) + +from .base import Pipeline, MyProgress, ELFInfo +from .firmware.elf_finder import FirmwareFinder + + +class PipelineRemote(Pipeline): + """ + Pipeline process for running experiments on a remote kubernetes setup + (Tailored specifically for @Clasm's setup: buyer beware) + """ + + REMOTE_DIR = Path("/tank/kubernetes/clasm") + REMOTE_RESULT_DIR = REMOTE_DIR / "mango-results" + KUBE_MOUNT_DIR = Path("/shared") + KUBE_RESULT_DIR = KUBE_MOUNT_DIR / "clasm" / "mango-results" + SSH_SERVER = "nfs_server" + + def __init__(self, *args, quiet=True, parallel=500, watch=None, **kwargs): + super().__init__(*args, parallel=parallel, **kwargs) + + config_path = Path(__file__).parent.absolute() / "configs" / "pipeline.toml" + if not config_path.exists(): + self.create_default_config(config_path) + + self.config = toml.loads(config_path.read_text()) + + self.REMOTE_RESULT_DIR = self.REMOTE_DIR / self.results_dir.name + self.KUBE_RESULT_DIR = self.KUBE_MOUNT_DIR / "clasm" / self.results_dir.name + config.load_kube_config() + self.api_client = kube_client.ApiClient() + self.k8_client = kube_client.CoreV1Api(self.api_client) + self._get_time = time.time + self.job_status = { + "name": "Unknown", + "status": "Unknown", + "active": None, + "failed": None, + "succeeded": None, + } + if self.target: + self.remote_target = self.REMOTE_DIR / self.target.name + + if self.target and self.target.is_dir(): + finder = FirmwareFinder( + self.target, + self.results_dir, + bin_prep=self.bin_prep, + exclude_libs=self.exclude_libs, + ) + self.vendor_dict = finder.vendor_dict + else: + self.vendor_dict = {} + + if watch is not None: + self.watch_job(job_name=watch, namespace="clasm") + + def build_container(self): + super().build_container() + self.push_container() + + def push_container(self): + """ + Pushes built container to repository + """ + + client = docker.from_env() + image = client.images.get(self.container_name) + success = image.tag(self.config["remote"]["registry_image"]) + if not success: + with Console() as console: + console.print("[red bold]Failed to tag container") + return + + push_info = {} + resp = client.images.push( + self.config["remote"]["registry_image"], + auth_config=self.config["remote"]["auth_config"], + stream=True, + decode=True, + ) + + with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TransferSpeedColumn(), + TimeRemainingColumn(), + transient=True, + ) as progress: + for line in resp: + if "status" in line and line["status"] == "Pushing": + item_id = line["id"] + + info = line["progressDetail"] + current = info["current"] + total = info["total"] if "total" in info else None + if item_id in push_info: + progress.update( + push_info[item_id], completed=current, total=total + ) + else: + push_info[item_id] = progress.add_task( + description=f"Pushing {item_id}", + completed=current, + total=total, + ) + else: + if "errorDetail" in line: + progress.print(line) + Console().print( + f"[green]Pushed Container to {self.config['remote']['registry_image']}!" + ) + + def run_mango(self, targets: Set[ELFInfo]): + + # self.env_merge() + job_name = "mango-job" if self.job_name is None else self.job_name + self.upload_targets(targets) + self.upload_current_results(targets, "mango") + + self.create_job(completions=len(targets), job_name=job_name) + + self.watch_job(job_name=job_name, namespace="clasm") + self.download_new_results() + + def translate_local_to_remote_targets(self, targets: Set[ELFInfo]): + remote_targets = set() + for target in targets: + root_path = self.target.absolute().resolve() + remote_root_path = self.KUBE_MOUNT_DIR / "clasm" / self.target.name + target_path = Path(target.path).absolute().resolve() + + remote_path = str(target_path).replace( + str(root_path), str(remote_root_path) + ) + remote_ld_paths = [ + x.replace(str(root_path), str(remote_root_path)) + for x in target.get_lib_locations() + ] + remote_target = ELFInfo( + path=remote_path, + brand=target.brand, + firmware=target.firmware, + sha=target.sha, + ld_paths=remote_ld_paths, + ) + + remote_targets.add(remote_target) + return remote_targets + + def upload_targets(self, targets: Set[ELFInfo]): + """ + Zips targets and extracts it at the remote dir, preserving path + :param local_dir_loc: + :param remote_dir_loc: + :param targets: + :return: + """ + temp_loc = Path("/tmp") / f"{self.target.name}.zip" + + with MyProgress() as progress: + zip_task = progress.add_task("Zipping...", total=len(targets)) + with ZipFile(temp_loc, "w") as zip_obj: + resolved_path = self.target.absolute().resolve() + for target in targets: + p = Path( + target.path.replace(str(resolved_path), resolved_path.name) + ) + zip_obj.write(target.path, str(p)) + progress.update(zip_task, advance=1) + + self.console_subprocess(["scp", temp_loc, f"{self.SSH_SERVER}:{temp_loc}"]) + self.console_subprocess( + ["ssh", self.SSH_SERVER, "unzip", "-o", temp_loc, "-d", self.REMOTE_DIR] + ) + self.console_subprocess(["ssh", self.SSH_SERVER, "rm", temp_loc]) + os.remove(temp_loc) + + def upload_current_results(self, targets: Set[ELFInfo], script: str): + + remote_targets = self.translate_local_to_remote_targets(targets) + + experiment_json = dict() + experiment_json["script"] = script + experiment_json["timeout"] = self.timeout + experiment_json["rda_timeout"] = self.rda_timeout + experiment_json["category"] = self.category + experiment_json["targets"] = { + idx: target.to_dict() for idx, target in enumerate(remote_targets) + } + experiment_json["result_dest"] = str(self.KUBE_RESULT_DIR) + experiment_json["target_dir"] = "/tmp" + + self.results_dir.mkdir(exist_ok=True, parents=True) + with (self.results_dir / f"{self.category}_experiment_list.json").open( + "w+" + ) as f: + json.dump(experiment_json, f, indent=4) + + results_zip = Path("/tmp/results.zip") + self.console_subprocess( + ["zip", "-r", results_zip, self.results_dir.name], + cwd=self.results_dir.parent, + ) + self.console_subprocess( + ["scp", results_zip, f"{self.SSH_SERVER}:{results_zip}"] + ) + self.console_subprocess( + ["ssh", self.SSH_SERVER, "unzip", "-o", results_zip, "-d", self.REMOTE_DIR] + ) + self.console_subprocess(["ssh", self.SSH_SERVER, "rm", results_zip]) + os.remove(results_zip) + + def download_new_results(self): + results_zip = Path("/tmp/results.zip") + self.console_subprocess( + [ + "ssh", + self.SSH_SERVER, + "cd", + self.REMOTE_DIR, + "&&", + "zip", + "-r", + results_zip, + self.REMOTE_RESULT_DIR.name, + ] + ) + self.console_subprocess( + ["scp", f"{self.SSH_SERVER}:{results_zip}", results_zip] + ) + self.console_subprocess( + ["unzip", "-o", results_zip, "-d", self.results_dir.parent] + ) + self.console_subprocess(["ssh", self.SSH_SERVER, "rm", results_zip]) + os.remove(results_zip) + + def run_env_resolve(self, targets: Set[ELFInfo]): + + job_name = "env-resolve-job" if self.job_name is None else self.job_name + self.upload_targets(targets) + self.upload_current_results(targets, "env_resolve") + + self.create_job(completions=len(targets), job_name=job_name) + + self.watch_job(job_name=job_name, namespace="clasm") + self.download_new_results() + + def parse_mango_result(self, res_file, target): + pass + + def create_job( + self, + completions: int, + job_name="mango-job", + cpu_min="1000m", + cpu_max="1000m", + mem_min="5Gi", + ): + # Container should grab the indexed job + share_name = "nfs-shared" + claim_name = "nfs" + + self.create_registry_secret() + + resources = V1ResourceRequirements( + requests={"memory": mem_min, "cpu": cpu_min}, limits={"cpu": cpu_max} + ) + + container_volume_mounts = [ + V1VolumeMount(name=share_name, mount_path=str(self.KUBE_MOUNT_DIR)) + ] + + env = [V1EnvVar(name="KUBE", value="True")] + if self.extra_args: + env += [ + V1EnvVar( + name="EXTRA_ARGS", + value=json.dumps(["--" + x for x in self.extra_args]), + ) + ] + if self.py_spy: + env += [V1EnvVar(name="PYSPY", value="")] + + container = V1Container( + name="mango-worker", + image=self.config["remote"]["registry_image"], + env=env, + command=[ + "/entrypoint.py", + str(self.KUBE_RESULT_DIR / f"{self.category}_experiment_list.json"), + ], + resources=resources, + volume_mounts=container_volume_mounts, + ) + + # Create and configure a spec section + nfs_volume = V1Volume( + name=share_name, + persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( + claim_name=claim_name + ), + ) + + template = V1PodTemplateSpec( + metadata=V1ObjectMeta(labels={"app": "operation-mango"}), + spec=V1PodSpec( + containers=[container], + volumes=[nfs_volume], + restart_policy="Never", + image_pull_secrets=[ + V1LocalObjectReference( + name=self.config["remote"]["registry_secret_name"] + ) + ], + ), + ) + + # Create the specification of deployment + + spec = V1JobSpec( + template=template, + completions=completions, + completion_mode="Indexed", + backoff_limit=completions, + parallelism=self.parallel, + ) + + # Instantiate the job object + job = V1Job( + api_version="batch/v1", + kind="Job", + metadata=V1ObjectMeta(name=job_name), + spec=spec, + ) + + batch_api = kube_client.BatchV1Api() + while True: + try: + batch_api.create_namespaced_job(body=job, namespace="clasm") + break + except kube_client.exceptions.ApiException as e: + if e.reason == "Conflict": + Console().print(f"Attempting to delete {job_name}") + self.delete_job(batch_api, job_name) + time.sleep(30) + batch_api.create_namespaced_job(body=job, namespace="clasm") + + def create_registry_secret(self): + auth = base64.b64encode( + f"{self.config['remote']['auth_config']['username']}:{self.config['remote']['auth_config']['password']}".encode( + "utf-8" + ) + ).decode("utf-8") + + docker_config_dict = { + "auths": { + self.config["remote"]["registry"]: { + "username": self.config["remote"]["auth_config"]["username"], + "password": self.config["remote"]["auth_config"]["password"], + "auth": auth, + } + } + } + + docker_config = base64.b64encode( + json.dumps(docker_config_dict).encode("utf-8") + ).decode("utf-8") + + try: + self.k8_client.create_namespaced_secret( + namespace="clasm", + body=kube_client.V1Secret( + metadata=kube_client.V1ObjectMeta( + name=self.config["remote"]["registry_secret_name"], + ), + type="kubernetes.io/dockerconfigjson", + data={".dockerconfigjson": docker_config}, + ), + ) + except kube_client.exceptions.ApiException as e: + if e.reason != "Conflict": + raise e + + @staticmethod + def delete_job(api_instance, job_name): + api_response = api_instance.delete_namespaced_job( + name=job_name, + namespace="clasm", + body=V1DeleteOptions( + propagation_policy="Foreground", grace_period_seconds=5 + ), + ) + Console().print(f"[bold red]Job deleted. status='{api_response.status}'") + + def gen_job_table(self): + # Display the job status in a table + table = Table( + title=f"{self.job_status['name']} Status - {self.job_status['status']}", + min_width=100, + ) + table.add_column("Status", style="cyan") + table.add_column("Value", style="magenta") + table.add_row("Active", str(self.job_status["active"] or 0)) + table.add_row("Succeeded", str(self.job_status["succeeded"] or 0)) + table.add_row("Failed", str(self.job_status["failed"] or 0)) + + return table + + def get_time(self): + return self._get_time() + + def print_status(self): + self.watch_job(self.job_name, "clasm") + + def watch_job(self, job_name, namespace): + # Load the kubeconfig file + config.load_kube_config() + + # Set up the Kubernetes API client + api = kube_client.BatchV1Api() + + # Create a watch object + job_watch = watch.Watch() + + # Watch for events related to the specified job + self.job_status["name"] = job_name + with MyProgress( + renderable_callback=self.gen_job_table, get_time=time.time + ) as progress: + progress_bar = None + done = False + while not done: + for event in job_watch.stream( + api.list_namespaced_job, namespace=namespace, timeout_seconds=None + ): + job = event["object"] + + if job.metadata.name == job_name: + job_status = job.status + self.job_status["status"] = event["type"] + self.job_status["active"] = job_status.active + self.job_status["succeeded"] = job_status.succeeded + self.job_status["failed"] = job_status.failed + + completed = job_status.succeeded or 0 + total = job.spec.completions + + if progress_bar is None: + self._get_time = ( + lambda: time.time() - job_status.start_time.timestamp() + ) + progress_bar = progress.add_task( + description=f"[cyan] Running {job_name}..." + ) + try: + progress._tasks[ + progress_bar + ].start_time = job_status.start_time.timestamp() + if total == completed: + start_time = job_status.start_time.timestamp() + end_time = job_status.completion_time.timestamp() + progress._tasks[ + progress_bar + ].start_time = time.time() - (end_time - start_time) + except AttributeError: + pass + progress.start_task(progress_bar) + + progress.update(progress_bar, completed=completed, total=total) + + # Stop watching the job when it's complete + if total == completed: + job_watch.stop() + progress.print(f"Job {job_name} completed.") + done = True + break + + time.sleep(5) diff --git a/pipeline/mango_pipeline/run.py b/pipeline/mango_pipeline/run.py new file mode 100755 index 0000000..176b6a9 --- /dev/null +++ b/pipeline/mango_pipeline/run.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python +import argparse +import sys + +from pathlib import Path + +from mango_pipeline import PipelineRemote, PipelineLocal, PipelineKube +from argument_resolver.external_function.sink import VULN_TYPES + + +def cli_args(): + parser = argparse.ArgumentParser() + path_group = parser.add_argument_group( + "Path Args", "Deciding source and result destination" + ) + run_group = parser.add_argument_group( + "Running", "Options that modify how mango runs" + ) + output_group = parser.add_argument_group( + "Output", "Options to increase or modify output" + ) + path_group.add_argument( + "--path", default=None, help="Binary or Directory of binaries to analyze" + ) + path_group.add_argument( + "--results", + dest="result_folder", + default="./results", + help="Where to store the results of the analysis. (Default: ./results)", + ) + path_group.add_argument( + "--download-results", + dest="download", + action="store_true", + default=False, + help="Download the latest results from remote server", + ) + output_group.add_argument( + "--status", + dest="status", + default=False, + action="store_true", + help="Display current status of either results folder or kube", + ) + + output_group.add_argument( + "--show-dups", + dest="show_dups", + default=False, + action="store_true", + help="Include duplicates in status (Only applies to --status flag)", + ) + + output_group.add_argument( + "--verbose", + dest="verbose", + action="store_true", + default=False, + help="Print STDOUT as operation runs (single target only)", + ) + output_group.add_argument( + "--build-docker", + dest="build_docker", + action="store_true", + default=False, + help="Build Docker Container (Requires internet access if kube)", + ) + output_group.add_argument( + "--gen-csv", + dest="csv", + action="store_true", + default=False, + help="Explicitly generate CSV with current results", + ) + + output_group.add_argument( + "--show-errors", + dest="show_errors", + action="store_true", + default=False, + help="Show table of errors for local results", + ) + output_group.add_argument( + "--aggregate-results", + dest="agg_results", + default=None, + help="Aggregate all results into a single folder [Requires: --results]", + ) + + output_group.add_argument( + "--py-spy", + dest="py_spy", + default=False, + action="store_true", + help="Enable PySpy data logging", + ) + + run_group.add_argument( + "--category", + dest="sink_category", + default="cmdi", + choices=VULN_TYPES.keys(), + help="Analyze sinks from category", + ) + + run_group.add_argument( + "--kube", + dest="kube", + action="store_true", + default=False, + help="Run experiment on the cluster", + ) + run_group.add_argument( + "--env", + dest="is_env", + action="store_true", + default=False, + help="Run env resolver", + ) + run_group.add_argument( + "--mango", + dest="is_mango", + action="store_true", + default=False, + help="Run mango", + ) + run_group.add_argument( + "--full", + dest="is_full", + action="store_true", + default=False, + help="Run full pipeline (Not recommended for remote workloads)", + ) + run_group.add_argument( + "--parallel", + dest="parallel", + default=1, + type=int, + help="Run experiment on the cluster", + ) + run_group.add_argument( + "--brand", + dest="brand", + default="", + type=str, + help="Select specific brand to run experiments on", + ) + run_group.add_argument( + "--firmware", + dest="firmware", + default="", + type=str, + help="Select specific firmware to run experiments on (Brand Must Be Set)", + ) + + run_group.add_argument( + "--extra-args", + dest="extra_args", + default=[], + nargs="+", + help="Extra args to run analysis with", + ) + + run_group.add_argument( + "--job-name", + dest="job_name", + default="mango-job", + type=str, + help="Job Name used for kubernetes", + ) + + run_group.add_argument( + "--timeout", + dest="timeout", + default=3 * 60 * 60, + type=int, + help="Timeout for each container/pod (Default: 3hrs)", + ) + run_group.add_argument( + "--rda-timeout", + dest="rda_timeout", + default=5 * 60, + type=int, + help="Timeout for each sub analysis in a job (Default: 5min)", + ) + + run_group.add_argument( + "--bin-prep", + dest="bin_prep", + default=False, + action="store_true", + help="Find binaries and symbols in firmware (Happens automatically with other options)", + ) + run_group.add_argument( + "--giga-kube", + dest="giga_kube", + default=False, + action="store_true", + help="Reserved for only the largest datasets", + ) + + run_group.add_argument( + "--include-libs", + dest="exclude_libs", + default=True, + action="store_false", + help="Include libraries in the analysis", + ) + + return parser + + +def main(): + parser = cli_args() + args = parser.parse_args() + + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + exit(-2) + + path = Path(args.path) if args.path else None + results_dir = Path(args.result_folder) + + if args.kube: + pipeline_class = PipelineRemote + elif args.giga_kube: + pipeline_class = PipelineKube + args.kube = True + else: + pipeline_class = PipelineLocal + + pipeline = pipeline_class( + path, + results_dir, + parallel=args.parallel, + quiet=not args.verbose, + is_env=args.is_env or args.is_full, + is_mango=args.is_mango or args.is_full, + category=args.sink_category, + brand=args.brand, + firmware=args.firmware, + extra_args=args.extra_args, + job_name=args.job_name, + py_spy=args.py_spy, + timeout=args.timeout, + rda_timeout=args.rda_timeout, + bin_prep=args.bin_prep, + exclude_libs=args.exclude_libs, + show_dups=args.show_dups, + ) + + if args.build_docker: + pipeline.build_container() + + if args.kube and args.status and args.job_name: + pipeline.watch_job(args.job_name, "clasm") + return + + if results_dir.exists(): + if args.status: + pipeline.print_status() + return + elif args.show_errors: + if not args.kube: + pipeline.print_errors() + + if args.download and args.kube: + pipeline.download_new_results() + + if args.agg_results is not None and results_dir.exists(): + pipeline.prep_results(results_dir, Path(args.agg_results), args.category) + + if args.csv or args.agg_results is not None: + if args.agg_results is not None: + pipeline.results_dir = Path(args.agg_results) + pipeline.mango_results_to_csv() + + if path is None: + exit(-1) + + pipeline.run_experiment() + + +if __name__ == "__main__": + main() diff --git a/pipeline/mango_pipeline/scripts/__init__.py b/pipeline/mango_pipeline/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pipeline/mango_pipeline/scripts/ablation.py b/pipeline/mango_pipeline/scripts/ablation.py new file mode 100644 index 0000000..4c6dd1a --- /dev/null +++ b/pipeline/mango_pipeline/scripts/ablation.py @@ -0,0 +1,213 @@ +import argparse +import subprocess +import json +import statistics +import subprocess +from collections import Counter +from dataclasses import dataclass +from typing import Dict, Tuple +from functools import reduce + +from rich.console import Console +from rich.table import Table, Column + +from pathlib import Path + + +@dataclass +class AblationInfo: + path: Path + assumed_execution: bool + reverse_trace: bool + timeouts: int + errors: int + average_time: float + total_time: float + alerts: int = 0 + oom: int = 0 + analysis_times: Dict[str, Tuple[int, int, int]] = None + closures: dict = None + trupocs: int = 0 + desc: str = "" + invalid_shas: list = None + + def sort_score(self): + if self.assumed_execution and self.reverse_trace: + score = 0 + elif self.assumed_execution: + score = 1 + elif self.reverse_trace: + score = 2 + else: + score = 3 + return score + + def get_files(self, firmware: Path, filename: str): + all_files = subprocess.check_output(["find", firmware, "-type", "f", "-name", filename]).strip().decode().split("\n") + return {Path(x) for x in all_files} + + def get_time_data(self, f: Path): + data = json.loads(f.read_text()) + if data["error"] is not None and data["ret_code"] == 124: + if "sinks" in data: + return sum(data["sinks"].values()) + return 40*60 + + if "cfg_time" not in data or data["cfg_time"] is None: + return 0 + + if "mango_time" in data: + if isinstance(data["mango_time"], list): + t = sum(data["mango_time"]) + else: + t = data["mango_time"] + if t == 0 and data["has_sinks"]: + t = sum(data["sinks"]) + return sum([data["cfg_time"], data["vra_time"], t]) + else: + return sum([data["cfg_time"], data["vra_time"], data["analysis_time"]]) + + def get_run_time(self): + files = self.get_files(self.path, "cmdi_results.json") + result_files = {f.parent.name: self.get_time_data(f) for f in files} + result_files = {k: v for k, v in result_files.items() if v != 0} + + for f in files: + data = json.loads(f.read_text()) + if data["error"] is not None: + if data["ret_code"] == 124: + self.timeouts += 1 + elif data["ret_code"] == -9: + self.oom += 1 + else: + self.errors += 1 + + self.files = files + self.analysis_times = result_files + self.closures = {f.parent.name: json.loads(f.read_text())["closures"] for f in files} + #self.timeouts = errors["timeout"] + #self.errors = errors["early_termination"] + #self.oom = errors["OOMKILLED"] + self.alerts = sum(len(v) for v in self.closures.values()) + self.total_time = sum(self.analysis_times.values()) + (40*60*self.timeouts) + self.average_time = (self.total_time / len(self.analysis_times))# if len(self.analysis_times) != 0 else -1 + self.trupocs = sum(1 for closure_list in self.closures.values() for closure in closure_list if closure["rank"] >= 7) + return self + + def filter_run_time(self, valid_files, valid_alert_files): + self.total_time = sum(v for k, v in self.analysis_times.items() if k in valid_files) + self.average_time = (self.total_time / len(valid_files))# if len(valid_files) != 0 else -1 + self.alerts = sum(len(v) for k, v in self.closures.items() if k in valid_alert_files) + +def de_duplicate(files: list, valid_files): + default = files[0] + other = files[1:] + + for sha in valid_files: + default_paths = {tuple( + sorted({x['ins_addr'] for x in closure['trace']} | {closure['sink']['ins_addr']}, key=lambda x: int(x, 16))) + for closure in default.closures[sha]} + if not default_paths: + continue + for a_f in other: + minimum_paths = {tuple(sorted({x['ins_addr'] for x in closure['trace']} | {closure['sink']['ins_addr']}, key=lambda x: int(x, 16))) for closure in a_f.closures[sha]} + changed = True + while changed: + for closure in minimum_paths.copy(): + count = 0 + for dc in default_paths: + if set(dc).issubset(set(closure)): + minimum_paths.discard(closure) + minimum_paths.add(dc) + count += 1 + break + else: + changed = False + + a_f.closures[sha] = minimum_paths + + return files + + + +if __name__ == '__main__': + brand = "NetGear" + firmware = "R6400v2-V1.0.4.84_10.0.58" + parser = argparse.ArgumentParser() + parser.add_argument(dest="ablation_dir", help="Location of Ablation Directory") + args = parser.parse_args() + all_res = subprocess.check_output(["find", args.ablation_dir, "-type", "f", "-iname", "cmdi_results.json"]).decode().strip().split("\n") + + ablation_files = [ + AblationInfo(path=f"{args.ablation_dir}/ablation-default", + assumed_execution=True, + reverse_trace=True, + timeouts=0, + errors=0, + average_time=-1, + total_time=-1, + desc="Default").get_run_time(), + + AblationInfo(path=f"{args.ablation_dir}/ablation-assumed", + assumed_execution=False, + reverse_trace=True, + timeouts=0, + errors=0, + average_time=-1, + total_time=-1, + desc="Assumed").get_run_time(), + + AblationInfo(path=f"{args.ablation_dir}/ablation-trace", + assumed_execution=True, + reverse_trace=False, + timeouts=0, + errors=0, + average_time=-1, + total_time=-1, + desc="Trace").get_run_time(), + + AblationInfo(path=f"{args.ablation_dir}/ablation-all", + assumed_execution=False, + reverse_trace=False, + timeouts=0, + errors=0, + average_time=-1, + total_time=-1, + desc="All").get_run_time(), + ] + + valid_files = set(reduce(lambda x, y: set(x).intersection(y), [x.analysis_times for x in ablation_files])) + valid_alert_files = set(reduce(lambda x, y: set(x).intersection(y), [[k for k, v in x.analysis_times.items() if v != 40*60] for x in ablation_files])) + + ablation_files = de_duplicate(ablation_files, valid_files) + table = Table(f"Desc [Binaries {len(valid_alert_files)}]", + Column("Assumed\nExecution", justify="center"), + Column("Reverse\nTrace", justify="center"), + Column("Average\n(seconds)", justify="center"), + Column("Total\n(minutes)" , justify="center"), + Column("Alerts", justify="center"), + Column("TruPoCs", justify="center"), + Column("Errors" , justify="center"), + Column("OOMKilled", justify="center"), + Column("Timeouts", justify="center"), + show_lines=True, + safe_box=True) + + good = "[green]:heavy_check_mark:" + bad = "[red]:x:" + + for a_f in sorted(ablation_files, key=lambda x: x.sort_score()): + a_f.filter_run_time(valid_files, valid_alert_files) + row = [a_f.desc] + row.append(good if a_f.assumed_execution else bad) + row.append(good if a_f.reverse_trace else bad) + row.append(f"{a_f.average_time:.2f}") + row.append(f"{a_f.total_time/60:.2f}") + row.append(f"[green]{a_f.alerts}") + row.append(f"[bold green]{a_f.trupocs}") + row.append(f"[red]{a_f.errors}") + row.append(f"[yellow]{a_f.oom}") + row.append(f"[blue]{a_f.timeouts}") + table.add_row(*row) + + Console().print(table) \ No newline at end of file diff --git a/pipeline/mango_pipeline/scripts/aggregate.py b/pipeline/mango_pipeline/scripts/aggregate.py new file mode 100644 index 0000000..470a61d --- /dev/null +++ b/pipeline/mango_pipeline/scripts/aggregate.py @@ -0,0 +1,110 @@ +import sys +import subprocess +import json +import shutil +import csv + +from pathlib import Path + + +def prep_results(directory: Path, result_path: Path, prefix=None): + results_files = subprocess.check_output( + ["find", str(directory.resolve()), "-type", "f", "-name", "results.json"]).decode().strip().split('\n') + results_files = [Path(x) for x in results_files if Path(x).is_file() and not Path(x).is_symlink()] + + for results_file in results_files: + data = json.loads(results_file.read_text()) + if data["has_sinks"] is False or data["error"] or len(data["closures"]) == 0: + continue + + out_dir = result_path / (data["sha256"] if "sha256" in data else data["sha"]) + out_dir.mkdir(parents=True, exist_ok=True) + prefix = "" if prefix is None else prefix + filename = prefix + "-" + results_file.name + shutil.copy(results_file, out_dir / filename) + try: + local_path = str(data["path"]).replace("/shared/clasm", "/home/clasm/projects/angr-squad") + output = subprocess.check_output(["file", local_path]).decode() + if "shared object" in output: + shutil.rmtree(out_dir) + continue + shutil.copy(local_path, out_dir / data["name"]) + except PermissionError: + pass + if (results_file.parent / "mango.out").exists(): + filename = prefix + "-" + "mango.out" + shutil.copy(results_file.parent / "mango.out", out_dir / filename) + + +def combine_csv(csv_paths): + out_data = {} + title = [] + vendor = None + firmware = None + sha = None + name = None + cfg_time = None + vra_time = None + analysis_time = None + found_bins = set() + found_header = False + for path in csv_paths: + with open(path / "results.csv", newline='') as csvfile: + spamreader = csv.reader(csvfile, delimiter='\t', quotechar='|') + for line in spamreader: + + if line and line[0].strip() in ["Brand", "Firmware"]: + title = line + found_header = not found_header + continue + + if not found_header: + continue + + if line and line[0]: + vendor, firmware, sha, name = line[:4] + cfg_time, vra_time, analysis_time = [float(x) for x in line[7:10]] + + if not vendor in out_data: + out_data[vendor] = {} + if not firmware in out_data[vendor]: + out_data[vendor][firmware] = {} + if not sha in out_data[vendor][firmware]: + out_data[vendor][firmware][sha] = {"name": name, "cfg_time": cfg_time, "vra_time": vra_time, "analysis_time": analysis_time, "rows": []} + else: + out_data[vendor][firmware][sha]["analysis_time"] += analysis_time + out_data[vendor][firmware][sha]["cfg_time"] = max(analysis_time, out_data[vendor][firmware][sha]["cfg_time"]) + out_data[vendor][firmware][sha]["vra_time"] = max(analysis_time, out_data[vendor][firmware][sha]["vra_time"]) + continue + + if not any(x for x in line): + continue + + out_data[vendor][firmware][sha]["rows"].append(line) + + with open("../../aggregate_results/results.csv", "w+", newline="") as csvfile: + spamwriter = csv.writer(csvfile, delimiter='\t', + quotechar='|', quoting=csv.QUOTE_MINIMAL) + spamwriter.writerow(title) + for vendor in sorted(out_data.keys()): + for firmware in sorted(out_data[vendor].keys()): + for sha in sorted(out_data[vendor][firmware].keys()): + if sha in found_bins: + continue + if not Path("aggregate_results/"+sha).exists(): + continue + found_bins.add(sha) + data = out_data[vendor][firmware][sha] + if data['name'] == "busybox": + continue + spamwriter.writerow([vendor, firmware, sha, data['name'], "", "", "Blank", "", data["cfg_time"], data["vra_time"], data["analysis_time"]]) + for row in sorted(data["rows"], key=lambda x: x[4]): + row[6] = "Unfilled" + spamwriter.writerow(row) + +if __name__ == '__main__': + dataset_dir = Path(sys.argv[1]) + result_dirs = [(Path(x.split("|")[0]), x.split("|")[1]) for x in sys.argv[2:]] + for result_dir, prefix in result_dirs: + prep_results(result_dir, Path("../../aggregate_results"), prefix) + combine_csv([x[0] for x in result_dirs]) \ No newline at end of file diff --git a/pipeline/mango_pipeline/scripts/data_printer.py b/pipeline/mango_pipeline/scripts/data_printer.py new file mode 100644 index 0000000..2e6ddc1 --- /dev/null +++ b/pipeline/mango_pipeline/scripts/data_printer.py @@ -0,0 +1,311 @@ +import datetime +import json + +from typing import List, Dict +from pathlib import Path + +from rich.table import Table +from ..firmware import ELFInfo + + +class ResultAccumulator: + def __init__(self, fields: List[str]): + self.field_dict = {x: {"uniq": 0, "dup": 0} for x in fields} + + def __getattr__(self, name: str): + if name == "field_dict": + return self.__getattribute__(name) + return self.field_dict[name] + + def make_table_row(self, field, modifier="", show_dup=False): + if show_dup: + row = f'{modifier}{self.field_dict[field]["dup"]}' + else: + row = f'{modifier}{self.field_dict[field]["uniq"]}' + return row + + def make_table_rows(self, fields, show_dup=False): + row = [] + for modifier, field in fields: + row.append(self.make_table_row(field, modifier, show_dup)) + return row + + # update all fields based on provided dict + def update(self, data: dict, dup=False): + for field, value in data.items(): + if value is None: + continue + if field == "mango_time" and isinstance(value, list): + value = sum(value) + if not dup: + self.field_dict[field]["uniq"] += value + self.field_dict[field]["dup"] += value + + +def parse_mango_result( + mango_results: Dict[str, ResultAccumulator], + result_file: Path, + info: ELFInfo, + dup=False, +): + if info.brand not in mango_results: + mango_results[info.brand] = ResultAccumulator( + [ + "binaries", + "binaries_alerted", + "binaries_resolved", + "total_alerts", + "trupocs", + "no_sinks", + "timeout", + "oom", + "error", + "cfg_time", + "vra_time", + "mango_time", + ] + ) + + if result_file.exists(): + try: + result = json.loads(result_file.read_text()) + hits = len(result.get("closures", result.get("results", []))) + has_hits = hits > 0 + resolved = not has_hits - (not result["has_sinks"]) + no_sinks = 0 if result["has_sinks"] else 1 + error = 1 if result["error"] else 0 + timeout = 1 if "ret_code" in result and result["ret_code"] == 124 else 0 + error -= timeout + oom = 1 if "ret_code" in result and result["ret_code"] == -9 else 0 + cfg_time = result["cfg_time"] if result["cfg_time"] else 0 + vra_time = result["vra_time"] if result["vra_time"] else 0 + mango_time = ( + result["mango_time"] if "mango_time" in result else result["analysis_time"] + ) + trupocs = len([x for x in result.get("closures", [{"rank": 0}]) if x["rank"] >= 7]) + + update_dict = { + "binaries_alerted": has_hits, + "binaries_resolved": resolved, + "total_alerts": hits, + "trupocs": trupocs, + "no_sinks": no_sinks, + "error": error, + "timeout": timeout, + "oom": oom, + "cfg_time": cfg_time, + "vra_time": vra_time, + "mango_time": mango_time, + "binaries": 1, + } + except json.decoder.JSONDecodeError: + update_dict = {"error": 1, "binaries": 1} + + else: + update_dict = {"error": 1, "binaries": 1} + + mango_results[info.brand].update(update_dict, dup=dup) + + +def parse_env_result( + env_results: Dict[str, ResultAccumulator], + result_file: Path, + info: ELFInfo, + dup=False, +): + if info.brand not in env_results: + env_results[info.brand] = ResultAccumulator( + [ + "binaries", + "binaries_alerted", + "binaries_resolved", + "total_alerts", + "no_sinks", + "timeout", + "oom", + "error", + "cfg_time", + "vra_time", + "analysis_time", + ] + ) + + if result_file.exists(): + result = json.loads(result_file.read_text()) + hits = len(result.get("closures", result.get("results", []))) + has_hits = hits > 0 + resolved = not has_hits - (not result["has_sinks"]) + no_sinks = 0 if result["has_sinks"] else 1 + error = 1 if result["error"] else 0 + timeout = 1 if "ret_code" in result and result["ret_code"] == 124 else 0 + error -= timeout + oom = 1 if "ret_code" in result and result["ret_code"] == -9 else 0 + cfg_time = result["cfg_time"] if result["cfg_time"] else 0 + vra_time = result["vra_time"] if result["vra_time"] else 0 + mango_time = ( + result["mango_time"] if "mango_time" in result else result["analysis_time"] + ) + if mango_time == 0 and "sink_times" in result: + mango_time = sum(result["sink_times"].values()) + + + update_dict = { + "binaries_alerted": has_hits, + "binaries_resolved": resolved, + "total_alerts": hits, + "no_sinks": no_sinks, + "error": error, + "timeout": timeout, + "oom": oom, + "cfg_time": cfg_time, + "vra_time": vra_time, + "analysis_time": mango_time, + "binaries": 1, + } + + else: + update_dict = {"error": 1, "binaries": 1} + + env_results[info.brand].update(update_dict, dup=dup) + + +def generate_mango_table( + mango_results: Dict[str, ResultAccumulator], show_dups=False +) -> Table: + """Make a new table.""" + table = Table(title="MANGO RESULTS") + table.add_column("Vendor", vertical="middle", style="bold") + table.add_column("Binaries") + table.add_column("[green]Binaries Alerted") + table.add_column("Binaries Resolved") + table.add_column("No Sinks") + table.add_column("[green]Total Alerts") + table.add_column("[bold green]TruPoCs") + table.add_column("[red]Error") + table.add_column("[blue]Timeout") + table.add_column("[yellow]OOM") + table.add_column("Analysis Time", justify="right") + + selector = "dup" if show_dups else "uniq" + for idx, vendor in enumerate(sorted(list(mango_results))): + data = mango_results[vendor] + vendor_time = datetime.timedelta( + seconds=int( + data.mango_time[selector] + + data.cfg_time[selector] + + data.vra_time[selector] + ) + ) + row = [vendor] + styled_rows = data.make_table_rows( + [ + ("", "binaries"), + ("[green]", "binaries_alerted"), + ("", "binaries_resolved"), + ("", "no_sinks"), + ("[green]", "total_alerts"), + ("[bold green]", "trupocs"), + ("[red]", "error"), + ("[blue]", "timeout"), + ("[yellow]", "oom"), + ], + show_dup=show_dups, + ) + row.extend(styled_rows) + row.append(f"{vendor_time}") + table.add_row(*row, end_section=idx == len(list(mango_results)) - 1) + table.add_row( + "Total", + str(sum(x.binaries[selector] for x in mango_results.values())), + f"[green]{sum(x.binaries_alerted[selector] for x in mango_results.values())}", + f"{sum(x.binaries_resolved[selector] for x in mango_results.values())}", + f"{sum(x.no_sinks[selector] for x in mango_results.values())}", + f"[green]{sum(x.total_alerts[selector] for x in mango_results.values())}", + f"[bold green]{sum(x.trupocs[selector] for x in mango_results.values())}", + f"[red]{sum(x.error[selector] for x in mango_results.values())}", + f"[blue]{sum(x.timeout[selector] for x in mango_results.values())}", + f"[yellow]{sum(x.oom[selector] for x in mango_results.values())}", + str( + datetime.timedelta( + seconds=int( + sum( + x.mango_time[selector] + + x.cfg_time[selector] + + x.vra_time[selector] + for x in mango_results.values() + ) + ) + ) + ), + ) + return table + + +def generate_env_table( + env_results: Dict[str, ResultAccumulator], show_dups=False +) -> Table: + """Make a new table.""" + table = Table(title="ENV RESULTS") + table.add_column("Vendor") + table.add_column("Binaries") + table.add_column("[green]Binaries Alerted") + table.add_column("Binaries Resolved") + table.add_column("No Sinks") + table.add_column("Total Alerts") + table.add_column("[red]Error") + table.add_column("[blue]Timeout") + table.add_column("[yellow]OOM") + table.add_column("Analysis Time", justify="right") + + selector = "dup" if show_dups else "uniq" + for idx, vendor in enumerate(sorted(list(env_results))): + data = env_results[vendor] + vendor_time = datetime.timedelta( + seconds=int( + data.analysis_time[selector] + + data.cfg_time[selector] + + data.vra_time[selector] + ) + ) + row = [vendor] + styled_rows = data.make_table_rows( + [ + ("", "binaries"), + ("[green]", "binaries_alerted"), + ("", "binaries_resolved"), + ("", "no_sinks"), + ("[green]", "total_alerts"), + ("[red]", "error"), + ("[blue]", "timeout"), + ("[yellow]", "oom"), + ], + show_dup=show_dups, + ) + row.extend(styled_rows) + row.append(f"{vendor_time}") + table.add_row(*row, end_section=idx == len(list(env_results)) - 1) + + table.add_row( + "Total", + str(sum(x.binaries[selector] for x in env_results.values())), + f"[green]{sum(x.binaries_alerted[selector] for x in env_results.values())}", + str(sum(x.binaries_resolved[selector] for x in env_results.values())), + f"{sum(x.no_sinks[selector] for x in env_results.values())}", + f"[green]{sum(x.total_alerts[selector] for x in env_results.values())}", + f"[red]{sum(x.error[selector] for x in env_results.values())}", + f"[blue]{sum(x.timeout[selector] for x in env_results.values())}", + f"[yellow]{sum(x.oom[selector] for x in env_results.values())}", + str( + datetime.timedelta( + seconds=int( + sum( + x.analysis_time[selector] + + x.cfg_time[selector] + + x.vra_time[selector] + for x in env_results.values() + ) + ) + ) + ), + ) + return table diff --git a/pipeline/mango_pipeline/scripts/de-dup.py b/pipeline/mango_pipeline/scripts/de-dup.py new file mode 100644 index 0000000..4a02967 --- /dev/null +++ b/pipeline/mango_pipeline/scripts/de-dup.py @@ -0,0 +1,163 @@ +import sys +import csv + +from pathlib import Path + +def get_csv_data(csv_path: Path, prev=False, delim="\t"): + out_data = {} + title = [] + vendor = None + firmware = None + sha = None + name = None + cfg_time = None + vra_time = None + analysis_time = None + found_bins = set() + found_header = False + + with open(csv_path, newline='') as csvfile: + spamreader = csv.reader(csvfile, delimiter=delim, quotechar='|') + for line in spamreader: + + if line and line[0].strip() in ["Brand", "Firmware"]: + if not title: + title = line + found_header = not found_header + continue + + if not found_header: + continue + + if line and line[0]: + vendor, firmware, sha, name = line[:4] + cfg_time, vra_time, analysis_time = [float(x) if x else 0 for x in line[8:11]] + + if not vendor in out_data: + out_data[vendor] = {} + if not firmware in out_data[vendor]: + out_data[vendor][firmware] = {} + if not sha in out_data[vendor][firmware]: + out_data[vendor][firmware][sha] = {"name": name, "cfg_time": cfg_time, "vra_time": vra_time, "analysis_time": analysis_time, "rows": {}} + else: + out_data[vendor][firmware][sha]["analysis_time"] += analysis_time + out_data[vendor][firmware][sha]["cfg_time"] = max(analysis_time, out_data[vendor][firmware][sha]["cfg_time"]) + out_data[vendor][firmware][sha]["vra_time"] = max(analysis_time, out_data[vendor][firmware][sha]["vra_time"]) + continue + + if not any(x for x in line): + continue + + if prev: + line = line[2:] + addr = line[5] + out_data[vendor][firmware][sha]["rows"][addr] = line + return out_data, title + + +def get_total_data(prev_data, new_data): + sha_set = set() + all_lens = set() + for brand, firm_dict in new_data.items(): + if brand not in prev_data: + continue + for firmware, sha_dict in firm_dict.items(): + if firmware not in prev_data[brand]: + continue + for sha, row_dict in sha_dict.items(): + if sha in sha_set: + new_data[brand][firmware].pop(sha) + continue + else: + sha_set.add(sha) + + if sha not in prev_data[brand][firmware]: + continue + + for addr, row in row_dict["rows"].items(): + if addr in prev_data[brand][firmware][sha]["rows"]: + new_data[brand][firmware][sha]["rows"][addr] = prev_data[brand][firmware][sha]["rows"][addr] + new_data[brand][firmware][sha]["rows"][addr].extend(row[-2:]) + else: + while len(new_data[brand][firmware][sha]["rows"][addr]) < 15: + new_data[brand][firmware][sha]["rows"][addr].insert(-3, "") + + return new_data + + +def gen_csv(title, csv_data): + with open("./updated.csv", "w", newline="") as csvfile: + spamwriter = csv.writer(csvfile, delimiter='\t', + quotechar='|', quoting=csv.QUOTE_MINIMAL) + + title = title[:3] + ["Name"] + title[3:] + all_rows = [] + for vendor in sorted(csv_data.keys()): + for firmware in sorted(csv_data[vendor].keys()): + for sha in sorted(csv_data[vendor][firmware].keys()): + data = csv_data[vendor][firmware][sha] + all_rows.append([vendor, firmware, sha, data['name'], "", "", "", "", data["cfg_time"], data["vra_time"], data["analysis_time"]]) + for row in sorted(data["rows"], key=lambda x: int(x, 16)): + all_rows.append([""] + data["rows"][row]) + + for idx in range(len(all_rows)): + all_rows[idx][4] = f'=IF(ISBLANK(D{idx+5}), "", D{idx+5} & " [" & COUNTIF($D$5:$D, D{idx+5}) & "]")' + all_rows = [title] + [x for x in all_rows] + all_rows.insert(0, []) + all_rows.insert(0, ["Completed", f'=CountA(H5:H{len(all_rows)-1}) & " of " & CountA(F5:F{len(all_rows)-1})']) + all_rows.insert(0, ["True Positives", f'=CountIF(H5:H{len(all_rows)-1}, "Y")/(CountIF(H5:H{len(all_rows)-1}, "Y") + CountIF(H5:H{len(all_rows)-1}, "N"))']) + spamwriter.writerows(all_rows) + +def print_stats(total_data): + valid_rows = [] + from collections import Counter + paired_d = {} + shas = set() + for vendor, firm_dict in total_data.items(): + for firmware, sha_dict in firm_dict.items(): + for sha, data_dict in sha_dict.items(): + if any(y == "TP" or y == "FP" for x in data_dict["rows"].values() for y in x): + shas.add(sha) + valid_rows.extend([(x[6], x[13], x[14], vendor) for x in data_dict["rows"].values() if len(x) > 6 and x[6]]) + d = {"TP": {}, "FP": {}} + total = 0 + file_ops = ["fgets", "read", "open", "fread"] + network_ops = ["socket", "recv"] + vendor_dict = {} + for i, tags, xrefs, vendor in valid_rows: + if i not in d: + continue + if vendor not in vendor_dict: + vendor_dict[vendor] = {} + if i not in vendor_dict[vendor]: + vendor_dict[vendor][i] = 0 + + vendor_dict[vendor][i] += 1 + other = {"network_ops" if x in network_ops else "file_ops" if x in file_ops else "unknown" if not x else x for x in tags.split(",")} + other = sorted(other) + tup = tuple(other + [i]) + if tup not in paired_d: + paired_d[tup] = 0 + paired_d[tup] += 1 + for t in other: + if t not in d[i]: + d[i][t] = 0 + d[i][t] += 1 + total += 1 + for t in d["TP"]: + if t not in d["FP"]: + d["FP"][t] = 0 + print(f"{t.ljust(10, ' ')}: {d['TP'][t]}/{(d['TP'][t] + d['FP'][t])} = {(d['TP'][t]/(d['TP'][t] + d['FP'][t]))*100:.2f}%") + print(f"{len(shas) = }") + print(f"{total = }") + + +if __name__ == '__main__': + assert(len(sys.argv) == 3) + prev_version = Path(sys.argv[1]) + new_version = Path(sys.argv[2]) + prev_data, title = get_csv_data(prev_version, prev=True, delim="\t") + new_data, _ = get_csv_data(new_version) + total_data = get_total_data(prev_data, new_data) + print_stats(total_data) + gen_csv(title, total_data) \ No newline at end of file diff --git a/pipeline/mango_pipeline/scripts/get_tp_from_sheet.py b/pipeline/mango_pipeline/scripts/get_tp_from_sheet.py new file mode 100644 index 0000000..d9a6b17 --- /dev/null +++ b/pipeline/mango_pipeline/scripts/get_tp_from_sheet.py @@ -0,0 +1,120 @@ +import sys +import json +import subprocess +import csv + +from pathlib import Path +from rich.table import Table +from rich.console import Console + + + +def get_csv_data(csv_path: Path, prev=False, delim="\t"): + out_data = {} + title = [] + vendor = None + firmware = None + sha = None + name = None + cfg_time = None + vra_time = None + analysis_time = None + found_bins = set() + found_header = False + + with open(csv_path, newline='') as csvfile: + spamreader = csv.reader(csvfile, delimiter=delim, quotechar='|') + for line in spamreader: + + if line and line[0].strip() in ["Brand", "Firmware"]: + if not title: + title = line + found_header = not found_header + continue + + if not found_header: + continue + + if line and line[0]: + vendor, firmware, sha, name = line[:4] + cfg_time, vra_time, analysis_time = [float(x) if x else 0 for x in line[8:11]] + + if not vendor in out_data: + out_data[vendor] = {} + if not firmware in out_data[vendor]: + out_data[vendor][firmware] = {} + if not sha in out_data[vendor][firmware]: + out_data[vendor][firmware][sha] = {"name": name, "cfg_time": cfg_time, "vra_time": vra_time, "analysis_time": analysis_time, "rows": {}} + else: + out_data[vendor][firmware][sha]["analysis_time"] += analysis_time + out_data[vendor][firmware][sha]["cfg_time"] = max(analysis_time, out_data[vendor][firmware][sha]["cfg_time"]) + out_data[vendor][firmware][sha]["vra_time"] = max(analysis_time, out_data[vendor][firmware][sha]["vra_time"]) + continue + + if not any(x for x in line): + continue + + if prev: + line = line[1:] + addr = line[5] + out_data[vendor][firmware][sha]["rows"][addr] = line + return out_data, title + + +def get_tp_count(csv_data, results_path): + tp_dict = {} + for brand, firm_dict in csv_data.items(): + for firmware, bin_dict in firm_dict.items(): + for sha, vals in bin_dict.items(): + val_dict = {} + for row in vals["rows"].values(): + key = row[6].strip() + if key not in val_dict: + val_dict[key] = 0 + val_dict[key] += 1 + + tp_dict[sha] = val_dict + + vendor_data = json.loads((results_path / "vendors.json").read_text()) + data_dict = {} + for brand, firm_dict in vendor_data.items(): + data_dict[brand] = {} + for firmware, elf_dict in firm_dict["firmware"].items(): + data_dict[brand][firmware] = {} + for elf in elf_dict["elfs"]: + if elf in tp_dict: + for key in tp_dict[elf]: + if key not in data_dict[brand][firmware]: + data_dict[brand][firmware][key] = 0 + data_dict[brand][firmware][key] += tp_dict[elf][key] + + data_dict.pop("huawei_fastboot") + data_dict.pop("lk") + data_dict.pop("NVIDIA") + all_keys = sorted({y for x in tp_dict.values() for y in x}) + + table = Table(title="Total Data Info") + table.add_column("Brand/Firmware") + for key in all_keys: + table.add_column(key) + + brand_data = {b: {x: 0 for x in all_keys} for b in data_dict} + for brand, firmware_dict in data_dict.items(): + for firmware, values in firmware_dict.items(): + firmware_dict = {x: 0 for x in all_keys} + for key, val in values.items(): + if key in firmware_dict: + brand_data[brand][key] += val + firmware_dict[key] += val + #table.add_row(firmware, *[str(firmware_dict[x]) for x in all_keys]) + table.add_row(brand, *[str(brand_data[brand][x]) for x in all_keys]) + + table.add_row("Total", *[str(sum(brand_data[b][x] for b in brand_data)) for x in all_keys]) + Console().print(table) + + +if __name__ == '__main__': + sheet = Path(sys.argv[1]) + data, _ = get_csv_data(sheet, prev=True) + results = Path(sys.argv[2]) + get_tp_count(data, results) diff --git a/pipeline/mango_pipeline/scripts/path_context_aggregator.py b/pipeline/mango_pipeline/scripts/path_context_aggregator.py new file mode 100644 index 0000000..c7b1b3f --- /dev/null +++ b/pipeline/mango_pipeline/scripts/path_context_aggregator.py @@ -0,0 +1,94 @@ +import subprocess +import json + +import sys +import networkx as nx +import multiprocessing +import logging +import argparse +import pprint + +from pathlib import Path + +import angr +from rich.progress import track +from argument_resolver.external_function.sink import BUFFER_OVERFLOW_SINKS, COMMAND_INJECTION_SINKS + +logging.getLogger('angr').setLevel('CRITICAL') + +def get_path_contexts(context_args): + target_folder, bin_path = context_args + overflow_count = {} + cmdi_count = {} + if target_folder.exists(): + if (target_folder / "cmdi_results.json").exists(): + try: + cmdi_data = json.loads((target_folder / "cmdi_results.json").read_text()) + if "sinks" in cmdi_data: + cmdi_count = cmdi_data["sinks"] + except: + pass + if (target_folder / "overflow_results.json").exists(): + try: + overflow_data = json.loads((target_folder / "overflow_results.json").read_text()) + if "sinks" in overflow_data: + overflow_count = overflow_data["sinks"] + except: + pass + + contexts = {} + if not cmdi_count and not overflow_count: + return contexts + try: + project = angr.Project(str(bin_path), auto_load_libs=False) + cfg = project.analyses.CFGFast(normalize=True, data_references=True, show_progressbar=False) + valid_sinks = [cfg.functions[x] for x in cmdi_count | overflow_count if x in cfg.functions] + contexts = {k.name: {"paths": 0, "count": 0} for k in valid_sinks} + contexts["cmdi"] = {"paths": 0, "count": sum(cmdi_count.values() or [0])} + contexts["overflow"] = {"paths": 0, "count": sum(overflow_count.values() or [0])} + for sink in valid_sinks: + g = nx.dfs_tree(cfg.functions.callgraph.reverse(), source=sink.addr, depth_limit=7) + leaf_nodes = {x for x in g.nodes() if g.out_degree(x) == 0} + + agg = "overflow" if sink.name == "strcpy" else "cmdi" + if sink.name in cmdi_count: + contexts[sink.name]["count"] += cmdi_count[sink.name] + elif sink.name in overflow_count: + contexts[sink.name]["count"] += overflow_count[sink.name] + for node in leaf_nodes: + paths = nx.all_simple_paths(cfg.functions.callgraph, node, sink.addr) + total_paths = len(list(paths)) + contexts[sink.name]["paths"] += total_paths + contexts[agg]["paths"] += total_paths + + except Exception as e: + print(e) + return contexts + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Angr Based Path Contex Aggregator') + parser.add_argument('result_folder', type=str, help='Result Directory of Mango Analysis') + parser.add_argument('--cores', type=int, default=int(multiprocessing.cpu_count() * 3 / 4), help="Amount of cores to dedicate to this (could take several hours)") + args = parser.parse_args() + vendor_file = json.loads(Path(args.result_folder + "/vendors.json").read_text()) + targets = set() + for brand, firmwares in vendor_file.items(): + for firmware, vals in firmwares["firmware"].items(): + for sha, elf in vals["elfs"].items(): + targets.add((Path(args.result_folder) / brand / firmware / sha, Path(elf["path"]))) + + context_counter = {} + with multiprocessing.Pool(args.cores) as p: + for res in track(p.imap_unordered(get_path_contexts, targets), total=len(targets), description="Running analysis..."): + for key, values in res.items(): + if key not in context_counter: + context_counter[key] = {"paths": 0, "count": 0} + context_counter[key]["paths"] += values["paths"] + context_counter[key]["count"] += values["count"] + pprint.pprint(context_counter, indent=4) + + with open("counter.agg", "w+") as f: + json.dump(context_counter, f, indent=4) + print("FINAL COUNTER:", context_counter) + diff --git a/pipeline/mango_pipeline/scripts/show_data.py b/pipeline/mango_pipeline/scripts/show_data.py new file mode 100644 index 0000000..26e8708 --- /dev/null +++ b/pipeline/mango_pipeline/scripts/show_data.py @@ -0,0 +1,331 @@ +import rich +import subprocess +import sys +import json +import datetime + +from collections import Counter +from functools import reduce + +from pathlib import Path +from rich.progress import track +from rich.table import Table +from rich.console import Console + + +def get_all_data(result_dir: Path, de_dup=True): + possible_starts = { + "/home/clasm/projects/angr-squad/", + "/shared/clasm/", + "/tmp/firmwar", + } + all_jsons = ( + subprocess.check_output( + ["find", result_dir, "-type", "f", "-iname", "cmdi_results.json"] + ) + .decode() + .strip() + .split("\n") + ) + all_jsons += ( + subprocess.check_output( + ["find", result_dir, "-type", "f", "-iname", "overflow_results.json"] + ) + .decode() + .strip() + .split("\n") + ) + result_data = [] + vendor_data = json.loads((result_dir / "vendors.json").read_text()) + known_cmdi_shas = set() + known_overflow_shas = set() + for x in track(all_jsons, description="Extracting JSON Data", total=len(all_jsons)): + if not x: + continue + + p = Path(x) + try: + d = json.loads(p.read_text()) + except: + continue + + symbols = json.loads((p.parent.parent / "symbols.json").read_text()) + d["vendor"] = symbols["brand"] + d["firmware"] = symbols["firmware"] + d["firm_path"] = "/".join(Path(x).parts[-3:]) + + if "vendor" not in d: + breakpoint() + + if de_dup and ( + d["path"].replace("shared/clasm", "home/clasm/projects/angr-squad") + != vendor_data[d["vendor"]]["firmware"][d["firmware"]]["elfs"][d["sha256"]][ + "path" + ] + ): + continue + + if de_dup: + if "cmdi" in d["firm_path"]: + if d["sha256"] in known_cmdi_shas: + continue + known_cmdi_shas.add(d["sha256"]) + elif "overflow" in d["firm_path"]: + if d["sha256"] in known_overflow_shas: + continue + known_overflow_shas.add(d["sha256"]) + result_data.append(d) + return result_data + + +def generate_mango_table(raw_data, show_firm, title) -> Table: + """Make a new table.""" + table = Table(title=title) + table.add_column("Vendor") + if show_firm: + table.add_column("Device") + table.add_column("Binaries") + table.add_column("[green]Binaries Hit") + table.add_column("Binaries Resolved") + table.add_column("No Sinks") + table.add_column("[green]Total Hits") + table.add_column("[red]Error") + table.add_column("[blue]Timeout") + table.add_column("[yellow]OOM") + table.add_column("Analysis Time", justify="right") + + seen_bins = set() + table_data = {} + for bin_data in raw_data: + vendor = bin_data["vendor"] + if show_firm: + vendor = bin_data["firmware"] + seen = bin_data["firm_path"] in seen_bins + if not seen: + seen_bins.add(bin_data["firm_path"]) + if vendor not in table_data: + table_data[vendor] = { + "binaries": 0, + "binaries_with_alerts": 0, + "binaries_resolved": 0, + "no_sinks": 0, + "total_alerts": 0, + "errors": 0, + "timeouts": 0, + "oom": 0, + "time_taken": 0, + "unique_sinks": Counter(), + "alerted_sinks": Counter(), + "sinks": Counter(), + "vendor": bin_data["vendor"], + } + + alerts = len(bin_data["closures"]) + has_sinks = bin_data["has_sinks"] + resolved = has_sinks and not alerts > 0 + error = bin_data["error"] is not None + timeout = error and bin_data["ret_code"] == 124 + oom = error and bin_data["ret_code"] == -9 + + mango_time = bin_data["mango_time"] + if isinstance(mango_time, list): + mango_time = sum(mango_time) + + table_data[vendor]["binaries"] += 1 if not seen else 0 + table_data[vendor]["binaries_with_alerts"] += alerts > 0 and not seen + table_data[vendor]["binaries_resolved"] += resolved and not seen + table_data[vendor]["no_sinks"] += not has_sinks and not seen + table_data[vendor]["total_alerts"] += alerts + table_data[vendor]["errors"] += error and not timeout and not oom + table_data[vendor]["timeouts"] += timeout + table_data[vendor]["oom"] += oom + table_data[vendor]["time_taken"] += ( + (mango_time or 0) + + (bin_data["cfg_time"] or 0) + + (bin_data["vra_time"] or 0) + ) + table_data[vendor]["alerted_sinks"].update( + Counter([c["sink"]["function"] for c in bin_data["closures"]]) + ) + unique_sinks = { + c["sink"]["function"] + "-" + c["sink"]["ins_addr"] + for c in bin_data["closures"] + } + for c in unique_sinks: + name = c.split("-")[0] + table_data[vendor]["unique_sinks"][name] += 1 + if "sinks" in bin_data: + table_data[vendor]["sinks"].update(bin_data["sinks"]) + + for idx, data in enumerate( + sorted(table_data.items(), key=lambda r: r[1]["vendor"] + r[0]) + ): + vendor, row = data + row_data = [ + f"{row['vendor']}", + f"{row['binaries']}", + f"[green]{row['binaries_with_alerts']}", + f"{row['binaries_resolved']}", + f"{row['no_sinks']}", + f"[green]{row['total_alerts']}", + f"[red]{row['errors']}", + f"[blue]{row['timeouts']}", + f"[yellow]{row['oom']}", + str(datetime.timedelta(seconds=int(row["time_taken"]))), + ] + if show_firm: + row_data.insert(1, vendor) + table.add_row(*row_data, end_section=idx == len(list(table_data)) - 1) + final_sinks = Counter() + final_unique_sinks = Counter() + final_alerted_sinks = Counter() + for x in table_data.values(): + final_sinks.update(x["sinks"]) + final_unique_sinks.update(x["unique_sinks"]) + final_alerted_sinks.update(x["alerted_sinks"]) + table_data["total"] = { + "binaries": sum(x["binaries"] for x in table_data.values()), + "binaries_with_alerts": sum( + x["binaries_with_alerts"] for x in table_data.values() + ), + "binaries_resolved": sum(x["binaries_resolved"] for x in table_data.values()), + "no_sinks": sum(x["no_sinks"] for x in table_data.values()), + "total_alerts": sum(x["total_alerts"] for x in table_data.values()), + "errors": sum(x["errors"] for x in table_data.values()), + "timeouts": sum(x["timeouts"] for x in table_data.values()), + "oom": sum(x["oom"] for x in table_data.values()), + "time_taken": sum(x["time_taken"] for x in table_data.values()), + "sinks": final_sinks, + "unique_sinks": final_unique_sinks, + "alerted_sinks": final_alerted_sinks, + } + row_data = [ + "Total", + str(table_data["total"]["binaries"]), + f"[green]{table_data['total']['binaries_with_alerts']}", + f"{table_data['total']['binaries_resolved']}", + f"{table_data['total']['no_sinks']}", + f"[green]{table_data['total']['total_alerts']}", + f"[red]{table_data['total']['errors']}", + f"[blue]{table_data['total']['timeouts']}", + f"[yellow]{table_data['total']['oom']}", + str(datetime.timedelta(seconds=int(table_data["total"]["time_taken"]))), + ] + if show_firm: + row_data.insert(1, "-") + table.add_row(*row_data) + return table, table_data + + +def print_table( + result_data: list, + orig_dir: Path, + show_firm=False, + unique=False, + show_sinks=False, + print_latex=False, + title="Mango Results", +): + if unique: + unique_data = {} + for d in result_data: + if "sha" in d: + sha = d["sha"] + del d["sha"] + d["sha256"] = sha + unique_data[d["sha256"]] = d + + result_data = list(unique_data.values()) + + with open(orig_dir.name + ".list", "w+") as f: + shas = [ + x["firm_path"] + for x in result_data + if x["has_sinks"] and x["error"] is None and len(x["closures"]) > 0 + ] + f.write("\n".join(sorted(shas))) + table, table_data = generate_mango_table( + result_data, show_firm=show_firm, title=title + ) + with Console() as console: + console.print(table) + if show_sinks: + console.print("Total Sinks", table_data["total"]["sinks"]) + console.print("Alerted Sinks", table_data["total"]["alerted_sinks"]) + console.print("Unique Sinks", table_data["total"]["unique_sinks"]) + if print_latex: + name = orig_dir.name.replace("-", "").replace("_", "") + for vendor, data in table_data.items(): + console.print(f"%{'-'*50}") + console.print(f"%{name.upper()} {vendor.upper()} DATA") + console.print(f"%{'-'*50}") + for k, v in data.items(): + if k == "time_taken": + analyzed_bins = ( + data["binaries_with_alerts"] + data["binaries_resolved"] + ) + out = f"\\newcommand{{\\{name}{vendor}AVGTimePerBin}}{{{v/analyzed_bins:.2f}\\xspace}}" + console.print(out) + v = str(datetime.timedelta(seconds=int(v))) + elif k in {"unique_sinks", "sinks", "alerted_sinks"}: + continue + out = f"\\newcommand{{\\{name}{vendor}{k.replace('_','')}}}{{{v}\\xspace}}" + console.print(out) + + +if __name__ == "__main__": + only_unique = False + print_latex = False + combine = False + firmware = False + sinks = False + if any(x == "-u" or x == "--unique" for x in sys.argv[1:]): + only_unique = True + if any(x == "-l" or x == "--latex" for x in sys.argv[1:]): + print_latex = True + if any(x == "-c" or x == "--combine" for x in sys.argv[1:]): + combine = True + if any(x == "-f" or x == "--firmware" for x in sys.argv[1:]): + firmware = True + if any(x == "-s" or x == "--sinks" for x in sys.argv[1:]): + sinks = True + all_data = [] + for res_dir in sys.argv[1:]: + if res_dir in { + "-u", + "--unique", + "-l", + "--latex", + "-c", + "--combine", + "-f", + "--firmware", + "-s", + "--sinks", + }: + continue + res_d = Path(res_dir).absolute() + data = get_all_data(res_d, de_dup=False) + + if not combine: + print_table( + data, + res_d, + unique=only_unique, + show_firm=firmware, + show_sinks=sinks, + print_latex=print_latex, + title=res_d.name, + ) + else: + all_data.extend(data) + if combine: + print_table( + all_data, + Path("combined"), + unique=only_unique, + show_firm=firmware, + show_sinks=sinks, + print_latex=print_latex, + title="combined", + ) diff --git a/pipeline/mango_pipeline/scripts/show_table.py b/pipeline/mango_pipeline/scripts/show_table.py new file mode 100644 index 0000000..883265c --- /dev/null +++ b/pipeline/mango_pipeline/scripts/show_table.py @@ -0,0 +1,178 @@ +import sys +import subprocess +import json +import argparse + +from rich.table import Table +from rich.console import Console +from rich.progress import track +from pathlib import Path + + +def parse_mango_result(data_path: Path, dataset: Path): + output = subprocess.check_output(["find", data_path, "-type", "f", "-name", "*_results.json"]).decode().strip().split("\n") + + total_dict = {"Vendors": {}, "Firmware": {}} + for file in track(output, description="Loading files", total=len(output)): + fp = Path(file) + if not fp.is_file(): + continue + try: + data = json.loads(fp.read_text()) + path = Path(data["path"]) + if dataset is not None: + new_path = dataset / '/'.join(path.parts[path.parts.index(dataset.name)+1:]) + mango_time = data["mango_time"] + if isinstance(mango_time, list): + mango_time = sum(mango_time) + file_res = { + "brand": fp.parent.parent.parent.name, + "firmware": fp.parent.parent.name, + "has_sinks": data["has_sinks"], + "time": (mango_time + data["vra_time"] + data["cfg_time"]) if data["error"] is None else 0, + "hits": len(data["closures"]), + "trupocs": len([x for x in data["closures"] if x["rank"] >= 7]), + "error": data["ret_code"] if data["error"] else 0, + "size": new_path.stat().st_size if dataset is not None else 0, + } + except json.decoder.JSONDecodeError: + file_res = { + "brand": fp.parent.parent.parent.name, + "firmware": fp.parent.parent.name, + "has_sinks": False, + "time": 0, + "hits": 0, + "trupocs": 0, + "error": "early_termination", + "size": 0, + } + firm = file_res["firmware"] + brand = file_res["brand"] + + if firm in total_dict["Firmware"]: + total_dict["Firmware"][firm].append(file_res) + else: + total_dict["Firmware"][firm] = [file_res] + + if brand in total_dict["Vendors"]: + total_dict["Vendors"][brand].append(file_res) + else: + total_dict["Vendors"][brand] = [file_res] + + return total_dict + + +def show_table(results: dict, show_firm): + """Make a new table.""" + rows = [] + + brand_data = {} + for firm, res in results["Firmware"].items(): + firm_data = {"binaries": len(res), "bin_hits": 0, "resolved": 0, "no_sinks": 0, "hits": 0, "errors": 0, "timeouts": 0, "time": 0, "OOMKILLED": 0, "size": 0, "trupocs": 0, "trupoc_bins": 0} + for r in res: + firm_data["binaries"] += 1 + firm_data["size"] += r["size"] + if r["hits"] > 0: + firm_data["bin_hits"] += 1 + firm_data["hits"] += r["hits"] + firm_data["trupocs"] += r["trupocs"] + if r["trupocs"] > 0: + firm_data["trupoc_bins"] += 1 + elif r["error"] == 0: + firm_data["resolved"] += 1 + elif r["error"] != 0: + if r["error"] == 124: + firm_data["timeouts"] += 1 + firm_data["time"] += 3 *60 *60 + elif r["error"] == -9: + firm_data["OOMKILLED"] += 1 + else: + firm_data["errors"] += 1 + + if r["time"] < 0: + r["time"] = 3 * 60 * 60 + firm_data["time"] += r["time"] + if r["brand"] not in brand_data: + brand_data[r["brand"]] = [firm_data] + else: + brand_data[r["brand"]].append(firm_data) + + firm_data["no_sinks"] = firm_data["binaries"] - (firm_data["bin_hits"] + firm_data["resolved"]) + if show_firm: + rows.append([firm, "", f'{firm_data["binaries"]:,}', f'{firm_data["bin_hits"]:,}', f'{firm_data["trupoc_bins"]:,}', f'{firm_data["resolved"]:,}', f'{firm_data["no_sinks"]:,}', f'{firm_data["hits"]:,}', f'{firm_data["trupocs"]}', f'{firm_data["errors"]:,}', f'{firm_data["OOMKILLED"]:,}', f'{firm_data["timeouts"]:,}', f"{firm_data['time']/60:,.2f} Min", "", "", f"{firm_data['time']/(firm_data['binaries'] + firm_data['timeouts']):.2f} Sec"]) + + total_data = {"binaries": 0, "bin_hits": 0, "resolved": 0, "no_sinks": 0, "hits": 0, "errors": 0, "timeouts": 0, "time": 0, "OOMKILLED": 0, "firm": 0, "size": 0, "trupocs": 0, "trupoc_bins": 0} + for brand, data in brand_data.items(): + bins = sum(x["binaries"] for x in data) + bin_hits = sum(x["bin_hits"] for x in data) + resolved = sum(x["resolved"] for x in data) + no_sinks = sum(x["no_sinks"] for x in data) + hits = sum(x["hits"] for x in data) + trupocs = sum(x["trupocs"] for x in data) + trupoc_bins = sum(x["trupoc_bins"] for x in data) + errors = sum(x["errors"] for x in data) + oomkilled = sum(x["OOMKILLED"] for x in data) + timeouts = sum(x["timeouts"] for x in data) + size = f"{(sum(x['size'] for x in data)/int(bins)):,.2f} B" + time = f"{(sum(x['time'] for x in data)/60):,.2f} Min" + avg = f"{(sum(x['time'] for x in data)/60)/len(data):,.2f} Min" + avg_bin = f"{(sum(x['time'] for x in data))/(int(bins) + int(timeouts)):.2f} Sec" + + total_data["binaries"] += int(bins) + total_data["bin_hits"] += int(bin_hits) + total_data["resolved"] += int(resolved) + total_data["no_sinks"] += int(no_sinks) + total_data["hits"] += int(hits) + total_data["errors"] += int(errors) + total_data["OOMKILLED"] += int(oomkilled) + total_data["timeouts"] += int(timeouts) + total_data["time"] += sum(x['time'] for x in data) + total_data["firm"] += len(data) + total_data["size"] += sum(x["size"] for x in data) + total_data["trupocs"] += int(trupocs) + total_data["trupoc_bins"] += int(trupoc_bins) + + + rows.append([brand, f"{len(data):,}", f"{bins:,}", f"{bin_hits:,}", f"{trupoc_bins:,}", f"{resolved:,}", f"{no_sinks:,}", f"{hits:,}", f"{trupocs:,}", f"{errors:,}", f"{oomkilled:,}", f"{timeouts:,}", time, avg, size, avg_bin]) + rows.append(["Total", f'{total_data["firm"]:,}', f'{total_data["binaries"]:,}', f'{total_data["bin_hits"]:,}', f'{total_data["trupoc_bins"]:,}', f'{total_data["resolved"]:,}', f'{total_data["no_sinks"]:,}', f'{total_data["hits"]:,}', f'{total_data["trupocs"]:,}', f'{total_data["errors"]:,}', str(total_data["OOMKILLED"]), str(total_data["timeouts"]), f"{total_data['time']/60:,.2f} Min", f"{(total_data['time']/(sum(len(x) for x in brand_data.values()) or 1))/60:.2f} Min", f"{total_data['size']/(total_data['binaries'] or 1):.2f} B", f"{(total_data['time']/(total_data['binaries'] or 1)):.2f} Sec"]) + + table = Table(title="MANGO RESULTS", show_footer=True) + table.add_column("Name", rows[-1][0]) + offset = 0 + table.add_column("# Firm", rows[-1][1], justify="right") + offset = 1 + table.add_column("Binaries", rows[-1][1+offset], justify="right") + table.add_column("[green]Alerted Bins", rows[-1][2+offset], justify="right") + table.add_column("[bold green]TruPoC Bins", rows[-1][3+offset], justify="right") + table.add_column("Binaries Resolved", rows[-1][4+offset], justify="right") + table.add_column("No Sinks", rows[-1][5+offset], justify="right") + table.add_column("[green]Alerts", rows[-1][6+offset], justify="right") + table.add_column("[bold green]TruPoCs", rows[-1][7+offset], justify="right") + table.add_column("[red]Error", rows[-1][8+offset], justify="right") + table.add_column("[yellow]OOMKilled", rows[-1][9+offset], justify="right") + table.add_column("[blue]Timeout", rows[-1][10+offset], justify="right") + table.add_column("Analysis Time", rows[-1][11+offset], justify="right") + table.add_column("AVG Time", rows[-1][12+offset], justify="right") + #table.add_column("Avg Size", rows[-1][13+offset], justify="right") + table.add_column("AVG Bin Time", rows[-1][14+offset], justify="right") + found_start = False + for idx, x in enumerate(rows[:-1]): + if show_firm and not found_start and idx+1 < len(rows) and rows[idx+1][0] in brand_data: + table.add_row(*x[:-2], x[-1], end_section=True) + found_start = True + else: + table.add_row(*x[:-2], x[-1]) + + Console().print(table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("result_dir", type=str, help="Path to result directory") + parser.add_argument("--dataset", type=str, default=None, help="Path to dataset used in experiment (only useful for binary size data)") + parser.add_argument("--show-firmware", action="store_true", default=False, help="Show data by firmware instead of by vendor") + args = parser.parse_args() + + results = parse_mango_result(args.result_dir, args.dataset) + show_table(results, show_firm=args.show_firmware) + diff --git a/pipeline/mango_pipeline/scripts/symbols_and_vendors.py b/pipeline/mango_pipeline/scripts/symbols_and_vendors.py new file mode 100644 index 0000000..005a19e --- /dev/null +++ b/pipeline/mango_pipeline/scripts/symbols_and_vendors.py @@ -0,0 +1,46 @@ +import sys +import subprocess +import json + +import collections.abc +from pathlib import Path + +def update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = update(d.get(k, {}), v) + else: + d[k] = v + return d + + + +if __name__ == '__main__': + folder = sys.argv[1] + vendors = subprocess.check_output(["find", folder, "-type", "f", "-name", "vendor.json"]).decode().strip().split("\n") + symbols = subprocess.check_output(["find", folder, "-type", "f", "-name", "symbols.json"]).decode().strip().split("\n") + vendor_out = {} + for vendor_file in vendors: + if vendor_file == Path(folder) / "vendor.json": + continue + with open(vendor_file, "r") as f: + vendor_data = json.load(f) + p = Path(vendor_file) + vendor_data = {p.parent.parent.name: {"firmware": {p.parent.name: vendor_data}}} + vendor_out = update(vendor_out, vendor_data) + + with open(Path(folder) / "vendor.json", "w+") as f: + json.dump(vendor_out, f) + + symbol_out = {} + for symbol_file in symbols: + if symbol_file == Path(folder) / "symbols.json": + continue + with open(symbol_file, "r") as f: + symbol_data = json.load(f) + if "symbols" not in symbol_data: + continue + symbol_out.update(symbol_data["symbols"]) + + with open(Path(folder) / "symbols.json", "w+") as f: + json.dump(symbol_out, f) diff --git a/pipeline/pyproject.toml b/pipeline/pyproject.toml new file mode 100644 index 0000000..7725fa3 --- /dev/null +++ b/pipeline/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "setuptools-scm", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mango_pipeline" +version = "0.0.1" +authors = [ + {name = "Wil Gibbs", email = "wfgibbs@asu.edu"}, + ] +description = "A utility to facilitate parallelization across multiple target files for argument_resolver" +requires-python=">=3.10" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11" +] + +dependencies = [ + "argument_resolver", + "rich==13.7.1", + "docker==7.0.0", + "toml==0.10.2", + "kubernetes==29.0.0", + "esprima==4.0.1", + "phply==1.2.6", + "bs4==0.0.2", + "lxml==5.1.0", + "binwalk @ git+https://github.com/ReFirmLabs/binwalk@cddfede795971045d99422bd7a9676c8803ec5ee", + "pyyaml==6.0.1" +] + +[project.scripts] +mango-pipeline = "mango_pipeline.run:main" + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3c991a4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,78 @@ +[build-system] +requires = ["setuptools>=42.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "argument_resolver" +version = "0.0.1" +authors = [ + {name = "Wil Gibbs", email = "wfgibbs@asu.edu"}, + {name = "Pamplemousse", email = "private@example.com"}, + {name = "Fish", email = "fishw@asu.edu"} + ] +description = "An RDA based static-analysis library for resolving function arguments" +readme="README.md" +requires-python=">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11" +] + +dependencies = [ + "angr==9.2.94", + "pydot==2.0.0", + "networkx==3.2.1", + "psutil==5.9.8", + "ipdb==0.13.13", + "rich==13.7.1" +] + +[project.optional-dependencies] +dev = [ + "ipdb", + "pytest", + "pytest-cov", + "mypy", + "flake8" +] +[project.scripts] +mango = "argument_resolver.analysis.mango:main" +env_resolve = "argument_resolver.analysis.env_resolve:main" + +[projects.package_data] +argument_resolver = "py.typed" + +[flake8] +max-line-length = 160 + +[tool.setuptools.packages.find] +where = ["package"] + +[tool.pytest.ini_options] +addopts = "--cov=argument_resolver --cov-fail-under 70" +testpaths = [ + "package/tests", +] + +[tool.coverage.run] +omit = [ + "package/argument_resolver/__main__.py", + "package/argument_resolver/mango.py", +] + +[tool.mypy] +mypy_path = "package" +check_untyped_defs = true +disallow_any_generies = true +ignore_missing_imports = true +no_implicit_optional = true +show_error_codes = true +strict_equality = true +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true +no_implicit_reexport = true + + diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..9efd9d0 --- /dev/null +++ b/shell.nix @@ -0,0 +1,56 @@ +with import { }; + +let python38WithCoolPackages = + # Use CPython, as: + # * PyPy does not seem to support >3.6; + # * `pythonPackages` in nixpkgs might not all target PyPy. + python38.withPackages(ps: with ps; [ + pygraphviz + z3 + ]); +in +stdenv.mkDerivation rec { + name = "operation-mango"; + + buildInputs = [ + python38Packages.virtualenvwrapper + python38WithCoolPackages + + nasm + nmap + libxml2 + libxslt + libffi + readline + libtool + glib + gcc + graphviz + debootstrap + pixman + openssl + jdk8 + ]; + + shellHook = '' + source $(command -v virtualenvwrapper.sh) + if [ -d "$HOME/.virtualenvs/venv3.8" ]; then + workon venv3.8 + else + mkvirtualenv venv3.8 -p $(which python3.8) + fi + + SOURCE_DATE_EPOCH=$(date +%s) + + # + # Insure that some dependencies are installed + # + pip list > pip_list.out + + grep unicorn pip_list.out 2>&1 >/dev/null || UNICORN_QEMU_FLAGS="--python=$(which python2)" pip install unicorn + + for local_dependency in "ailment" "archinfo" "claripy" "cle" "pyelftools" "pyvex" "angr"; do + grep $local_dependency pip_list.out || pip install -e "../$local_dependency" + done + ''; +}