From 240dcb06a4d522e728bb05d49896781d41387b9b Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Mon, 20 May 2024 09:55:26 -0500 Subject: [PATCH] feat: 08 --- .github/workflows/test.yaml | 2 +- README.md | 2 +- ape_solidity/_utils.py | 19 +++---- ape_solidity/compiler.py | 101 +++++++++++++++++++----------------- ape_solidity/exceptions.py | 4 +- pyproject.toml | 2 +- setup.py | 3 +- tests/test_compiler.py | 30 ++++++----- 8 files changed, 85 insertions(+), 78 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 73acc0c..3fa88b7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -65,7 +65,7 @@ jobs: # TODO: Replace with macos-latest when works again. # https://github.com/actions/setup-python/issues/808 os: [ubuntu-latest, macos-12] # eventually add `windows-latest` - python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] + python-version: [3.9, '3.10', '3.11', '3.12'] env: GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index 0f86cdf..6dcd2de 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Compile Solidity contracts. ## Dependencies -- [python3](https://www.python.org/downloads) version 3.8 up to 3.12. +- [python3](https://www.python.org/downloads) version 3.9 up to 3.12. ## Installation diff --git a/ape_solidity/_utils.py b/ape_solidity/_utils.py index e34739a..60b0a66 100644 --- a/ape_solidity/_utils.py +++ b/ape_solidity/_utils.py @@ -1,9 +1,10 @@ import json import os import re +from collections.abc import Iterable from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Sequence, Set, Union +from typing import Optional, Union from ape.exceptions import CompilerError from ape.utils import pragma_str_to_specifier_set @@ -45,7 +46,7 @@ def validate_entry(cls, value): return value @property - def _parts(self) -> List[str]: + def _parts(self) -> list[str]: return self.entry.split("=") # path normalization needed in case delimiter in remapping key/value @@ -96,8 +97,8 @@ class ImportRemappingBuilder: def __init__(self, contracts_cache: Path): # import_map maps import keys like `@openzeppelin/contracts` # to str paths in the compiler cache folder. - self.import_map: Dict[str, str] = {} - self.dependencies_added: Set[Path] = set() + self.import_map: dict[str, str] = {} + self.dependencies_added: set[Path] = set() self.contracts_cache = contracts_cache def add_entry(self, remapping: ImportRemapping): @@ -108,8 +109,8 @@ def add_entry(self, remapping: ImportRemapping): self.import_map[remapping.key] = str(path) -def get_import_lines(source_paths: Set[Path]) -> Dict[Path, List[str]]: - imports_dict: Dict[Path, List[str]] = {} +def get_import_lines(source_paths: set[Path]) -> dict[Path, list[str]]: + imports_dict: dict[Path, list[str]] = {} for filepath in source_paths: import_set = set() if not filepath.is_file(): @@ -168,7 +169,7 @@ def get_pragma_spec_from_str(source_str: str) -> Optional[SpecifierSet]: return pragma_str_to_specifier_set(pragma_match.groups()[0]) -def load_dict(data: Union[str, dict]) -> Dict: +def load_dict(data: Union[str, dict]) -> dict: return data if isinstance(data, dict) else json.loads(data) @@ -183,7 +184,7 @@ def add_commit_hash(version: Union[str, Version]) -> Version: return get_solc_version_from_binary(solc, with_commit_hash=True) -def verify_contract_filepaths(contract_filepaths: Sequence[Path]) -> Set[Path]: +def verify_contract_filepaths(contract_filepaths: Iterable[Path]) -> set[Path]: invalid_files = [p.name for p in contract_filepaths if p.suffix != Extension.SOL.value] if not invalid_files: return set(contract_filepaths) @@ -192,7 +193,7 @@ def verify_contract_filepaths(contract_filepaths: Sequence[Path]) -> Set[Path]: raise CompilerError(f"Unable to compile '{sources_str}' using Solidity compiler.") -def select_version(pragma_spec: SpecifierSet, options: Sequence[Version]) -> Optional[Version]: +def select_version(pragma_spec: SpecifierSet, options: Iterable[Version]) -> Optional[Version]: choices = sorted(list(pragma_spec.filter(options)), reverse=True) return choices[0] if choices else None diff --git a/ape_solidity/compiler.py b/ape_solidity/compiler.py index 6775c49..3e27948 100644 --- a/ape_solidity/compiler.py +++ b/ape_solidity/compiler.py @@ -1,7 +1,8 @@ import os import re +from collections.abc import Iterable, Iterator from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast +from typing import Any, Optional, Union, cast from ape.api import CompilerAPI, PluginConfig from ape.contracts import ContractInstance @@ -65,7 +66,7 @@ class SolidityConfig(PluginConfig): Configure the Solidity plugin. """ - import_remapping: List[str] = [] + import_remapping: list[str] = [] """ Configure re-mappings using a ``=`` separated-str, e.g. ``"@import_name=path/to/dependency"``. @@ -99,9 +100,9 @@ class SolidityConfig(PluginConfig): class SolidityCompiler(CompilerAPI): _import_remapping_hash: Optional[int] = None _cached_project_path: Optional[Path] = None - _cached_import_map: Dict[str, str] = {} - _libraries: Dict[str, Dict[str, AddressType]] = {} - _contracts_needing_libraries: Set[Path] = set() + _cached_import_map: dict[str, str] = {} + _libraries: dict[str, dict[str, AddressType]] = {} + _contracts_needing_libraries: set[Path] = set() @property def name(self) -> str: @@ -112,11 +113,11 @@ def config(self) -> SolidityConfig: return cast(SolidityConfig, self.config_manager.get_config(self.name)) @property - def libraries(self) -> Dict[str, Dict[str, AddressType]]: + def libraries(self) -> dict[str, dict[str, AddressType]]: return self._libraries @cached_property - def available_versions(self) -> List[Version]: + def available_versions(self) -> list[Version]: # NOTE: Package version should already be included in available versions try: return get_installable_solc_versions() @@ -126,7 +127,7 @@ def available_versions(self) -> List[Version]: return [] @property - def installed_versions(self) -> List[Version]: + def installed_versions(self) -> list[Version]: """ Returns a lis of installed version WITHOUT their commit hashes. @@ -219,7 +220,7 @@ def add_library(self, *contracts: ContractInstance): self._contracts_needing_libraries = set() - def get_versions(self, all_paths: Sequence[Path]) -> Set[str]: + def get_versions(self, all_paths: Iterable[Path]) -> set[str]: versions = set() for path in all_paths: # Make sure we have the compiler available to compile this @@ -229,7 +230,7 @@ def get_versions(self, all_paths: Sequence[Path]) -> Set[str]: return versions - def get_import_remapping(self, base_path: Optional[Path] = None) -> Dict[str, str]: + def get_import_remapping(self, base_path: Optional[Path] = None) -> dict[str, str]: """ Config remappings like ``'@import_name=path/to/dependency'`` parsed here as ``{'@import_name': 'path/to/dependency'}``. @@ -407,17 +408,17 @@ def _add_dependencies( ) def get_compiler_settings( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[Version, Dict]: + self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + ) -> dict[Version, dict]: base_path = base_path or self.config_manager.contracts_folder files_by_solc_version = self.get_version_map(contract_filepaths, base_path=base_path) if not files_by_solc_version: return {} import_remappings = self.get_import_remapping(base_path=base_path) - settings: Dict = {} + settings: dict = {} for solc_version, sources in files_by_solc_version.items(): - version_settings: Dict[str, Union[Any, List[Any]]] = { + version_settings: dict[str, Union[Any, list[Any]]] = { "optimizer": {"enabled": self.settings.optimize, "runs": DEFAULT_OPTIMIZATION_RUNS}, "outputSelection": { str(get_relative_path(p, base_path)): {"*": OUTPUT_SELECTION, "": ["ast"]} @@ -447,8 +448,8 @@ def get_compiler_settings( return settings def _get_used_remappings( - self, sources, remappings: Dict[str, str], base_path: Optional[Path] = None - ) -> Dict[str, str]: + self, sources, remappings: dict[str, str], base_path: Optional[Path] = None + ) -> dict[str, str]: base_path = base_path or self.project_manager.contracts_folder remappings = remappings or self.get_import_remapping(base_path=base_path) if not remappings: @@ -473,8 +474,8 @@ def _get_used_remappings( } def get_standard_input_json( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[Version, Dict]: + self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + ) -> dict[Version, dict]: base_path = base_path or self.config_manager.contracts_folder files_by_solc_version = self.get_version_map(contract_filepaths, base_path=base_path) settings = self.get_compiler_settings(contract_filepaths, base_path) @@ -510,18 +511,18 @@ def get_standard_input_json( return input_jsons def compile( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> List[ContractType]: + self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + ) -> Iterator[ContractType]: base_path = base_path or self.config_manager.contracts_folder - contract_versions: Dict[str, Version] = {} - contract_types: List[ContractType] = [] + contract_versions: dict[str, Version] = {} + contract_types: list[ContractType] = [] input_jsons = self.get_standard_input_json(contract_filepaths, base_path=base_path) for solc_version, input_json in input_jsons.items(): logger.info(f"Compiling using Solidity compiler '{solc_version}'.") cleaned_version = Version(solc_version.base_version) solc_binary = get_executable(version=cleaned_version) - arguments: Dict = {"solc_binary": solc_binary, "solc_version": cleaned_version} + arguments: dict = {"solc_binary": solc_binary, "solc_version": cleaned_version} if solc_version >= Version("0.6.9"): arguments["base_path"] = base_path @@ -543,7 +544,7 @@ def compile( if not contracts: continue - input_contract_names: List[str] = [] + input_contract_names: list[str] = [] for source_id, contracts_out in contracts.items(): for name, _ in contracts_out.items(): # Filter source files that the user did not ask for, such as @@ -606,7 +607,7 @@ def compile( contract_versions[contract_name] = solc_version # Output compiler data used. - compilers_used: Dict[Version, Compiler] = {} + compilers_used: dict[Version, Compiler] = {} for ct in contract_types: if not ct.name: # Won't happen, but just for mypy. @@ -632,7 +633,9 @@ def compile( compilers_ls = list(compilers_used.values()) self.project_manager.local_project.add_compiler_data(compilers_ls) - return contract_types + # Yield contract-types afterward to ensure we yield only the latest types. + # This avoids collision errors for shared imported contracts across versions. + yield from contract_types def compile_code( self, @@ -692,14 +695,14 @@ def compile_code( def _get_unmapped_imports( self, - contract_filepaths: Sequence[Path], + contract_filepaths: Iterable[Path], base_path: Optional[Path] = None, - ) -> Dict[str, List[Tuple[str, str]]]: + ) -> dict[str, list[tuple[str, str]]]: contracts_path = base_path or self.config_manager.contracts_folder import_remapping = self.get_import_remapping(base_path=contracts_path) contract_filepaths_set = verify_contract_filepaths(contract_filepaths) - imports_dict: Dict[str, List[Tuple[str, str]]] = {} + imports_dict: dict[str, list[tuple[str, str]]] = {} for src_path, import_strs in get_import_lines(contract_filepaths_set).items(): import_list = [] for import_str in import_strs: @@ -722,13 +725,13 @@ def _get_unmapped_imports( def get_imports( self, - contract_filepaths: Sequence[Path], + contract_filepaths: Iterable[Path], base_path: Optional[Path] = None, - ) -> Dict[str, List[str]]: + ) -> dict[str, list[str]]: contracts_path = base_path or self.config_manager.contracts_folder - def build_map(paths: Set[Path], prev: Optional[Dict] = None) -> Dict[str, List[str]]: - result: Dict[str, List[str]] = prev or {} + def build_map(paths: set[Path], prev: Optional[dict] = None) -> dict[str, list[str]]: + result: dict[str, list[str]] = prev or {} for src_path, import_strs in get_import_lines(paths).items(): source_id = str(get_relative_path(src_path, contracts_path)) @@ -754,13 +757,13 @@ def build_map(paths: Set[Path], prev: Optional[Dict] = None) -> Dict[str, List[s def get_version_map( self, - contract_filepaths: Union[Path, Sequence[Path]], + contract_filepaths: Union[Path, Iterable[Path]], base_path: Optional[Path] = None, - ) -> Dict[Version, Set[Path]]: + ) -> dict[Version, set[Path]]: # Ensure `.cache` folder is built before getting version map. self.get_import_remapping(base_path=base_path) - if not isinstance(contract_filepaths, Sequence): + if not isinstance(contract_filepaths, Iterable): contract_filepaths = [contract_filepaths] base_path = base_path or self.project_manager.contracts_folder @@ -798,7 +801,7 @@ def get_version_map( install_solc(latest, show_progress=True) # Adjust best-versions based on imports. - files_by_solc_version: Dict[Version, Set[Path]] = {} + files_by_solc_version: dict[Version, set[Path]] = {} for source_file_path in source_paths_to_get: solc_version = self._get_best_version(source_file_path, source_by_pragma_spec) imported_source_paths = self._get_imported_source_paths( @@ -859,9 +862,9 @@ def _get_imported_source_paths( self, path: Path, base_path: Path, - imports: Dict, - source_ids_checked: Optional[List[str]] = None, - ) -> Set[Path]: + imports: dict, + source_ids_checked: Optional[list[str]] = None, + ) -> set[Path]: source_ids_checked = source_ids_checked or [] source_identifier = str(get_relative_path(path, base_path)) if source_identifier in source_ids_checked: @@ -910,7 +913,7 @@ def _get_pramga_spec_from_str(self, source_str: str) -> Optional[SpecifierSet]: return pragma_spec - def _get_best_version(self, path: Path, source_by_pragma_spec: Dict) -> Version: + def _get_best_version(self, path: Path, source_by_pragma_spec: dict) -> Version: compiler_version: Optional[Version] = None if pragma_spec := source_by_pragma_spec.get(path): if selected := select_version(pragma_spec, self.installed_versions): @@ -1037,7 +1040,7 @@ def get_first_version_pragma(source: str) -> str: return "" -def get_licenses(src: str) -> List[Tuple[str, str]]: +def get_licenses(src: str) -> list[tuple[str, str]]: return LICENSES_PATTERN.findall(src) @@ -1084,7 +1087,7 @@ def process_licenses(contract: str) -> str: return contract_with_single_license -def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]: +def _get_sol_panic(revert_message: str) -> Optional[type[RuntimeErrorUnion]]: if revert_message.startswith(RUNTIME_ERROR_CODE_PREFIX): # ape-geth (style) plugins show the hex with the Panic ABI prefix. error_type_val = int( @@ -1101,7 +1104,7 @@ def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]: def _import_str_to_source_id( - _import_str: str, source_path: Path, base_path: Path, import_remapping: Dict[str, str] + _import_str: str, source_path: Path, base_path: Path, import_remapping: dict[str, str] ) -> str: quote = '"' if '"' in _import_str else "'" @@ -1118,18 +1121,18 @@ def _import_str_to_source_id( source_id_value = str(get_relative_path(path, base_path)) # Get all matches. - matches: List[Tuple[str, str]] = [] + import_matches: list[tuple[str, str]] = [] for key, value in import_remapping.items(): if key not in source_id_value: continue - matches.append((key, value)) + import_matches.append((key, value)) - if not matches: + if not import_matches: return source_id_value # Convert remapping list back to source using longest match (most exact). - key, value = max(matches, key=lambda x: len(x[0])) + key, value = max(import_matches, key=lambda x: len(x[0])) sections = [s for s in source_id_value.split(key) if s] depth = len(sections) - 1 source_id_value = "" @@ -1147,5 +1150,5 @@ def _import_str_to_source_id( return source_id_value -def _try_max(ls: List[Any]): +def _try_max(ls: list[Any]): return max(ls) if ls else None diff --git a/ape_solidity/exceptions.py b/ape_solidity/exceptions.py index f059a33..5a11aae 100644 --- a/ape_solidity/exceptions.py +++ b/ape_solidity/exceptions.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Dict, Type, Union +from typing import Union from ape.exceptions import CompilerError, ConfigError, ContractLogicError from ape.logging import LogLevel, logger @@ -169,7 +169,7 @@ def __init__(self, **kwargs): PopOnEmptyArrayError, ZeroInitializedVariableError, ] -RUNTIME_ERROR_MAP: Dict[RuntimeErrorType, Type[RuntimeErrorUnion]] = { +RUNTIME_ERROR_MAP: dict[RuntimeErrorType, type[RuntimeErrorUnion]] = { RuntimeErrorType.ASSERTION_ERROR: SolidityAssertionError, RuntimeErrorType.ARITHMETIC_UNDER_OR_OVERFLOW: SolidityArithmeticError, RuntimeErrorType.DIVISION_BY_ZERO_ERROR: DivisionByZeroError, diff --git a/pyproject.toml b/pyproject.toml index 5e7fecd..25c0ffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ write_to = "ape_solidity/version.py" [tool.black] line-length = 100 -target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +target-version = ['py39', 'py310', 'py311', 'py312'] include = '\.pyi?$' [tool.pytest.ini_options] diff --git a/setup.py b/setup.py index df01162..ddb94e3 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ "packaging", # Use the version ape requires "requests", ], - python_requires=">=3.8,<4", + python_requires=">=3.9,<4", extras_require=extras_require, py_modules=["ape_solidity"], license="Apache-2.0", @@ -91,7 +91,6 @@ "Operating System :: MacOS", "Operating System :: POSIX", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0a7bc3d..83a9d67 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -53,7 +53,11 @@ def test_compile_performance(benchmark, compiler, project): See https://pytest-benchmark.readthedocs.io/en/latest/ """ source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = benchmark.pedantic(compiler.compile, args=([source_path],), rounds=1) + result = benchmark.pedantic( + lambda *args, **kwargs: list(compiler.compile(*args, **kwargs)), + args=([source_path],), + rounds=1, + ) assert len(result) > 0 @@ -69,13 +73,13 @@ def test_compile_when_offline(project, compiler, mocker): # Using a non-specific contract - doesn't matter too much which one. source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = compiler.compile([source_path]) + result = list(compiler.compile([source_path])) assert len(result) > 0, "Nothing got compiled." def test_compile_multiple_definitions_in_source(project, compiler): source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = compiler.compile([source_path]) + result = list(compiler.compile([source_path])) assert len(result) == 2 assert [r.name for r in result] == ["IMultipleDefinitions", "MultipleDefinitions"] assert all(r.source_id == "MultipleDefinitions.sol" for r in result) @@ -92,7 +96,7 @@ def test_compile_specific_order(project, compiler): project.contracts_folder / "OlderVersion.sol", project.contracts_folder / "Imports.sol", ] - compiler.compile(ordered_files) + list(compiler.compile(ordered_files)) def test_compile_missing_version(project, compiler, temp_solcx_path): @@ -104,7 +108,7 @@ def test_compile_missing_version(project, compiler, temp_solcx_path): compilers installed. """ assert not solcx.get_installed_solc_versions() - contract_types = compiler.compile([project.contracts_folder / "MissingPragma.sol"]) + contract_types = list(compiler.compile([project.contracts_folder / "MissingPragma.sol"])) assert len(contract_types) == 1 installed_versions = solcx.get_installed_solc_versions() assert len(installed_versions) == 1 @@ -120,14 +124,14 @@ def test_compile_contract_with_different_name_than_file(project): def test_compile_only_returns_contract_types_for_inputs(compiler, project): # The compiler has to compile multiple files for 'Imports.sol' (it imports stuff). # However - it should only return a single contract type in this case. - contract_types = compiler.compile([project.contracts_folder / "Imports.sol"]) + contract_types = list(compiler.compile([project.contracts_folder / "Imports.sol"])) assert len(contract_types) == 1 assert contract_types[0].name == "Imports" def test_compile_vyper_contract(compiler, vyper_source_path): with raises_because_not_sol: - compiler.compile([vyper_source_path]) + list(compiler.compile([vyper_source_path])) def test_compile_just_a_struct(compiler, project): @@ -135,7 +139,7 @@ def test_compile_just_a_struct(compiler, project): Before, you would get a nasty index error, even though this is valid Solidity. The fix involved using nicer access to "contracts" in the standard output JSON. """ - contract_types = compiler.compile([project.contracts_folder / "JustAStruct.sol"]) + contract_types = list(compiler.compile([project.contracts_folder / "JustAStruct.sol"])) assert len(contract_types) == 0 @@ -477,7 +481,7 @@ def test_evm_version(compiler): def test_source_map(project, compiler): source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = compiler.compile([source_path])[-1] + result = list(compiler.compile([source_path]))[-1] assert result.sourcemap.root == "124:87:0:-:0;;;;;;;;;;;;;;;;;;;" @@ -494,7 +498,7 @@ def test_add_library(project, account, compiler, connection): def test_enrich_error_when_custom(compiler, project, owner, not_owner, connection): - compiler.compile((project.contracts_folder / "HasError.sol",)) + list(compiler.compile((project.contracts_folder / "HasError.sol",))) # Deploy so Ape know about contract type. contract = owner.deploy(project.HasError, 1) @@ -527,7 +531,7 @@ def test_enrich_error_when_builtin(project, owner, connection): # TODO: Not yet used and super slow. # def test_ast(project, compiler): # source_path = project.contracts_folder / "MultipleDefinitions.sol" -# actual = compiler.compile([source_path])[-1].ast +# actual = list(compiler.compile([source_path]))[-1].ast # fn_node = actual.children[1].children[0] # assert actual.ast_type == "SourceUnit" # assert fn_node.classification == ASTClassification.FUNCTION @@ -583,13 +587,13 @@ def test_via_ir(project, compiler): source_path.write_text(source_code) try: - compiler.compile([source_path]) + list(compiler.compile([source_path])) except Exception as e: assert "Stack too deep" in str(e) compiler.config.via_ir = True - compiler.compile([source_path]) + list(compiler.compile([source_path])) # delete source code file source_path.unlink()