diff --git a/.clang-tidy b/.clang-tidy index b876ede7..bb139f4c 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,7 +1,9 @@ --- -Checks: 'clang-diagnostic-*,clang-analyzer-*,cppcoreguidelines-*,modernize-*,bugprone-*,concurrency-*,performance-*,portability-*,-modernize-use-nodiscard,-modernize-use-trailing-return-type,-cppcoreguidelines-special-member-functions' +Checks: 'clang-diagnostic-*,clang-analyzer-*,cppcoreguidelines-*,modernize-*,bugprone-*,concurrency-*,performance-*,portability-*,-modernize-use-nodiscard,-modernize-use-trailing-return-type,-cppcoreguidelines-special-member-functions,-bugprone-easily-swappable-parameters,-bugprone-assignment-in-if-condition' WarningsAsErrors: false HeaderFilterRegex: '(build/.+)|(codon/util/.+)' AnalyzeTemporaryDtors: false FormatStyle: llvm -... +CheckOptions: + - key: cppcoreguidelines-macro-usage.CheckCapsOnly + value: '1' diff --git a/.gitignore b/.gitignore index 07dda88e..50929bc7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,10 +14,8 @@ *.so *.dylib *.pyc -build/ -install/ -install_*/ -install-*/ +build*/ +install*/ extra/python/src/jit.cpp extra/jupyter/build/ @@ -68,3 +66,4 @@ jit/codon/version.py temp/ playground/ scratch*.* +/_* diff --git a/CMakeLists.txt b/CMakeLists.txt index 8548d12b..ccd9843d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,210 +183,204 @@ add_custom_command( include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) set(CODON_HPPFILES - codon/compiler/compiler.h - codon/compiler/debug_listener.h - codon/compiler/engine.h - codon/compiler/error.h - codon/compiler/jit.h - codon/compiler/memory_manager.h - codon/dsl/dsl.h - codon/dsl/plugins.h - codon/parser/ast.h - codon/parser/ast/expr.h - codon/parser/ast/stmt.h - codon/parser/ast/types.h - codon/parser/ast/types/type.h - codon/parser/ast/types/link.h - codon/parser/ast/types/class.h - codon/parser/ast/types/function.h - codon/parser/ast/types/union.h - codon/parser/ast/types/static.h - codon/parser/ast/types/traits.h - codon/parser/cache.h - codon/parser/common.h - codon/parser/ctx.h - codon/parser/peg/peg.h - codon/parser/peg/rules.h - codon/parser/visitors/doc/doc.h - codon/parser/visitors/format/format.h - codon/parser/visitors/simplify/simplify.h - codon/parser/visitors/simplify/ctx.h - codon/parser/visitors/translate/translate.h - codon/parser/visitors/translate/translate_ctx.h - codon/parser/visitors/typecheck/typecheck.h - codon/parser/visitors/typecheck/ctx.h - codon/parser/visitors/visitor.h - codon/cir/analyze/analysis.h - codon/cir/analyze/dataflow/capture.h - codon/cir/analyze/dataflow/cfg.h - codon/cir/analyze/dataflow/dominator.h - codon/cir/analyze/dataflow/reaching.h - codon/cir/analyze/module/global_vars.h - codon/cir/analyze/module/side_effect.h - codon/cir/attribute.h - codon/cir/base.h - codon/cir/const.h - codon/cir/dsl/codegen.h - codon/cir/dsl/nodes.h - codon/cir/flow.h - codon/cir/func.h - codon/cir/instr.h - codon/cir/llvm/gpu.h - codon/cir/llvm/llvisitor.h - codon/cir/llvm/llvm.h - codon/cir/llvm/native/native.h - codon/cir/llvm/native/targets/aarch64.h - codon/cir/llvm/native/targets/target.h - codon/cir/llvm/native/targets/x86.h - codon/cir/llvm/optimize.h - codon/cir/module.h - codon/cir/pyextension.h - codon/cir/cir.h - codon/cir/transform/cleanup/canonical.h - codon/cir/transform/cleanup/dead_code.h - codon/cir/transform/cleanup/global_demote.h - codon/cir/transform/cleanup/replacer.h - codon/cir/transform/folding/const_fold.h - codon/cir/transform/folding/const_prop.h - codon/cir/transform/folding/folding.h - codon/cir/transform/folding/rule.h - codon/cir/transform/lowering/imperative.h - codon/cir/transform/lowering/pipeline.h - codon/cir/transform/numpy/numpy.h - codon/cir/transform/manager.h - codon/cir/transform/parallel/openmp.h - codon/cir/transform/parallel/schedule.h - codon/cir/transform/pass.h - codon/cir/transform/pythonic/dict.h - codon/cir/transform/pythonic/generator.h - codon/cir/transform/pythonic/io.h - codon/cir/transform/pythonic/list.h - codon/cir/transform/pythonic/str.h - codon/cir/transform/rewrite.h - codon/cir/types/types.h - codon/cir/util/cloning.h - codon/cir/util/context.h - codon/cir/util/format.h - codon/cir/util/inlining.h - codon/cir/util/irtools.h - codon/cir/util/iterators.h - codon/cir/util/matching.h - codon/cir/util/operator.h - codon/cir/util/outlining.h - codon/cir/util/packs.h - codon/cir/util/side_effect.h - codon/cir/util/visitor.h - codon/cir/value.h - codon/cir/var.h - codon/util/common.h - codon/compiler/jit_extern.h) + codon/compiler/compiler.h + codon/compiler/debug_listener.h + codon/compiler/engine.h + codon/compiler/error.h + codon/compiler/jit.h + codon/compiler/memory_manager.h + codon/dsl/dsl.h + codon/dsl/plugins.h + codon/parser/ast.h + codon/parser/match.h + codon/parser/ast/node.h + codon/parser/ast/expr.h + codon/parser/ast/stmt.h + codon/parser/ast/types.h + codon/parser/ast/attr.h + codon/parser/ast/types/type.h + codon/parser/ast/types/link.h + codon/parser/ast/types/class.h + codon/parser/ast/types/function.h + codon/parser/ast/types/union.h + codon/parser/ast/types/static.h + codon/parser/ast/types/traits.h + codon/parser/cache.h + codon/parser/common.h + codon/parser/ctx.h + codon/parser/peg/peg.h + codon/parser/peg/rules.h + codon/parser/visitors/doc/doc.h + codon/parser/visitors/format/format.h + codon/parser/visitors/scoping/scoping.h + codon/parser/visitors/translate/translate.h + codon/parser/visitors/translate/translate_ctx.h + codon/parser/visitors/typecheck/typecheck.h + codon/parser/visitors/typecheck/ctx.h + codon/parser/visitors/visitor.h + codon/cir/analyze/analysis.h + codon/cir/analyze/dataflow/capture.h + codon/cir/analyze/dataflow/cfg.h + codon/cir/analyze/dataflow/dominator.h + codon/cir/analyze/dataflow/reaching.h + codon/cir/analyze/module/global_vars.h + codon/cir/analyze/module/side_effect.h + codon/cir/attribute.h + codon/cir/base.h + codon/cir/const.h + codon/cir/dsl/codegen.h + codon/cir/dsl/nodes.h + codon/cir/flow.h + codon/cir/func.h + codon/cir/instr.h + codon/cir/llvm/gpu.h + codon/cir/llvm/llvisitor.h + codon/cir/llvm/llvm.h + codon/cir/llvm/optimize.h + codon/cir/module.h + codon/cir/pyextension.h + codon/cir/cir.h + codon/cir/transform/cleanup/canonical.h + codon/cir/transform/cleanup/dead_code.h + codon/cir/transform/cleanup/global_demote.h + codon/cir/transform/cleanup/replacer.h + codon/cir/transform/folding/const_fold.h + codon/cir/transform/folding/const_prop.h + codon/cir/transform/folding/folding.h + codon/cir/transform/folding/rule.h + codon/cir/transform/lowering/imperative.h + codon/cir/transform/lowering/pipeline.h + codon/cir/transform/manager.h + codon/cir/transform/parallel/openmp.h + codon/cir/transform/parallel/schedule.h + codon/cir/transform/pass.h + codon/cir/transform/pythonic/dict.h + codon/cir/transform/pythonic/generator.h + codon/cir/transform/pythonic/io.h + codon/cir/transform/pythonic/list.h + codon/cir/transform/pythonic/str.h + codon/cir/transform/rewrite.h + codon/cir/types/types.h + codon/cir/util/cloning.h + codon/cir/util/context.h + codon/cir/util/format.h + codon/cir/util/inlining.h + codon/cir/util/irtools.h + codon/cir/util/iterators.h + codon/cir/util/matching.h + codon/cir/util/operator.h + codon/cir/util/outlining.h + codon/cir/util/packs.h + codon/cir/util/side_effect.h + codon/cir/util/visitor.h + codon/cir/value.h + codon/cir/llvm/native/native.h + codon/cir/llvm/native/targets/aarch64.h + codon/cir/llvm/native/targets/target.h + codon/cir/llvm/native/targets/x86.h + codon/cir/transform/numpy/numpy.h + codon/cir/var.h + codon/util/common.h + codon/util/serialize.h + codon/compiler/jit_extern.h) set(CODON_CPPFILES - codon/compiler/compiler.cpp - codon/compiler/debug_listener.cpp - codon/compiler/engine.cpp - codon/compiler/error.cpp - codon/compiler/jit.cpp - codon/compiler/memory_manager.cpp - codon/dsl/plugins.cpp - codon/parser/ast/expr.cpp - codon/parser/ast/stmt.cpp - codon/parser/ast/types/type.cpp - codon/parser/ast/types/link.cpp - codon/parser/ast/types/class.cpp - codon/parser/ast/types/function.cpp - codon/parser/ast/types/union.cpp - codon/parser/ast/types/static.cpp - codon/parser/ast/types/traits.cpp - codon/parser/cache.cpp - codon/parser/common.cpp - codon/parser/peg/peg.cpp - codon/parser/visitors/doc/doc.cpp - codon/parser/visitors/format/format.cpp - codon/parser/visitors/simplify/simplify.cpp - codon/parser/visitors/simplify/ctx.cpp - codon/parser/visitors/simplify/assign.cpp - codon/parser/visitors/simplify/basic.cpp - codon/parser/visitors/simplify/call.cpp - codon/parser/visitors/simplify/class.cpp - codon/parser/visitors/simplify/collections.cpp - codon/parser/visitors/simplify/cond.cpp - codon/parser/visitors/simplify/function.cpp - codon/parser/visitors/simplify/access.cpp - codon/parser/visitors/simplify/import.cpp - codon/parser/visitors/simplify/loops.cpp - codon/parser/visitors/simplify/op.cpp - codon/parser/visitors/simplify/error.cpp - codon/parser/visitors/translate/translate.cpp - codon/parser/visitors/translate/translate_ctx.cpp - codon/parser/visitors/typecheck/typecheck.cpp - codon/parser/visitors/typecheck/infer.cpp - codon/parser/visitors/typecheck/ctx.cpp - codon/parser/visitors/typecheck/assign.cpp - codon/parser/visitors/typecheck/basic.cpp - codon/parser/visitors/typecheck/call.cpp - codon/parser/visitors/typecheck/class.cpp - codon/parser/visitors/typecheck/collections.cpp - codon/parser/visitors/typecheck/cond.cpp - codon/parser/visitors/typecheck/function.cpp - codon/parser/visitors/typecheck/access.cpp - codon/parser/visitors/typecheck/loops.cpp - codon/parser/visitors/typecheck/op.cpp - codon/parser/visitors/typecheck/error.cpp - codon/parser/visitors/visitor.cpp - codon/cir/attribute.cpp - codon/cir/analyze/analysis.cpp - codon/cir/analyze/dataflow/capture.cpp - codon/cir/analyze/dataflow/cfg.cpp - codon/cir/analyze/dataflow/dominator.cpp - codon/cir/analyze/dataflow/reaching.cpp - codon/cir/analyze/module/global_vars.cpp - codon/cir/analyze/module/side_effect.cpp - codon/cir/base.cpp - codon/cir/const.cpp - codon/cir/dsl/nodes.cpp - codon/cir/flow.cpp - codon/cir/func.cpp - codon/cir/instr.cpp - codon/cir/llvm/gpu.cpp - codon/cir/llvm/llvisitor.cpp - codon/cir/llvm/native/native.cpp - codon/cir/llvm/native/targets/aarch64.cpp - codon/cir/llvm/native/targets/x86.cpp - codon/cir/llvm/optimize.cpp - codon/cir/module.cpp - codon/cir/transform/cleanup/canonical.cpp - codon/cir/transform/cleanup/dead_code.cpp - codon/cir/transform/cleanup/global_demote.cpp - codon/cir/transform/cleanup/replacer.cpp - codon/cir/transform/folding/const_fold.cpp - codon/cir/transform/folding/const_prop.cpp - codon/cir/transform/folding/folding.cpp - codon/cir/transform/lowering/imperative.cpp - codon/cir/transform/lowering/pipeline.cpp - codon/cir/transform/numpy/expr.cpp - codon/cir/transform/numpy/forward.cpp - codon/cir/transform/numpy/numpy.cpp - codon/cir/transform/manager.cpp - codon/cir/transform/parallel/openmp.cpp - codon/cir/transform/parallel/schedule.cpp - codon/cir/transform/pass.cpp - codon/cir/transform/pythonic/dict.cpp - codon/cir/transform/pythonic/generator.cpp - codon/cir/transform/pythonic/io.cpp - codon/cir/transform/pythonic/list.cpp - codon/cir/transform/pythonic/str.cpp - codon/cir/types/types.cpp - codon/cir/util/cloning.cpp - codon/cir/util/format.cpp - codon/cir/util/inlining.cpp - codon/cir/util/irtools.cpp - codon/cir/util/matching.cpp - codon/cir/util/outlining.cpp - codon/cir/util/side_effect.cpp - codon/cir/util/visitor.cpp - codon/cir/value.cpp - codon/cir/var.cpp - codon/util/common.cpp) + codon/compiler/compiler.cpp + codon/compiler/debug_listener.cpp + codon/compiler/engine.cpp + codon/compiler/error.cpp + codon/compiler/jit.cpp + codon/compiler/memory_manager.cpp + codon/dsl/plugins.cpp + codon/parser/ast/expr.cpp + codon/parser/ast/attr.cpp + codon/parser/ast/stmt.cpp + codon/parser/ast/types/type.cpp + codon/parser/ast/types/link.cpp + codon/parser/ast/types/class.cpp + codon/parser/ast/types/function.cpp + codon/parser/ast/types/union.cpp + codon/parser/ast/types/static.cpp + codon/parser/ast/types/traits.cpp + codon/parser/cache.cpp + codon/parser/match.cpp + codon/parser/common.cpp + codon/parser/peg/peg.cpp + codon/parser/visitors/doc/doc.cpp + codon/parser/visitors/format/format.cpp + codon/parser/visitors/scoping/scoping.cpp + codon/parser/visitors/translate/translate.cpp + codon/parser/visitors/translate/translate_ctx.cpp + codon/parser/visitors/typecheck/typecheck.cpp + codon/parser/visitors/typecheck/infer.cpp + codon/parser/visitors/typecheck/ctx.cpp + codon/parser/visitors/typecheck/assign.cpp + codon/parser/visitors/typecheck/basic.cpp + codon/parser/visitors/typecheck/call.cpp + codon/parser/visitors/typecheck/class.cpp + codon/parser/visitors/typecheck/collections.cpp + codon/parser/visitors/typecheck/cond.cpp + codon/parser/visitors/typecheck/function.cpp + codon/parser/visitors/typecheck/access.cpp + codon/parser/visitors/typecheck/import.cpp + codon/parser/visitors/typecheck/loops.cpp + codon/parser/visitors/typecheck/op.cpp + codon/parser/visitors/typecheck/error.cpp + codon/parser/visitors/typecheck/special.cpp + codon/parser/visitors/visitor.cpp + codon/cir/attribute.cpp + codon/cir/analyze/analysis.cpp + codon/cir/analyze/dataflow/capture.cpp + codon/cir/analyze/dataflow/cfg.cpp + codon/cir/analyze/dataflow/dominator.cpp + codon/cir/analyze/dataflow/reaching.cpp + codon/cir/analyze/module/global_vars.cpp + codon/cir/analyze/module/side_effect.cpp + codon/cir/base.cpp + codon/cir/const.cpp + codon/cir/dsl/nodes.cpp + codon/cir/flow.cpp + codon/cir/func.cpp + codon/cir/instr.cpp + codon/cir/llvm/gpu.cpp + codon/cir/llvm/llvisitor.cpp + codon/cir/llvm/optimize.cpp + codon/cir/module.cpp + codon/cir/transform/cleanup/canonical.cpp + codon/cir/transform/cleanup/dead_code.cpp + codon/cir/transform/cleanup/global_demote.cpp + codon/cir/transform/cleanup/replacer.cpp + codon/cir/transform/folding/const_fold.cpp + codon/cir/transform/folding/const_prop.cpp + codon/cir/transform/folding/folding.cpp + codon/cir/transform/lowering/imperative.cpp + codon/cir/transform/lowering/pipeline.cpp + codon/cir/transform/manager.cpp + codon/cir/transform/parallel/openmp.cpp + codon/cir/transform/parallel/schedule.cpp + codon/cir/transform/pass.cpp + codon/cir/transform/pythonic/dict.cpp + codon/cir/transform/pythonic/generator.cpp + codon/cir/transform/pythonic/io.cpp + codon/cir/transform/pythonic/list.cpp + codon/cir/transform/pythonic/str.cpp + codon/cir/types/types.cpp + codon/cir/util/cloning.cpp + codon/cir/util/format.cpp + codon/cir/util/inlining.cpp + codon/cir/util/irtools.cpp + codon/cir/util/matching.cpp + codon/cir/util/outlining.cpp + codon/cir/util/side_effect.cpp + codon/cir/util/visitor.cpp + codon/cir/value.cpp + codon/cir/var.cpp + codon/cir/llvm/native/native.cpp + codon/cir/llvm/native/targets/aarch64.cpp + codon/cir/llvm/native/targets/x86.cpp + codon/cir/transform/numpy/expr.cpp + codon/cir/transform/numpy/forward.cpp + codon/cir/transform/numpy/numpy.cpp + codon/util/common.cpp) add_library(codonc SHARED ${CODON_HPPFILES}) target_include_directories(codonc PRIVATE ${peglib_SOURCE_DIR} ${toml_SOURCE_DIR}/include @@ -435,9 +429,9 @@ llvm_map_components_to_libnames( Vectorize Passes) if(APPLE) - target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt dl codonrt) + target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt tser dl codonrt) else() - target_link_libraries(codonc PRIVATE ${STATIC_LIBCPP} ${LLVM_LIBS} fmt dl codonrt) + target_link_libraries(codonc PRIVATE ${STATIC_LIBCPP} ${LLVM_LIBS} fmt tser dl codonrt) endif() # Gather headers @@ -482,7 +476,7 @@ add_dependencies(libs codonrt codonc) # Codon command-line tool add_executable(codon codon/app/main.cpp) -target_link_libraries(codon PUBLIC ${STATIC_LIBCPP} fmt codonc codon_jupyter Threads::Threads) +target_link_libraries(codon PUBLIC ${STATIC_LIBCPP} fmt tser codonc codon_jupyter Threads::Threads) # Codon test Download and unpack googletest at configure time include(FetchContent) @@ -512,9 +506,8 @@ set(CODON_TEST_CPPFILES test/cir/var.cpp test/types.cpp) add_executable(codon_test ${CODON_TEST_CPPFILES}) -target_include_directories(codon_test PRIVATE test/cir - "${gc_SOURCE_DIR}/include") -target_link_libraries(codon_test fmt codonc codonrt gtest_main) +target_include_directories(codon_test PRIVATE test/cir "${gc_SOURCE_DIR}/include") +target_link_libraries(codon_test fmt tser codonc codonrt gtest_main) target_compile_definitions(codon_test PRIVATE TEST_DIR="${CMAKE_CURRENT_SOURCE_DIR}/test") diff --git a/cmake/deps.cmake b/cmake/deps.cmake index a5cb67e9..5373088c 100644 --- a/cmake/deps.cmake +++ b/cmake/deps.cmake @@ -12,6 +12,11 @@ CPMAddPackage( GIT_TAG codon OPTIONS "BUILD_TESTS OFF") +CPMAddPackage( + NAME tser + GITHUB_REPOSITORY "KonanM/tser" + GIT_TAG v1.2) + CPMAddPackage( NAME fmt GITHUB_REPOSITORY "fmtlib/fmt" diff --git a/codon/app/main.cpp b/codon/app/main.cpp index 54419810..67c839c1 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -64,16 +64,18 @@ std::string makeOutputFilename(const std::string &filename, void display(const codon::error::ParserErrorInfo &e) { using codon::MessageGroupPos; - for (auto &group : e) { + for (auto &group : e.getErrors()) { + int i = 0; for (auto &msg : group) { MessageGroupPos pos = MessageGroupPos::NONE; - if (&msg == &group.front()) { + if (i == 0) { pos = MessageGroupPos::HEAD; - } else if (&msg == &group.back()) { + } else if (i == group.size() - 1) { pos = MessageGroupPos::LAST; } else { pos = MessageGroupPos::MID; } + i++; codon::compilationError(msg.getMessage(), msg.getFile(), msg.getLine(), msg.getColumn(), msg.getLength(), msg.getErrorCode(), /*terminate=*/false, pos); diff --git a/codon/cir/analyze/dataflow/capture.cpp b/codon/cir/analyze/dataflow/capture.cpp index 12d42820..8079c164 100644 --- a/codon/cir/analyze/dataflow/capture.cpp +++ b/codon/cir/analyze/dataflow/capture.cpp @@ -176,8 +176,9 @@ struct DerivedSet { if (!shouldTrack(v)) return; - if (v->isGlobal()) + if (v->isGlobal()) { setExternCaptured(); + } auto id = v->getId(); if (shouldArgCapture && root && id != root->getId()) { @@ -689,7 +690,6 @@ std::vector CaptureContext::get(const Func *func) { if (isA(func)) { bool isTupleNew = func->getUnmangledName() == "__new__" && isA(util::getReturnType(func)); - bool isPromise = func->getUnmangledName() == "__promise__" && std::distance(func->arg_begin(), func->arg_end()) == 1 && isA(func->arg_front()->getType()); diff --git a/codon/cir/attribute.cpp b/codon/cir/attribute.cpp index 6dd7c096..cca57969 100644 --- a/codon/cir/attribute.cpp +++ b/codon/cir/attribute.cpp @@ -10,6 +10,17 @@ namespace codon { namespace ir { +const std::string StringValueAttribute::AttributeName = "svAttribute"; + +const std::string IntValueAttribute::AttributeName = "i64Attribute"; + +const std::string StringListAttribute::AttributeName = "slAttribute"; + +std::ostream &StringListAttribute::doFormat(std::ostream &os) const { + fmt::print(os, FMT_STRING("{}"), fmt::join(values.begin(), values.end(), ",")); + return os; +} + const std::string KeyValueAttribute::AttributeName = "kvAttribute"; bool KeyValueAttribute::has(const std::string &key) const { @@ -39,6 +50,22 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const { return os; } +const std::string PythonWrapperAttribute::AttributeName = "pythonWrapperAttribute"; + +std::unique_ptr PythonWrapperAttribute::clone(util::CloneVisitor &cv) const { + return std::make_unique(cast(cv.clone(original))); +} + +std::unique_ptr +PythonWrapperAttribute::forceClone(util::CloneVisitor &cv) const { + return std::make_unique(cv.forceClone(original)); +} + +std::ostream &PythonWrapperAttribute::doFormat(std::ostream &os) const { + fmt::print(os, FMT_STRING("(pywrap {})"), original->referenceString()); + return os; +} + const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute"; const std::string DocstringAttribute::AttributeName = "docstringAttribute"; @@ -180,4 +207,13 @@ std::ostream &PartialFunctionAttribute::doFormat(std::ostream &os) const { } } // namespace ir + +std::map> +clone(const std::map> &t) { + std::map> r; + for (auto &[k, v] : t) + r[k] = v->clone(); + return r; +} + } // namespace codon diff --git a/codon/cir/attribute.h b/codon/cir/attribute.h index 858ec7b3..f701e71e 100644 --- a/codon/cir/attribute.h +++ b/codon/cir/attribute.h @@ -33,7 +33,14 @@ struct Attribute { } /// @return a clone of the attribute - virtual std::unique_ptr clone(util::CloneVisitor &cv) const = 0; + virtual std::unique_ptr clone() const { + return std::make_unique(); + } + + /// @return a clone of the attribute + virtual std::unique_ptr clone(util::CloneVisitor &cv) const { + return clone(); + } /// @return a clone of the attribute virtual std::unique_ptr forceClone(util::CloneVisitor &cv) const { @@ -41,7 +48,7 @@ struct Attribute { } private: - virtual std::ostream &doFormat(std::ostream &os) const = 0; + virtual std::ostream &doFormat(std::ostream &os) const { return os; } }; /// Attribute containing SrcInfo @@ -56,7 +63,7 @@ struct SrcInfoAttribute : public Attribute { /// @param info the source info explicit SrcInfoAttribute(codon::SrcInfo info) : info(std::move(info)) {} - std::unique_ptr clone(util::CloneVisitor &cv) const override { + std::unique_ptr clone() const override { return std::make_unique(*this); } @@ -64,6 +71,24 @@ struct SrcInfoAttribute : public Attribute { std::ostream &doFormat(std::ostream &os) const override { return os << info; } }; +/// Attribute containing docstring from source +struct StringValueAttribute : public Attribute { + static const std::string AttributeName; + + std::string value; + + StringValueAttribute() = default; + /// Constructs a StringValueAttribute. + explicit StringValueAttribute(const std::string &value) : value(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + std::ostream &doFormat(std::ostream &os) const override { return os << value; } +}; + /// Attribute containing docstring from source struct DocstringAttribute : public Attribute { static const std::string AttributeName; @@ -76,7 +101,7 @@ struct DocstringAttribute : public Attribute { /// @param docstring the docstring explicit DocstringAttribute(const std::string &docstring) : docstring(docstring) {} - std::unique_ptr clone(util::CloneVisitor &cv) const override { + std::unique_ptr clone() const override { return std::make_unique(*this); } @@ -106,7 +131,7 @@ struct KeyValueAttribute : public Attribute { /// string if none std::string get(const std::string &key) const; - std::unique_ptr clone(util::CloneVisitor &cv) const override { + std::unique_ptr clone() const override { return std::make_unique(*this); } @@ -114,6 +139,27 @@ struct KeyValueAttribute : public Attribute { std::ostream &doFormat(std::ostream &os) const override; }; +/// Attribute containing function information +struct StringListAttribute : public Attribute { + static const std::string AttributeName; + + /// attributes map + std::vector values; + + StringListAttribute() = default; + /// Constructs a StringListAttribute. + /// @param attributes the map of attributes + explicit StringListAttribute(std::vector values) + : values(std::move(values)) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + /// Attribute containing type member information struct MemberAttribute : public Attribute { static const std::string AttributeName; @@ -127,7 +173,7 @@ struct MemberAttribute : public Attribute { explicit MemberAttribute(std::map memberSrcInfo) : memberSrcInfo(std::move(memberSrcInfo)) {} - std::unique_ptr clone(util::CloneVisitor &cv) const override { + std::unique_ptr clone() const override { return std::make_unique(*this); } @@ -135,6 +181,30 @@ struct MemberAttribute : public Attribute { std::ostream &doFormat(std::ostream &os) const override; }; +/// Attribute used to mark Python wrappers of Codon functions +struct PythonWrapperAttribute : public Attribute { + static const std::string AttributeName; + + /// the function being wrapped + Func *original; + + /// Constructs a PythonWrapperAttribute. + /// @param original the function being wrapped + explicit PythonWrapperAttribute(Func *original) : original(original) {} + + bool needsClone() const override { return false; } + + std::unique_ptr clone() const override { + seqassertn(false, "cannot operate without CloneVisitor"); + return nullptr; + } + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + /// Attribute attached to IR structures corresponding to tuple literals struct TupleLiteralAttribute : public Attribute { static const std::string AttributeName; @@ -145,6 +215,10 @@ struct TupleLiteralAttribute : public Attribute { explicit TupleLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} + std::unique_ptr clone() const override { + seqassertn(false, "cannot operate without CloneVisitor"); + return nullptr; + } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; @@ -170,6 +244,10 @@ struct ListLiteralAttribute : public Attribute { explicit ListLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} + std::unique_ptr clone() const override { + seqassertn(false, "cannot operate without CloneVisitor"); + return nullptr; + } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; @@ -187,6 +265,10 @@ struct SetLiteralAttribute : public Attribute { explicit SetLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} + std::unique_ptr clone() const override { + seqassertn(false, "cannot operate without CloneVisitor"); + return nullptr; + } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; @@ -211,6 +293,10 @@ struct DictLiteralAttribute : public Attribute { explicit DictLiteralAttribute(std::vector elements) : elements(std::move(elements)) {} + std::unique_ptr clone() const override { + seqassertn(false, "cannot operate without CloneVisitor"); + return nullptr; + } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; @@ -232,6 +318,10 @@ struct PartialFunctionAttribute : public Attribute { PartialFunctionAttribute(const std::string &name, std::vector args) : name(name), args(std::move(args)) {} + std::unique_ptr clone() const override { + seqassertn(false, "cannot operate without CloneVisitor"); + return nullptr; + } std::unique_ptr clone(util::CloneVisitor &cv) const override; std::unique_ptr forceClone(util::CloneVisitor &cv) const override; @@ -239,7 +329,28 @@ struct PartialFunctionAttribute : public Attribute { std::ostream &doFormat(std::ostream &os) const override; }; +struct IntValueAttribute : public Attribute { + static const std::string AttributeName; + + int64_t value; + + IntValueAttribute() = default; + /// Constructs a IntValueAttribute. + explicit IntValueAttribute(int64_t value) : value(value) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + std::ostream &doFormat(std::ostream &os) const override { return os << value; } +}; + } // namespace ir + +std::map> +clone(const std::map> &t); + } // namespace codon template <> struct fmt::formatter : fmt::ostream_formatter {}; diff --git a/codon/cir/base.cpp b/codon/cir/base.cpp index 4b0f9233..ac4acf4c 100644 --- a/codon/cir/base.cpp +++ b/codon/cir/base.cpp @@ -20,6 +20,10 @@ std::ostream &operator<<(std::ostream &os, const Node &other) { return util::format(os, &other); } +Node::Node(const Node &n) + : name(n.name), module(n.module), replacement(n.replacement), + attributes(codon::clone(n.attributes)) {} + int Node::replaceUsedValue(Value *old, Value *newValue) { return replaceUsedValue(old->getId(), newValue); } diff --git a/codon/cir/base.h b/codon/cir/base.h index 69333c91..71d35d39 100644 --- a/codon/cir/base.h +++ b/codon/cir/base.h @@ -48,13 +48,15 @@ class Node { private: /// the node's name std::string name; - /// key-value attribute store - std::map> attributes; /// the module Module *module = nullptr; /// a replacement, if set Node *replacement = nullptr; +protected: + /// key-value attribute store + std::map> attributes; + public: // RTTI is implemented using a port of LLVM's Extensible RTTI // For more details, see @@ -64,10 +66,15 @@ class Node { /// Constructs a node. /// @param name the node's name explicit Node(std::string name = "") : name(std::move(name)) {} + /// Constructs a node. + /// @param name the node's name + explicit Node(const Node &n); /// See LLVM documentation. static const void *nodeId() { return &NodeId; } /// See LLVM documentation. + virtual const void *dynamicNodeId() const = 0; + /// See LLVM documentation. virtual bool isConvertible(const void *other) const { if (hasReplacement()) return getActual()->isConvertible(other); @@ -94,10 +101,10 @@ class Node { /// Accepts visitors. /// @param v the visitor - virtual void accept(util::Visitor &v) = 0; + virtual void accept(util::Visitor &v) {} /// Accepts visitors. /// @param v the visitor - virtual void accept(util::ConstVisitor &v) const = 0; + virtual void accept(util::ConstVisitor &v) const {} /// Sets an attribute /// @param the attribute key @@ -150,6 +157,16 @@ class Node { return static_cast( getAttribute(AttributeType::AttributeName)); } + template + AttributeType *getAttribute(const std::string &key) { + return static_cast(getAttribute(key)); + } + template + const AttributeType *getAttribute(const std::string &key) const { + return static_cast(getAttribute(key)); + } + void eraseAttribute(const std::string &key) { attributes.erase(key); } + void cloneAttributesFrom(Node *n) { attributes = codon::clone(n->attributes); } /// @return iterator to the first attribute auto attributes_begin() { @@ -186,6 +203,12 @@ class Node { /// @param m the new module void setModule(Module *m) { getActual()->module = m; } + /// Convert a node to a string expression. + // virtual std::string toString(int) const = 0; + // virtual std::string toString() const { return toString(-1); } + // friend std::ostream &operator<<(std::ostream &os, const Node &a) { + // return out << expr.toString(); + // } friend std::ostream &operator<<(std::ostream &os, const Node &a); bool hasReplacement() const { return replacement != nullptr; } @@ -252,21 +275,23 @@ template class AcceptorExtend : public Paren /// See LLVM documentation. static const void *nodeId() { return &Derived::NodeId; } /// See LLVM documentation. - virtual bool isConvertible(const void *other) const { + const void *dynamicNodeId() const override { return &Derived::NodeId; } + /// See LLVM documentation. + virtual bool isConvertible(const void *other) const override { if (Node::hasReplacement()) return Node::getActual()->isConvertible(other); return other == nodeId() || Parent::isConvertible(other); } - void accept(util::Visitor &v) { + void accept(util::Visitor &v) override { if (Node::hasReplacement()) Node::getActual()->accept(v); else v.visit(static_cast(this)); } - void accept(util::ConstVisitor &v) const { + void accept(util::ConstVisitor &v) const override { if (Node::hasReplacement()) Node::getActual()->accept(v); else diff --git a/codon/cir/instr.cpp b/codon/cir/instr.cpp index 6c97ecf6..cb125259 100644 --- a/codon/cir/instr.cpp +++ b/codon/cir/instr.cpp @@ -27,6 +27,9 @@ types::Type *Instr::doGetType() const { return getModule()->getNoneType(); } const char AssignInstr::NodeId = 0; +AssignInstr::AssignInstr(Var *lhs, Value *rhs, std::string name) + : AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) {} + int AssignInstr::doReplaceUsedValue(id_t id, Value *newValue) { if (rhs->getId() == id) { rhs = newValue; diff --git a/codon/cir/instr.h b/codon/cir/instr.h index 2cb258ed..37beed33 100644 --- a/codon/cir/instr.h +++ b/codon/cir/instr.h @@ -42,8 +42,7 @@ class AssignInstr : public AcceptorExtend { /// @param rhs the right-hand side /// @param field the field being set, may be empty /// @param name the instruction's name - AssignInstr(Var *lhs, Value *rhs, std::string name = "") - : AcceptorExtend(std::move(name)), lhs(lhs), rhs(rhs) {} + AssignInstr(Var *lhs, Value *rhs, std::string name = ""); /// @return the left-hand side Var *getLhs() { return lhs; } diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index c1cc83b3..e87a6aea 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -22,10 +22,10 @@ namespace codon { namespace ir { namespace { -const std::string EXPORT_ATTR = "std.internal.attributes.export"; -const std::string INLINE_ATTR = "std.internal.attributes.inline"; -const std::string NOINLINE_ATTR = "std.internal.attributes.noinline"; -const std::string GPU_KERNEL_ATTR = "std.gpu.kernel"; +const std::string EXPORT_ATTR = "std.internal.attributes.export.0:0"; +const std::string INLINE_ATTR = "std.internal.attributes.inline.0:0"; +const std::string NOINLINE_ATTR = "std.internal.attributes.noinline.0:0"; +const std::string GPU_KERNEL_ATTR = "std.gpu.kernel.0:0"; const std::string MAIN_UNCLASH = ".main.unclash"; const std::string MAIN_CTOR = ".main.ctor"; @@ -1800,7 +1800,8 @@ void LLVMVisitor::visit(const InternalFunc *x) { else if (internalFuncMatchesIgnoreArgs("__new__", x)) { auto *recordType = cast(cast(x->getType())->getReturnType()); seqassertn(args.size() == std::distance(recordType->begin(), recordType->end()), - "args size does not match"); + "args size does not match: {} vs {}", args.size(), + std::distance(recordType->begin(), recordType->end())); result = llvm::UndefValue::get(getLLVMType(recordType)); for (auto i = 0; i < args.size(); i++) { result = B->CreateInsertValue(result, args[i], i); diff --git a/codon/cir/module.cpp b/codon/cir/module.cpp index 1b274fb2..d2100e0c 100644 --- a/codon/cir/module.cpp +++ b/codon/cir/module.cpp @@ -7,6 +7,7 @@ #include "codon/cir/func.h" #include "codon/parser/cache.h" +#include "codon/parser/visitors/typecheck/typecheck.h" namespace codon { namespace ir { @@ -16,22 +17,27 @@ translateGenerics(codon::ast::Cache *cache, std::vector &generic std::vector ret; for (auto &g : generics) { seqassertn(g.isStatic() || g.getTypeValue(), "generic must be static or a type"); - ret.push_back(std::make_shared( - g.isStatic() - ? std::make_shared(cache, g.getStaticValue()) - : (g.isStaticStr() ? std::make_shared( - cache, g.getStaticStringValue()) - : g.getTypeValue()->getAstType()))); + if (g.isStaticStr()) + ret.push_back(std::make_shared( + std::make_shared( + cache, g.getStaticStringValue()))); + else if (g.isStatic()) + ret.push_back(std::make_shared( + std::make_shared(cache, + g.getStaticValue()))); + else + ret.push_back(std::make_shared( + g.getTypeValue()->getAstType())); } return ret; } -std::vector +std::vector generateDummyNames(std::vector &types) { - std::vector ret; + std::vector ret; for (auto *t : types) { seqassertn(t->getAstType(), "{} must have an ast type", *t); - ret.emplace_back(t->getAstType()); + ret.emplace_back(t->getAstType().get()); } return ret; } @@ -46,8 +52,11 @@ translateArgs(codon::ast::Cache *cache, std::vector &types) { if (auto f = t->getAstType()->getFunc()) { auto *irType = cast(t); std::vector mask(std::distance(irType->begin(), irType->end()), 0); - ret.push_back(std::make_shared( - t->getAstType()->getRecord(), f, mask)); + // ast::TypecheckVisitor tv(cache->typeCtx); + // auto Expr * = tv.generatePartialCall(mask, f.get()); + // tv.transform(Expr *); + // ret.push_back(Expr *->type); + ret.push_back(t->getAstType()); } else { ret.push_back(t->getAstType()); } @@ -157,16 +166,17 @@ Func *Module::getOrRealizeMethod(types::Type *parent, const std::string &methodN auto cls = std::const_pointer_cast(parent->getAstType())->getClass(); - auto method = cache->findMethod(cls.get(), methodName, generateDummyNames(args)); + auto method = cache->findMethod(cls, methodName, generateDummyNames(args)); if (!method) return nullptr; try { return cache->realizeFunction(method, translateArgs(cache, args), translateGenerics(cache, generics), cls); } catch (const exc::ParserException &e) { - for (int i = 0; i < e.messages.size(); i++) - LOG_IR("getOrRealizeMethod parser error at {}: {}", e.locations[i], - e.messages[i]); + for (auto &trace : e.getErrors()) + for (auto &msg : trace) + LOG_IR("getOrRealizeMethod parser error at {}: {}", msg.getSrcInfo(), + msg.getMessage()); return nullptr; } } @@ -178,6 +188,8 @@ Func *Module::getOrRealizeFunc(const std::string &funcName, auto fqName = module.empty() ? funcName : fmt::format(FMT_STRING("{}.{}"), module, funcName); auto func = cache->findFunction(fqName); + if (!func) + func = cache->findFunction(fqName + ".0:0"); if (!func) return nullptr; auto arg = translateArgs(cache, args); @@ -185,8 +197,10 @@ Func *Module::getOrRealizeFunc(const std::string &funcName, try { return cache->realizeFunction(func, arg, gens); } catch (const exc::ParserException &e) { - for (int i = 0; i < e.messages.size(); i++) - LOG_IR("getOrRealizeFunc parser error at {}: {}", e.locations[i], e.messages[i]); + for (auto &trace : e.getErrors()) + for (auto &msg : trace) + LOG("getOrRealizeFunc parser error at {}: {}", msg.getSrcInfo(), + msg.getMessage()); return nullptr; } } @@ -202,8 +216,10 @@ types::Type *Module::getOrRealizeType(const std::string &typeName, try { return cache->realizeType(type, translateGenerics(cache, generics)); } catch (const exc::ParserException &e) { - for (int i = 0; i < e.messages.size(); i++) - LOG_IR("getOrRealizeType parser error at {}: {}", e.locations[i], e.messages[i]); + for (auto &trace : e.getErrors()) + for (auto &msg : trace) + LOG_IR("getOrRealizeType parser error at {}: {}", msg.getSrcInfo(), + msg.getMessage()); return nullptr; } } diff --git a/codon/cir/transform/cleanup/canonical.cpp b/codon/cir/transform/cleanup/canonical.cpp index ca29b6fc..431c4723 100644 --- a/codon/cir/transform/cleanup/canonical.cpp +++ b/codon/cir/transform/cleanup/canonical.cpp @@ -61,15 +61,15 @@ NodeRanker::Rank getRank(Node *node) { } bool isCommutativeOp(Func *fn) { - return fn && util::hasAttribute(fn, "std.internal.attributes.commutative"); + return fn && util::hasAttribute(fn, "std.internal.attributes.commutative.0:0"); } bool isAssociativeOp(Func *fn) { - return fn && util::hasAttribute(fn, "std.internal.attributes.associative"); + return fn && util::hasAttribute(fn, "std.internal.attributes.associative.0:0"); } bool isDistributiveOp(Func *fn) { - return fn && util::hasAttribute(fn, "std.internal.attributes.distributive"); + return fn && util::hasAttribute(fn, "std.internal.attributes.distributive.0:0"); } bool isInequalityOp(Func *fn) { diff --git a/codon/cir/transform/lowering/imperative.cpp b/codon/cir/transform/lowering/imperative.cpp index 08beb004..76f709b8 100644 --- a/codon/cir/transform/lowering/imperative.cpp +++ b/codon/cir/transform/lowering/imperative.cpp @@ -32,7 +32,7 @@ CallInstr *getRangeIter(Value *iter) { if (!newRangeFunc || newRangeFunc->getUnmangledName() != Module::NEW_MAGIC_NAME) return nullptr; auto *parentType = newRangeFunc->getParentType(); - auto *rangeType = M->getOrRealizeType("range", {}, "std.internal.types.range"); + auto *rangeType = M->getOrRealizeType("range.0", {}, "std.internal.types.range"); if (!parentType || !rangeType || parentType->getName() != rangeType->getName()) return nullptr; @@ -50,7 +50,7 @@ Value *getListIter(Value *iter) { return nullptr; auto *list = iterCall->front(); - if (list->getType()->getName().rfind("std.internal.types.ptr.List[", 0) != 0) + if (list->getType()->getName().rfind("std.internal.types.array.List.0[", 0) != 0) return nullptr; return list; diff --git a/codon/cir/transform/numpy/numpy.cpp b/codon/cir/transform/numpy/numpy.cpp index 3f9ec663..a77d79b4 100644 --- a/codon/cir/transform/numpy/numpy.cpp +++ b/codon/cir/transform/numpy/numpy.cpp @@ -63,8 +63,8 @@ NumPyPrimitiveTypes::NumPyPrimitiveTypes(Module *M) u32(M->getIntNType(32, false)), i64(M->getIntType()), u64(M->getIntNType(64, false)), f16(M->getFloat16Type()), f32(M->getFloat32Type()), f64(M->getFloatType()), - c64(M->getType("std.internal.types.complex.complex64")), - c128(M->getType("std.internal.types.complex.complex")) {} + c64(M->getType("std.internal.types.complex.complex64.0")), + c128(M->getType("std.internal.types.complex.complex.0")) {} NumPyType::NumPyType(Type dtype, int64_t ndim) : dtype(dtype), ndim(ndim) { seqassertn(ndim >= 0, "ndim must be non-negative"); diff --git a/codon/cir/transform/parallel/openmp.cpp b/codon/cir/transform/parallel/openmp.cpp index 898dd763..bdb6ca9a 100644 --- a/codon/cir/transform/parallel/openmp.cpp +++ b/codon/cir/transform/parallel/openmp.cpp @@ -66,7 +66,7 @@ struct ReductionLocks { Var *critLock = nullptr; // lock used in reduction critical sections Var *createLock(Module *M) { - auto *lockType = M->getOrRealizeType("Lock", {}, ompModule); + auto *lockType = M->getOrRealizeType("Lock.0", {}, ompModule); seqassertn(lockType, "openmp.Lock type not found"); auto *var = M->Nr(lockType, /*global=*/true); static int counter = 1; @@ -980,7 +980,7 @@ struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer { auto *init = ptrFromFunc(makeTaskRedInitFunc(reduction)); auto *comb = ptrFromFunc(makeTaskRedCombFunc(reduction)); - auto *taskRedInputType = M->getOrRealizeType("TaskReductionInput", {}, ompModule); + auto *taskRedInputType = M->getOrRealizeType("TaskReductionInput.0", {}, ompModule); seqassertn(taskRedInputType, "could not find 'TaskReductionInput' type"); auto *result = taskRedInputType->construct({shar, orig, size, init, comb}); seqassertn(result, "bad construction of 'TaskReductionInput' type"); @@ -990,6 +990,7 @@ struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer { void handle(VarValue *v) override { auto *M = v->getModule(); auto *func = util::getFunc(v); + if (func && func->getUnmangledName() == "_routine_stub") { std::vector reduceArgs; unsigned sharedsNext = 0; @@ -1046,9 +1047,11 @@ struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer { // add task reduction inputs auto *taskRedInitSeries = M->Nr(); - auto *taskRedInputType = M->getOrRealizeType("TaskReductionInput", {}, ompModule); + auto *taskRedInputType = + M->getOrRealizeType("TaskReductionInput.0", {}, ompModule); seqassertn(taskRedInputType, "could not find 'TaskReductionInput' type"); - auto *irArrayType = M->getOrRealizeType("TaskReductionInputArray", {}, ompModule); + auto *irArrayType = + M->getOrRealizeType("TaskReductionInputArray.0", {}, ompModule); seqassertn(irArrayType, "could not find 'TaskReductionInputArray' type"); auto *taskRedInputsArray = util::makeVar( M->Nr(irArrayType, numRed), taskRedInitSeries, parent); diff --git a/codon/cir/transform/pythonic/list.cpp b/codon/cir/transform/pythonic/list.cpp index 7dc4bfad..42d5ff99 100644 --- a/codon/cir/transform/pythonic/list.cpp +++ b/codon/cir/transform/pythonic/list.cpp @@ -13,8 +13,8 @@ namespace transform { namespace pythonic { namespace { -static const std::string LIST = "std.internal.types.ptr.List"; -static const std::string SLICE = "std.internal.types.slice.Slice[int,int,int]"; +static const std::string LIST = "std.internal.types.array.List.0"; +static const std::string SLICE = "std.internal.types.slice.Slice.0[int,int,int]"; bool isList(Value *v) { return v->getType()->getName().rfind(LIST + "[", 0) == 0; } bool isSlice(Value *v) { return v->getType()->getName() == SLICE; } diff --git a/codon/cir/types/types.cpp b/codon/cir/types/types.cpp index 72b26fee..902045b5 100644 --- a/codon/cir/types/types.cpp +++ b/codon/cir/types/types.cpp @@ -39,14 +39,13 @@ std::vector Type::doGetGenerics() const { ret.emplace_back( getModule()->getCache()->realizeType(cls, extractTypes(cls->generics))); else { - switch (g.type->getStatic()->expr->staticValue.type) { - case ast::StaticValue::INT: - ret.emplace_back(g.type->getStatic()->expr->staticValue.getInt()); - break; - case ast::StaticValue::STRING: - ret.emplace_back(g.type->getStatic()->expr->staticValue.getString()); - break; - default: + if (auto ai = g.type->getIntStatic()) { + ret.emplace_back(ai->value); + } else if (auto ai = g.type->getBoolStatic()) { + ret.emplace_back(int(ai->value)); + } else if (auto as = g.type->getStrStatic()) { + ret.emplace_back(as->value); + } else { seqassertn(false, "IR only supports int or str statics [{}]", g.type->getSrcInfo()); } @@ -172,9 +171,9 @@ std::vector FuncType::doGetGenerics() const { ret.emplace_back( getModule()->getCache()->realizeType(cls, extractTypes(cls->generics))); else { - seqassertn(g.type->getStatic()->expr->staticValue.type == ast::StaticValue::INT, - "IR only supports int statics [{}]", getSrcInfo()); - ret.emplace_back(g.type->getStatic()->expr->staticValue.getInt()); + seqassertn(g.type->getIntStatic(), "IR only supports int statics [{}]", + getSrcInfo()); + ret.emplace_back(g.type->getIntStatic()->value); } } diff --git a/codon/cir/util/side_effect.cpp b/codon/cir/util/side_effect.cpp index a85e12e8..d019061b 100644 --- a/codon/cir/util/side_effect.cpp +++ b/codon/cir/util/side_effect.cpp @@ -6,12 +6,12 @@ namespace codon { namespace ir { namespace util { -const std::string NON_PURE_ATTR = "std.internal.attributes.nonpure"; -const std::string PURE_ATTR = "std.internal.attributes.pure"; -const std::string NO_SIDE_EFFECT_ATTR = "std.internal.attributes.no_side_effect"; -const std::string NO_CAPTURE_ATTR = "std.internal.attributes.nocapture"; -const std::string DERIVES_ATTR = "std.internal.attributes.derives"; -const std::string SELF_CAPTURES_ATTR = "std.internal.attributes.self_captures"; +const std::string NON_PURE_ATTR = "std.internal.attributes.nonpure.0:0"; +const std::string PURE_ATTR = "pure:0"; +const std::string NO_SIDE_EFFECT_ATTR = "std.internal.attributes.no_side_effect.0:0"; +const std::string NO_CAPTURE_ATTR = "std.internal.attributes.nocapture.0:0"; +const std::string DERIVES_ATTR = "derives:0"; +const std::string SELF_CAPTURES_ATTR = "std.internal.attributes.self_captures.0:0"; } // namespace util } // namespace ir diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 7b2e0aa2..e447f0c6 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -7,7 +7,6 @@ #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/doc/doc.h" #include "codon/parser/visitors/format/format.h" -#include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" @@ -74,30 +73,31 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, input = file; std::string abspath = (file != "-") ? ast::getAbsolutePath(file) : file; try { - ast::StmtPtr codeStmt = isCode - ? ast::parseCode(cache.get(), abspath, code, startLine) - : ast::parseFile(cache.get(), abspath); + auto nodeOrErr = isCode ? ast::parseCode(cache.get(), abspath, code, startLine) + : ast::parseFile(cache.get(), abspath); + if (!nodeOrErr) + throw exc::ParserException(nodeOrErr.takeError()); + auto codeStmt = *nodeOrErr; cache->module0 = file; - Timer t2("simplify"); + Timer t2("typecheck"); t2.logged = true; - auto transformed = - ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt), abspath, defines, - getEarlyDefines(), (testFlags > 1)); + auto typechecked = ast::TypecheckVisitor::apply( + cache.get(), codeStmt, abspath, defines, getEarlyDefines(), (testFlags > 1)); LOG_TIME("[T] parse = {:.1f}", totalPeg); - LOG_TIME("[T] simplify = {:.1f}", t2.elapsed() - totalPeg); - - if (codon::getLogger().flags & codon::Logger::FLAG_USER) { - auto fo = fopen("_dump_simplify.sexp", "w"); - fmt::print(fo, "{}\n", transformed->toString(0)); - fclose(fo); + LOG_TIME("[T] typecheck = {:.1f}", t2.elapsed() - totalPeg); + + std::vector> q(cache->_timings.begin(), + cache->_timings.end()); + sort(q.begin(), q.end(), + [](const auto &a, const auto &b) { return b.second < a.second; }); + double s = 0; + for (auto &[k, v] : q) { + s += v; + LOG_TIME(" [->] {:60} = {:10.2f} / {:10.2f}", k, v, s); } - Timer t3("typecheck"); - auto typechecked = - ast::TypecheckVisitor::apply(cache.get(), std::move(transformed)); - t3.log(); if (codon::getLogger().flags & codon::Logger::FLAG_USER) { auto fo = fopen("_dump_typecheck.sexp", "w"); fmt::print(fo, "{}\n", typechecked->toString(0)); @@ -106,30 +106,18 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, fmt::print(fo, "{}\n", r.second->ast->toString(0)); } fclose(fo); + + fo = fopen("_dump_typecheck.htm", "w"); + auto s = ast::FormatVisitor::apply(typechecked, cache.get(), true); + fmt::print(fo, "{}\n", s); + fclose(fo); } Timer t4("translate"); ast::TranslateVisitor::apply(cache.get(), std::move(typechecked)); t4.log(); } catch (const exc::ParserException &exc) { - std::vector messages; - if (exc.messages.empty()) { - const int MAX_ERRORS = 5; - int ei = 0; - for (auto &e : cache->errors) { - for (unsigned i = 0; i < e.messages.size(); i++) { - if (!e.messages[i].empty()) - messages.emplace_back(e.messages[i], e.locations[i].file, - e.locations[i].line, e.locations[i].col, - e.locations[i].len, e.errorCode); - } - if (ei++ > MAX_ERRORS) - break; - } - return llvm::make_error(messages); - } else { - return llvm::make_error(exc); - } + return llvm::make_error(exc.getErrors()); } module->setSrcInfo({abspath, 0, 0, 0}); if (codon::getLogger().flags & codon::Logger::FLAG_USER) { @@ -178,8 +166,8 @@ llvm::Expected Compiler::docgen(const std::vector &fil try { auto j = ast::DocVisitor::apply(argv0, files); return j->toString(); - } catch (exc::ParserException &e) { - return llvm::make_error(e); + } catch (exc::ParserException &exc) { + return llvm::make_error(exc.getErrors()); } } diff --git a/codon/compiler/error.cpp b/codon/compiler/error.cpp index 969939cc..c3f927da 100644 --- a/codon/compiler/error.cpp +++ b/codon/compiler/error.cpp @@ -3,6 +3,27 @@ #include "error.h" namespace codon { + +SrcInfo::SrcInfo(std::string file, int line, int col, int len) + : file(std::move(file)), line(line), col(col), len(len), id(0) { + if (this->file.empty() && line != 0) + line++; + static int nextId = 0; + id = nextId++; +}; + +SrcInfo::SrcInfo() : SrcInfo("", 0, 0, 0) {} + +bool SrcInfo::operator==(const SrcInfo &src) const { return id == src.id; } + +bool SrcInfo::operator<(const SrcInfo &src) const { + return std::tie(file, line, col) < std::tie(src.file, src.line, src.col); +} + +bool SrcInfo::operator<=(const SrcInfo &src) const { + return std::tie(file, line, col) <= std::tie(src.file, src.line, src.col); +} + namespace error { char ParserErrorInfo::ID = 0; @@ -13,15 +34,16 @@ char PluginErrorInfo::ID = 0; char IOErrorInfo::ID = 0; -void raise_error(const char *format) { throw exc::ParserException(format); } +void E(llvm::Error &&error) { throw exc::ParserException(std::move(error)); } -void raise_error(int e, const ::codon::SrcInfo &info, const char *format) { - throw exc::ParserException(e, format, info); -} +} // namespace error -void raise_error(int e, const ::codon::SrcInfo &info, const std::string &format) { - throw exc::ParserException(e, format, info); +namespace exc { +ParserException::ParserException(llvm::Error &&e) noexcept : std::runtime_error("") { + llvm::handleAllErrors(std::move(e), [this](const error::ParserErrorInfo &e) { + errors = e.getErrors(); + }); } -} // namespace error +} // namespace exc } // namespace codon diff --git a/codon/compiler/error.h b/codon/compiler/error.h index 7ae43ac2..6b7be678 100644 --- a/codon/compiler/error.h +++ b/codon/compiler/error.h @@ -12,159 +12,13 @@ namespace codon { namespace error { -class Message { -private: - std::string msg; - std::string file; - int line = 0; - int col = 0; - int len = 0; - int errorCode = -1; - -public: - explicit Message(const std::string &msg, const std::string &file = "", int line = 0, - int col = 0, int len = 0, int errorCode = -1) - : msg(msg), file(file), line(line), col(col), len(len), errorCode(-1) {} - - std::string getMessage() const { return msg; } - std::string getFile() const { return file; } - int getLine() const { return line; } - int getColumn() const { return col; } - int getLength() const { return len; } - int getErrorCode() const { return errorCode; } - - void log(llvm::raw_ostream &out) const { - if (!getFile().empty()) { - out << getFile(); - if (getLine() != 0) { - out << ":" << getLine(); - if (getColumn() != 0) { - out << ":" << getColumn(); - } - } - out << ": "; - } - out << getMessage(); - } -}; - -class ParserErrorInfo : public llvm::ErrorInfo { -private: - std::vector> messages; - -public: - explicit ParserErrorInfo(const std::vector &m) : messages() { - for (auto &msg : m) { - messages.push_back({msg}); - } - } - explicit ParserErrorInfo(const exc::ParserException &e) : messages() { - std::vector group; - for (unsigned i = 0; i < e.messages.size(); i++) { - if (!e.messages[i].empty()) - group.emplace_back(e.messages[i], e.locations[i].file, e.locations[i].line, - e.locations[i].col, e.locations[i].len); - } - messages.push_back(group); - } - - auto begin() { return messages.begin(); } - auto end() { return messages.end(); } - auto begin() const { return messages.begin(); } - auto end() const { return messages.end(); } - - void log(llvm::raw_ostream &out) const override { - for (auto &group : messages) { - for (auto &msg : group) { - msg.log(out); - out << "\n"; - } - } - } - - std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); - } - - static char ID; -}; - -class RuntimeErrorInfo : public llvm::ErrorInfo { -private: - std::string output; - std::string type; - Message message; - std::vector backtrace; - -public: - RuntimeErrorInfo(const std::string &output, const std::string &type, - const std::string &msg, const std::string &file = "", int line = 0, - int col = 0, std::vector backtrace = {}) - : output(output), type(type), message(msg, file, line, col), - backtrace(std::move(backtrace)) {} - - std::string getOutput() const { return output; } - std::string getType() const { return type; } - std::string getMessage() const { return message.getMessage(); } - std::string getFile() const { return message.getFile(); } - int getLine() const { return message.getLine(); } - int getColumn() const { return message.getColumn(); } - std::vector getBacktrace() const { return backtrace; } - - void log(llvm::raw_ostream &out) const override { - out << type << ": "; - message.log(out); - } - - std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); - } - - static char ID; -}; - -class PluginErrorInfo : public llvm::ErrorInfo { -private: - std::string message; - -public: - explicit PluginErrorInfo(const std::string &message) : message(message) {} - - std::string getMessage() const { return message; } - - void log(llvm::raw_ostream &out) const override { out << message; } - - std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); - } - - static char ID; -}; - -class IOErrorInfo : public llvm::ErrorInfo { -private: - std::string message; - -public: - explicit IOErrorInfo(const std::string &message) : message(message) {} - - std::string getMessage() const { return message; } - - void log(llvm::raw_ostream &out) const override { out << message; } - - std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); - } - - static char ID; -}; - enum Error { CALL_NAME_ORDER, CALL_NAME_STAR, CALL_ELLIPSIS, IMPORT_IDENTIFIER, IMPORT_FN, + IMPORT_STAR, FN_LLVM, FN_LAST_KWARG, FN_MULTIPLE_ARGS, @@ -252,6 +106,7 @@ enum Error { SLICE_STEP_ZERO, OP_NO_MAGIC, INST_CALLABLE_STATIC, + CATCH_EXCEPTION_TYPE, TYPE_CANNOT_REALIZE_ATTR, TYPE_UNIFY, TYPE_FAILED, @@ -260,6 +115,116 @@ enum Error { __END__ }; +class ParserErrorInfo : public llvm::ErrorInfo { +private: + ParserErrors errors; + +public: + static char ID; + +public: + explicit ParserErrorInfo(const ErrorMessage &msg) : errors(msg) {} + explicit ParserErrorInfo(const std::vector &msgs) : errors(msgs) {} + explicit ParserErrorInfo(const ParserErrors &errors) : errors(errors) {} + + template + ParserErrorInfo(error::Error e, const codon::SrcInfo &o = codon::SrcInfo(), const TA &...args) { + auto msg = Emsg(e, args...); + errors = ParserErrors(ErrorMessage(msg, o, (int)e)); + } + + const ParserErrors &getErrors() const { return errors; } + ParserErrors &getErrors() { return errors; } + + void log(llvm::raw_ostream &out) const override { + for (const auto &trace : errors) { + for (const auto &msg : trace.getMessages()) { + msg.log(out); + out << "\n"; + } + } + } + + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +class RuntimeErrorInfo : public llvm::ErrorInfo { +private: + std::string output; + std::string type; + ErrorMessage message; + std::vector backtrace; + +public: + RuntimeErrorInfo(const std::string &output, const std::string &type, + const std::string &msg, const std::string &file = "", int line = 0, + int col = 0, std::vector backtrace = {}) + : output(output), type(type), message(msg, file, line, col), + backtrace(std::move(backtrace)) {} + + std::string getOutput() const { return output; } + std::string getType() const { return type; } + std::string getMessage() const { return message.getMessage(); } + std::string getFile() const { return message.getFile(); } + int getLine() const { return message.getLine(); } + int getColumn() const { return message.getColumn(); } + std::vector getBacktrace() const { return backtrace; } + + void log(llvm::raw_ostream &out) const override { + out << type << ": "; + message.log(out); + } + + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } + + static char ID; +}; + +class PluginErrorInfo : public llvm::ErrorInfo { +private: + std::string message; + +public: + explicit PluginErrorInfo(const std::string &message) : message(message) {} + + std::string getMessage() const { return message; } + + void log(llvm::raw_ostream &out) const override { out << message; } + + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } + + static char ID; +}; + +class IOErrorInfo : public llvm::ErrorInfo { +private: + std::string message; + +public: + explicit IOErrorInfo(const std::string &message) : message(message) {} + + std::string getMessage() const { return message; } + + void log(llvm::raw_ostream &out) const override { out << message; } + + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } + + static char ID; +}; + +template std::string Eformat(const TA &...args) { return ""; } +template std::string Eformat(const char *fmt, const TA &...args) { + return fmt::format(fmt, args...); +} + template std::string Emsg(Error e, const TA &...args) { switch (e) { /// Validations @@ -274,6 +239,8 @@ template std::string Emsg(Error e, const TA &...args) { case Error::IMPORT_FN: return fmt::format( "function signatures only allowed when importing C or Python functions"); + case Error::IMPORT_STAR: + return fmt::format("import * only allowed at module level"); case Error::FN_LLVM: return fmt::format("return types required for LLVM and C functions"); case Error::FN_LAST_KWARG: @@ -336,7 +303,8 @@ template std::string Emsg(Error e, const TA &...args) { case Error::ASSIGN_INVALID: return fmt::format("cannot assign to given expression"); case Error::ASSIGN_LOCAL_REFERENCE: - return fmt::format("local variable '{}' referenced before assignment", args...); + return fmt::format("local variable '{}' referenced before assignment at {}", + args...); case Error::ASSIGN_MULTI_STAR: return fmt::format("multiple starred expressions in assignment"); case Error::INT_RANGE: @@ -365,7 +333,7 @@ template std::string Emsg(Error e, const TA &...args) { case Error::CLASS_INVALID_BIND: return fmt::format("cannot bind '{}' to class or function", args...); case Error::CLASS_NO_INHERIT: - return fmt::format("{} classes cannot inherit other classes", args...); + return fmt::format("{} classes cannot inherit {} classes", args...); case Error::CLASS_TUPLE_INHERIT: return fmt::format("reference classes cannot inherit tuple classes"); case Error::CLASS_BAD_MRO: @@ -391,8 +359,7 @@ template std::string Emsg(Error e, const TA &...args) { case Error::LOOP_DECORATOR: return fmt::format("invalid loop decorator"); case Error::BAD_STATIC_TYPE: - return fmt::format( - "expected 'int' or 'str' (only integers and strings can be static)"); + return fmt::format("expected 'int', 'bool' or 'str'"); case Error::EXPECTED_TYPE: return fmt::format("expected {} expression", args...); case Error::UNEXPECTED_TYPE: @@ -401,7 +368,7 @@ template std::string Emsg(Error e, const TA &...args) { /// Typechecking case Error::UNION_TOO_BIG: return fmt::format( - "union exceeded its maximum capacity (contains more than {} types)"); + "union exceeded its maximum capacity (contains more than {} types)", args...); case Error::DOT_NO_ATTR: return fmt::format("'{}' object has no attribute '{}'", args...); case Error::DOT_NO_ATTR_ARGS: @@ -460,6 +427,8 @@ template std::string Emsg(Error e, const TA &...args) { return fmt::format("unsupported operand type(s) for {}: '{}' and '{}'", args...); case Error::INST_CALLABLE_STATIC: return fmt::format("Callable cannot take static types"); + case Error::CATCH_EXCEPTION_TYPE: + return fmt::format("'{}' does not inherit from BaseException", args...); case Error::TYPE_CANNOT_REALIZE_ATTR: return fmt::format("type of attribute '{}' of object '{}' cannot be inferred", @@ -479,24 +448,20 @@ template std::string Emsg(Error e, const TA &...args) { return fmt::format( "maximum realization depth reached during the realization of '{}'", args...); case Error::CUSTOM: - return fmt::format("{}", args...); - + return Eformat(args...); default: assert(false); } } -/// Raise a parsing error. -void raise_error(const char *format); -/// Raise a parsing error at a source location p. -void raise_error(int e, const codon::SrcInfo &info, const char *format); -void raise_error(int e, const codon::SrcInfo &info, const std::string &format); - template void E(Error e, const codon::SrcInfo &o = codon::SrcInfo(), const TA &...args) { auto msg = Emsg(e, args...); - raise_error((int)e, o, msg); + auto err = ParserErrors(ErrorMessage(msg, o, (int)e)); + throw exc::ParserException(err); } +void E(llvm::Error &&error); + } // namespace error } // namespace codon diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp index 1e1e370b..1db81879 100644 --- a/codon/compiler/jit.cpp +++ b/codon/compiler/jit.cpp @@ -8,7 +8,7 @@ #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/doc/doc.h" #include "codon/parser/visitors/format/format.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" @@ -35,11 +35,8 @@ llvm::Error JIT::init() { auto *pm = compiler->getPassManager(); auto *llvisitor = compiler->getLLVMVisitor(); - auto transformed = - ast::SimplifyVisitor::apply(cache, std::make_shared(), - JIT_FILENAME, {}, compiler->getEarlyDefines()); - - auto typechecked = ast::TypecheckVisitor::apply(cache, std::move(transformed)); + auto typechecked = ast::TypecheckVisitor::apply( + cache, cache->N(), JIT_FILENAME, {}, compiler->getEarlyDefines()); ast::TranslateVisitor::apply(cache, std::move(typechecked)); cache->isJit = true; // we still need main(), so set isJit after it has been set module->setSrcInfo({JIT_FILENAME, 0, 0, 0}); @@ -85,61 +82,58 @@ llvm::Expected JIT::compile(const std::string &code, const std::string &file, int line) { auto *cache = compiler->getCache(); auto sctx = cache->imports[MAIN_IMPORT].ctx; - auto preamble = std::make_shared>(); + auto preamble = std::make_shared>(); ast::Cache bCache = *cache; - ast::SimplifyContext bSimplify = *sctx; - ast::SimplifyContext stdlibSimplify = *(cache->imports[STDLIB_IMPORT].ctx); + auto bTypecheck = *sctx; + auto stdlibTypecheck = *(cache->imports[STDLIB_IMPORT].ctx); ast::TypeContext bType = *(cache->typeCtx); ast::TranslateContext bTranslate = *(cache->codegenCtx); try { - ast::StmtPtr node = ast::parseCode(cache, file.empty() ? JIT_FILENAME : file, code, - /*startLine=*/line); - auto *e = node->getSuite() ? node->getSuite()->lastInBlock() : &node; + auto nodeOrErr = ast::parseCode(cache, file.empty() ? JIT_FILENAME : file, code, + /*startLine=*/line); + if (!nodeOrErr) + throw exc::ParserException(nodeOrErr.takeError()); + auto *node = *nodeOrErr; + + ast::Stmt **e = &node; + while (auto se = ast::cast(*e)) { + if (se->empty()) + break; + e = &se->back(); + } if (e) - if (auto ex = const_cast((*e)->getExpr())) { - *e = std::make_shared(std::make_shared( - std::make_shared("_jit_display"), ex->expr->clone(), - std::make_shared(mode))); + if (auto ex = ast::cast(*e)) { + *e = cache->N(cache->N( + cache->N("_jit_display"), clone(ex->getExpr()), + cache->N(mode))); } - auto s = ast::SimplifyVisitor(sctx, preamble).transform(node); + auto tv = ast::TypecheckVisitor(sctx, preamble); + if (auto err = ast::ScopingVisitor::apply(sctx->cache, node)) + throw exc::ParserException(std::move(err)); + node = tv.transform(node); + if (!cache->errors.empty()) - throw exc::ParserException(); - auto simplified = std::make_shared(); + throw exc::ParserException(cache->errors); + auto typechecked = cache->N(); for (auto &s : *preamble) - simplified->stmts.push_back(s); - simplified->stmts.push_back(s); + typechecked->addStmt(s); + typechecked->addStmt(node); // TODO: unroll on errors... - auto *cache = compiler->getCache(); - auto typechecked = ast::TypecheckVisitor::apply(cache, simplified); - // add newly realized functions - std::vector v; + std::vector v; std::vector frs; v.push_back(typechecked); for (auto &p : cache->pendingRealizations) { v.push_back(cache->functions[p.first].ast); frs.push_back(&cache->functions[p.first].realizations[p.second]->ir); } - auto func = - ast::TranslateVisitor::apply(cache, std::make_shared(v)); + auto func = ast::TranslateVisitor::apply(cache, cache->N(v)); cache->jitCell++; return func; } catch (const exc::ParserException &exc) { - std::vector messages; - if (exc.messages.empty()) { - for (auto &e : cache->errors) { - for (unsigned i = 0; i < e.messages.size(); i++) { - if (!e.messages[i].empty()) - messages.emplace_back(e.messages[i], e.locations[i].file, - e.locations[i].line, e.locations[i].col, - e.locations[i].len, e.errorCode); - } - } - } - for (auto &f : cache->functions) for (auto &r : f.second.realizations) if (!(in(bCache.functions, f.first) && @@ -148,15 +142,12 @@ llvm::Expected JIT::compile(const std::string &code, cache->module->remove(r.second->ir); } *cache = bCache; - *(cache->imports[MAIN_IMPORT].ctx) = bSimplify; - *(cache->imports[STDLIB_IMPORT].ctx) = stdlibSimplify; + *(cache->imports[MAIN_IMPORT].ctx) = bTypecheck; + *(cache->imports[STDLIB_IMPORT].ctx) = stdlibTypecheck; *(cache->typeCtx) = bType; *(cache->codegenCtx) = bTranslate; - if (exc.messages.empty()) - return llvm::make_error(messages); - else - return llvm::make_error(exc); + return llvm::make_error(exc.getErrors()); } } diff --git a/codon/dsl/dsl.h b/codon/dsl/dsl.h index e4361c43..874f306d 100644 --- a/codon/dsl/dsl.h +++ b/codon/dsl/dsl.h @@ -39,7 +39,7 @@ class DSL { }; using KeywordCallback = - std::function; + std::function; struct ExprKeyword { std::string keyword; diff --git a/codon/parser/ast.h b/codon/parser/ast.h index b6bb1719..0b3a7f9f 100644 --- a/codon/parser/ast.h +++ b/codon/parser/ast.h @@ -2,7 +2,15 @@ #pragma once +#include +#include +#include + +#include "codon/cir/attribute.h" +#include "codon/cir/base.h" +#include "codon/parser/ast/attr.h" #include "codon/parser/ast/error.h" #include "codon/parser/ast/expr.h" +#include "codon/parser/ast/node.h" #include "codon/parser/ast/stmt.h" #include "codon/parser/ast/types.h" diff --git a/codon/parser/ast/attr.cpp b/codon/parser/ast/attr.cpp new file mode 100644 index 00000000..c5f09981 --- /dev/null +++ b/codon/parser/ast/attr.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#include "attr.h" + +namespace codon::ast { + +const std::string Attr::Module = "module"; +const std::string Attr::ParentClass = "parentClass"; +const std::string Attr::Bindings = "bindings"; + +const std::string Attr::LLVM = "llvm"; +const std::string Attr::Python = "python"; +const std::string Attr::Atomic = "atomic"; +const std::string Attr::Property = "property"; +const std::string Attr::StaticMethod = "staticmethod"; +const std::string Attr::Attribute = "__attribute__"; +const std::string Attr::C = "C"; + +const std::string Attr::Internal = "__internal__"; +const std::string Attr::HiddenFromUser = "__hidden__"; +const std::string Attr::ForceRealize = "__force__"; +const std::string Attr::RealizeWithoutSelf = + "std.internal.attributes.realize_without_self.0:0"; +const std::string Attr::ParentCallExpr = "parentCallExpr"; +const std::string Attr::TupleCall = "tupleFn"; +const std::string Attr::Validated = "validated"; +const std::string Attr::AutoGenerated = "autogenerated"; + +const std::string Attr::CVarArg = ".__vararg__"; +const std::string Attr::Method = ".__method__"; +const std::string Attr::Capture = ".__capture__"; +const std::string Attr::HasSelf = ".__hasself__"; +const std::string Attr::IsGenerator = ".__generator__"; + +const std::string Attr::Extend = "extend"; +const std::string Attr::Tuple = "tuple"; +const std::string Attr::ClassDeduce = "deduce"; +const std::string Attr::ClassNoTuple = "__notuple__"; + +const std::string Attr::Test = "std.internal.attributes.test.0:0"; +const std::string Attr::Overload = "overload:0"; +const std::string Attr::Export = "std.internal.attributes.export.0:0"; +const std::string Attr::Inline = "std.internal.attributes.inline.0:0"; +const std::string Attr::NoArgReorder = "std.internal.attributes.no_arg_reorder.0:0"; + +const std::string Attr::ClassMagic = "classMagic"; +const std::string Attr::ExprSequenceItem = "exprSequenceItem"; +const std::string Attr::ExprStarSequenceItem = "exprStarSequenceItem"; +const std::string Attr::ExprList = "exprList"; +const std::string Attr::ExprSet = "exprSet"; +const std::string Attr::ExprDict = "exprDict"; +const std::string Attr::ExprPartial = "exprPartial"; +const std::string Attr::ExprDominated = "exprDominated"; +const std::string Attr::ExprStarArgument = "exprStarArgument"; +const std::string Attr::ExprKwStarArgument = "exprKwStarArgument"; +const std::string Attr::ExprOrderedCall = "exprOrderedCall"; +const std::string Attr::ExprExternVar = "exprExternVar"; +const std::string Attr::ExprDominatedUndefCheck = "exprDominatedUndefCheck"; +const std::string Attr::ExprDominatedUsed = "exprDominatedUsed"; +const std::string Attr::ExprTime = "exprTime"; +const std::string Attr::ExprDoNotRealize = "exprDoNotRealize"; + +} // namespace codon::ast diff --git a/codon/parser/ast/attr.h b/codon/parser/ast/attr.h new file mode 100644 index 00000000..a690a7f8 --- /dev/null +++ b/codon/parser/ast/attr.h @@ -0,0 +1,68 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#pragma once + +#include + +namespace codon::ast { + +const int INDENT_SIZE = 2; + +struct Attr { + // Function attributes + const static std::string Module; + const static std::string ParentClass; + const static std::string Bindings; + // Toplevel attributes + const static std::string LLVM; + const static std::string Python; + const static std::string Atomic; + const static std::string Property; + const static std::string StaticMethod; + const static std::string Attribute; + const static std::string C; + // Internal attributes + const static std::string Internal; + const static std::string HiddenFromUser; + const static std::string ForceRealize; + const static std::string RealizeWithoutSelf; // not internal + const static std::string ParentCallExpr; + const static std::string TupleCall; + const static std::string Validated; + const static std::string AutoGenerated; + // Compiler-generated attributes + const static std::string CVarArg; + const static std::string Method; + const static std::string Capture; + const static std::string HasSelf; + const static std::string IsGenerator; + // Class attributes + const static std::string Extend; + const static std::string Tuple; + const static std::string ClassDeduce; + const static std::string ClassNoTuple; + // Standard library attributes + const static std::string Test; + const static std::string Overload; + const static std::string Export; + const static std::string Inline; + const static std::string NoArgReorder; + // Expression-related attributes + const static std::string ClassMagic; + const static std::string ExprSequenceItem; + const static std::string ExprStarSequenceItem; + const static std::string ExprList; + const static std::string ExprSet; + const static std::string ExprDict; + const static std::string ExprPartial; + const static std::string ExprDominated; + const static std::string ExprStarArgument; + const static std::string ExprKwStarArgument; + const static std::string ExprOrderedCall; + const static std::string ExprExternVar; + const static std::string ExprDominatedUndefCheck; + const static std::string ExprDominatedUsed; + const static std::string ExprTime; + const static std::string ExprDoNotRealize; +}; +} // namespace codon::ast diff --git a/codon/parser/ast/error.h b/codon/parser/ast/error.h index 1e540726..76843171 100644 --- a/codon/parser/ast/error.h +++ b/codon/parser/ast/error.h @@ -2,6 +2,7 @@ #pragma once +#include "llvm/Support/Error.h" #include #include #include @@ -23,17 +24,111 @@ struct SrcInfo { int len; int id; /// used to differentiate different instances - SrcInfo(std::string file, int line, int col, int len) - : file(std::move(file)), line(line), col(col), len(len), id(0) { - static int nextId = 0; - id = nextId++; - } + SrcInfo(); + SrcInfo(std::string file, int line, int col, int len); + bool operator==(const SrcInfo &src) const; + bool operator<(const SrcInfo &src) const; + bool operator<=(const SrcInfo &src) const; +}; + +class ErrorMessage { +private: + std::string msg; + SrcInfo loc; + int errorCode = -1; + +public: + explicit ErrorMessage(const std::string &msg, const SrcInfo &loc = SrcInfo(), + int errorCode = -1) + : msg(msg), loc(loc), errorCode(-1) {} + explicit ErrorMessage(const std::string &msg, const std::string &file = "", + int line = 0, int col = 0, int len = 0, int errorCode = -1) + : msg(msg), loc(file, line, col, len), errorCode(-1) {} - SrcInfo() : SrcInfo("", 0, 0, 0) {} + std::string getMessage() const { return msg; } + std::string getFile() const { return loc.file; } + int getLine() const { return loc.line; } + int getColumn() const { return loc.col; } + int getLength() const { return loc.len; } + int getErrorCode() const { return errorCode; } + SrcInfo getSrcInfo() const { return loc; } + void setSrcInfo(const SrcInfo &s) { loc = s; } + bool operator==(const ErrorMessage &t) const { return msg == t.msg && loc == t.loc; } - bool operator==(const SrcInfo &src) const { return id == src.id; } + void log(llvm::raw_ostream &out) const { + if (!getFile().empty()) { + out << getFile(); + if (getLine() != 0) { + out << ":" << getLine(); + if (getColumn() != 0) { + out << ":" << getColumn(); + } + } + out << ": "; + } + out << getMessage(); + } }; +struct ParserErrors { + struct Backtrace { + std::vector trace; + const std::vector &getMessages() const { return trace; } + auto begin() const { return trace.begin(); } + auto front() const { return trace.front(); } + auto front() { return trace.front(); } + auto end() const { return trace.end(); } + auto back() { return trace.back(); } + auto back() const { return trace.back(); } + auto size() const { return trace.size(); } + void addMessage(const std::string &msg, const SrcInfo &info = SrcInfo()) { + trace.emplace_back(msg, info); + } + bool operator==(const Backtrace &t) const { return trace == t.trace; } + }; + std::vector errors; + + ParserErrors() {} + ParserErrors(const ErrorMessage &msg) : errors{Backtrace{{msg}}} {} + ParserErrors(const std::string &msg, const SrcInfo &info) + : ParserErrors({msg, info}) {} + ParserErrors(const std::string &msg) : ParserErrors(msg, {}) {} + ParserErrors(const ParserErrors &e) : errors(e.errors) {} + ParserErrors(const std::vector &m) : ParserErrors() { + for (auto &msg : m) + errors.push_back(Backtrace{{msg}}); + } + + auto begin() { return errors.begin(); } + auto end() { return errors.end(); } + auto begin() const { return errors.begin(); } + auto end() const { return errors.end(); } + auto empty() const { return errors.empty(); } + auto size() const { return errors.size(); } + auto &back() { return errors.back(); } + const auto &back() const { return errors.back(); } + void append(const ParserErrors &e) { + for (auto &trace : e) + addError(trace); + } + + Backtrace getLast() { + assert(!empty() && "empty error trace"); + return errors.back(); + } + + /// Add an error message to the current backtrace + void addError(const Backtrace &trace) { + if (errors.empty() || !(errors.back() == trace)) + errors.push_back({trace}); + } + void addError(const std::vector &trace) { addError(Backtrace{trace}); } + std::string getMessage() const { + if (empty()) + return ""; + return errors.front().trace.front().getMessage(); + } +}; } // namespace codon namespace codon::exc { @@ -43,38 +138,17 @@ namespace codon::exc { * Used for parsing, transformation and type-checking errors. */ class ParserException : public std::runtime_error { -public: /// These vectors (stacks) store an error stack-trace. - std::vector locations; - std::vector messages; - int errorCode = -1; + ParserErrors errors; public: - ParserException(int errorCode, const std::string &msg, const SrcInfo &info) noexcept - : std::runtime_error(msg), errorCode(errorCode) { - messages.push_back(msg); - locations.push_back(info); - } ParserException() noexcept : std::runtime_error("") {} - ParserException(int errorCode, const std::string &msg) noexcept - : ParserException(errorCode, msg, {}) {} - explicit ParserException(const std::string &msg) noexcept - : ParserException(-1, msg, {}) {} - ParserException(const ParserException &e) noexcept - : std::runtime_error(e), locations(e.locations), messages(e.messages), - errorCode(e.errorCode) {} - - /// Add an error message to the current stack trace - void trackRealize(const std::string &msg, const SrcInfo &info) { - locations.push_back(info); - messages.push_back("during the realization of " + msg); - } + ParserException(const ParserErrors &errors) noexcept + : std::runtime_error(errors.getMessage()), errors(errors) {} + ParserException(llvm::Error &&e) noexcept; - /// Add an error message to the current stack trace - void track(const std::string &msg, const SrcInfo &info) { - locations.push_back(info); - messages.push_back(msg); - } + const ParserErrors &getErrors() const { return errors; } + ParserErrors &getErrors() { return errors; } }; } // namespace codon::exc diff --git a/codon/parser/ast/expr.cpp b/codon/parser/ast/expr.cpp index abf6b181..aa122788 100644 --- a/codon/parser/ast/expr.cpp +++ b/codon/parser/ast/expr.cpp @@ -8,8 +8,11 @@ #include #include +#include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/cache.h" +#include "codon/parser/match.h" +#include "codon/parser/peg/peg.h" #include "codon/parser/visitors/visitor.h" #define FASTFLOAT_ALLOWS_LEADING_PLUS @@ -17,141 +20,127 @@ #include "fast_float/fast_float.h" #define ACCEPT_IMPL(T, X) \ - ExprPtr T::clone() const { return std::make_shared(*this); } \ - void T::accept(X &visitor) { visitor.visit(this); } + ASTNode *T::clone(bool c) const { return cache->N(*this, c); } \ + void T::accept(X &visitor) { visitor.visit(this); } \ + const char T::NodeId = 0; using fmt::format; using namespace codon::error; +using namespace codon::matcher; namespace codon::ast { -Expr::Expr() - : type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC), - done(false), attributes(0), origExpr(nullptr) {} -void Expr::validate() const {} -types::TypePtr Expr::getType() const { return type; } -void Expr::setType(types::TypePtr t) { this->type = std::move(t); } -bool Expr::isType() const { return isTypeExpr; } -void Expr::markType() { isTypeExpr = true; } +ASTNode::ASTNode(const ASTNode &node) : Node(node), cache(node.cache) {} + +Expr::Expr() : AcceptorExtend(), type(nullptr), done(false), origExpr(nullptr) {} +Expr::Expr(const Expr &expr) + : AcceptorExtend(expr), type(expr.type), done(expr.done), origExpr(expr.origExpr) {} +Expr::Expr(const Expr &expr, bool clean) : Expr(expr) { + if (clean) { + type = nullptr; + done = false; + } +} +types::ClassType *Expr::getClassType() const { + return type ? type->getClass() : nullptr; +} std::string Expr::wrapType(const std::string &sexpr) const { auto is = sexpr; if (done) is.insert(findStar(is), "*"); - auto s = format("({}{})", is, type ? format(" #:type \"{}\"", type->toString()) : ""); - // if (hasAttr(ExprAttr::SequenceItem)) s += "%"; + auto s = format("({}{})", is, + type && !done ? format(" #:type \"{}\"", type->debugString(2)) : ""); return s; } -bool Expr::isStatic() const { return staticValue.type != StaticValue::NOT_STATIC; } -bool Expr::hasAttr(int attr) const { return (attributes & (1 << attr)); } -void Expr::setAttr(int attr) { attributes |= (1 << attr); } -std::string Expr::getTypeName() { - if (getId()) { - return getId()->value; +// llvm::Expected *Expr::operator<<(types::Type *t) { +// seqassert(type, "lhs is nullptr"); +// if ((*type) << t) { +// E(Error::TYPE_UNIFY, getSrcInfo(), type->prettyString(), t->prettyString()); +// } +// return this; +// } + +Param::Param(std::string name, Expr *type, Expr *defaultValue, int status) + : name(std::move(name)), type(type), defaultValue(defaultValue) { + if (status == 0 && + (match(getType(), MOr(M(TYPE_TYPE), M(TYPE_TYPEVAR), + M(M(TYPE_TYPEVAR), M_))) || + getStaticGeneric(getType()))) { + this->status = Generic; } else { - auto i = dynamic_cast(this); - seqassertn(i && i->typeExpr->getId(), "bad MRO"); - return i->typeExpr->getId()->value; + this->status = (status == 0 ? Value : (status == 1 ? Generic : HiddenGeneric)); } } - -StaticValue::StaticValue(StaticValue::Type t) : value(), type(t), evaluated(false) {} -StaticValue::StaticValue(int64_t i) : value(i), type(INT), evaluated(true) {} -StaticValue::StaticValue(std::string s) - : value(std::move(s)), type(STRING), evaluated(true) {} -bool StaticValue::operator==(const StaticValue &s) const { - if (type != s.type || s.evaluated != evaluated) - return false; - return !s.evaluated || value == s.value; -} -std::string StaticValue::toString() const { - if (type == StaticValue::NOT_STATIC) - return ""; - if (!evaluated) - return type == StaticValue::STRING ? "str" : "int"; - return type == StaticValue::STRING ? "'" + escape(std::get(value)) + "'" - : std::to_string(std::get(value)); -} -int64_t StaticValue::getInt() const { - seqassertn(type == StaticValue::INT, "not an int"); - return std::get(value); -} -std::string StaticValue::getString() const { - seqassertn(type == StaticValue::STRING, "not a string"); - return std::get(value); -} - -Param::Param(std::string name, ExprPtr type, ExprPtr defaultValue, int status) - : name(std::move(name)), type(std::move(type)), - defaultValue(std::move(defaultValue)) { - if (status == 0 && this->type && - (this->type->isId("type") || this->type->isId(TYPE_TYPEVAR) || - (this->type->getIndex() && this->type->getIndex()->expr->isId(TYPE_TYPEVAR)) || - getStaticGeneric(this->type.get()))) - this->status = Generic; - else - this->status = (status == 0 ? Normal : (status == 1 ? Generic : HiddenGeneric)); -} -Param::Param(const SrcInfo &info, std::string name, ExprPtr type, ExprPtr defaultValue, +Param::Param(const SrcInfo &info, std::string name, Expr *type, Expr *defaultValue, int status) : Param(name, type, defaultValue, status) { setSrcInfo(info); } -std::string Param::toString() const { - return format("({}{}{}{})", name, type ? " #:type " + type->toString() : "", - defaultValue ? " #:default " + defaultValue->toString() : "", - status != Param::Normal ? " #:generic" : ""); +std::string Param::toString(int indent) const { + return format("({}{}{}{})", name, type ? " #:type " + type->toString(indent) : "", + defaultValue ? " #:default " + defaultValue->toString(indent) : "", + !isValue() ? " #:generic" : ""); +} +Param Param::clone(bool clean) const { + return Param(name, ast::clone(type, clean), ast::clone(defaultValue, clean), status); } -Param Param::clone() const { - return Param(name, ast::clone(type), ast::clone(defaultValue), status); +std::pair Param::getNameWithStars() const { + int stars = 0; + for (; stars < name.size() && name[stars] == '*'; stars++) + ; + auto n = name.substr(stars); + return {stars, n}; } -NoneExpr::NoneExpr() : Expr() {} -std::string NoneExpr::toString() const { return wrapType("none"); } -ACCEPT_IMPL(NoneExpr, ASTVisitor); -BoolExpr::BoolExpr(bool value) : Expr(), value(value) { - staticValue = StaticValue(value); -} -std::string BoolExpr::toString() const { +NoneExpr::NoneExpr() : AcceptorExtend() {} +NoneExpr::NoneExpr(const NoneExpr &expr, bool clean) : AcceptorExtend(expr, clean) {} +std::string NoneExpr::toString(int) const { return wrapType("none"); } + +BoolExpr::BoolExpr(bool value) : AcceptorExtend(), value(value) {} +BoolExpr::BoolExpr(const BoolExpr &expr, bool clean) + : AcceptorExtend(expr, clean), value(expr.value) {} +bool BoolExpr::getValue() const { return value; } +std::string BoolExpr::toString(int) const { return wrapType(format("bool {}", int(value))); } -ACCEPT_IMPL(BoolExpr, ASTVisitor); -IntExpr::IntExpr(int64_t intValue) : Expr(), value(std::to_string(intValue)) { - this->intValue = std::make_unique(intValue); - staticValue = StaticValue(intValue); -} +IntExpr::IntExpr(int64_t intValue) + : AcceptorExtend(), value(std::to_string(intValue)), intValue(intValue) {} IntExpr::IntExpr(const std::string &value, std::string suffix) - : Expr(), value(), suffix(std::move(suffix)) { + : AcceptorExtend(), value(), suffix(std::move(suffix)) { for (auto c : value) if (c != '_') this->value += c; try { if (startswith(this->value, "0b") || startswith(this->value, "0B")) - intValue = - std::make_unique(std::stoull(this->value.substr(2), nullptr, 2)); + intValue = std::stoull(this->value.substr(2), nullptr, 2); else - intValue = std::make_unique(std::stoull(this->value, nullptr, 0)); + intValue = std::stoull(this->value, nullptr, 0); } catch (std::out_of_range &) { - intValue = nullptr; } } -IntExpr::IntExpr(const IntExpr &expr) - : Expr(expr), value(expr.value), suffix(expr.suffix) { - intValue = expr.intValue ? std::make_unique(*(expr.intValue)) : nullptr; +IntExpr::IntExpr(const IntExpr &expr, bool clean) + : AcceptorExtend(expr, clean), value(expr.value), suffix(expr.suffix), + intValue(expr.intValue) {} +std::pair IntExpr::getRawData() const { + return {value, suffix}; +} +bool IntExpr::hasStoredValue() const { return intValue.has_value(); } +int64_t IntExpr::getValue() const { + seqassertn(hasStoredValue(), "value not set"); + return intValue.value(); } -std::string IntExpr::toString() const { +std::string IntExpr::toString(int) const { return wrapType(format("int {}{}", value, suffix.empty() ? "" : format(" #:suffix \"{}\"", suffix))); } -ACCEPT_IMPL(IntExpr, ASTVisitor); FloatExpr::FloatExpr(double floatValue) - : Expr(), value(fmt::format("{:g}", floatValue)) { - this->floatValue = std::make_unique(floatValue); + : AcceptorExtend(), value(fmt::format("{:g}", floatValue)), floatValue(floatValue) { } FloatExpr::FloatExpr(const std::string &value, std::string suffix) - : Expr(), value(), suffix(std::move(suffix)) { + : AcceptorExtend(), value(), suffix(std::move(suffix)) { this->value.reserve(value.size()); std::copy_if(value.begin(), value.end(), std::back_inserter(this->value), [](char c) { return c != '_'; }); @@ -160,375 +149,463 @@ FloatExpr::FloatExpr(const std::string &value, std::string suffix) auto r = fast_float::from_chars(this->value.data(), this->value.data() + this->value.size(), result); if (r.ec == std::errc() || r.ec == std::errc::result_out_of_range) - floatValue = std::make_unique(result); - else - floatValue = nullptr; + floatValue = result; +} +FloatExpr::FloatExpr(const FloatExpr &expr, bool clean) + : AcceptorExtend(expr, clean), value(expr.value), suffix(expr.suffix), + floatValue(expr.floatValue) {} +std::pair FloatExpr::getRawData() const { + return {value, suffix}; } -FloatExpr::FloatExpr(const FloatExpr &expr) - : Expr(expr), value(expr.value), suffix(expr.suffix) { - floatValue = expr.floatValue ? std::make_unique(*(expr.floatValue)) : nullptr; +bool FloatExpr::hasStoredValue() const { return floatValue.has_value(); } +double FloatExpr::getValue() const { + seqassertn(hasStoredValue(), "value not set"); + return floatValue.value(); } -std::string FloatExpr::toString() const { +std::string FloatExpr::toString(int) const { return wrapType(format("float {}{}", value, suffix.empty() ? "" : format(" #:suffix \"{}\"", suffix))); } -ACCEPT_IMPL(FloatExpr, ASTVisitor); -StringExpr::StringExpr(std::vector> s) - : Expr(), strings(std::move(s)) { - if (strings.size() == 1 && strings.back().second.empty()) - staticValue = StaticValue(strings.back().first); -} +StringExpr::StringExpr(std::vector s) + : AcceptorExtend(), strings(std::move(s)) {} StringExpr::StringExpr(std::string value, std::string prefix) - : StringExpr(std::vector>{{value, prefix}}) {} -std::string StringExpr::toString() const { + : StringExpr(std::vector{{value, prefix}}) {} +StringExpr::StringExpr(const StringExpr &expr, bool clean) + : AcceptorExtend(expr, clean), strings(expr.strings) { + for (auto &s : strings) + s.expr = ast::clone(s.expr); +} +std::string StringExpr::toString(int) const { std::vector s; for (auto &vp : strings) - s.push_back(format("\"{}\"{}", escape(vp.first), - vp.second.empty() ? "" : format(" #:prefix \"{}\"", vp.second))); + s.push_back(format("\"{}\"{}", escape(vp.value), + vp.prefix.empty() ? "" : format(" #:prefix \"{}\"", vp.prefix))); return wrapType(format("string ({})", join(s))); } std::string StringExpr::getValue() const { - seqassert(!strings.empty(), "invalid StringExpr"); - return strings[0].first; + seqassert(isSimple(), "invalid StringExpr"); + return strings[0].value; +} +bool StringExpr::isSimple() const { + return strings.size() == 1 && strings[0].prefix.empty(); } -ACCEPT_IMPL(StringExpr, ASTVisitor); -IdExpr::IdExpr(std::string value) : Expr(), value(std::move(value)) {} -std::string IdExpr::toString() const { - return !type ? format("'{}", value) : wrapType(format("'{}", value)); +IdExpr::IdExpr(std::string value) : AcceptorExtend(), value(std::move(value)) {} +IdExpr::IdExpr(const IdExpr &expr, bool clean) + : AcceptorExtend(expr, clean), value(expr.value) {} +std::string IdExpr::toString(int) const { + return !getType() ? format("'{}", value) : wrapType(format("'{}", value)); } -ACCEPT_IMPL(IdExpr, ASTVisitor); -StarExpr::StarExpr(ExprPtr what) : Expr(), what(std::move(what)) {} -StarExpr::StarExpr(const StarExpr &expr) : Expr(expr), what(ast::clone(expr.what)) {} -std::string StarExpr::toString() const { - return wrapType(format("star {}", what->toString())); +StarExpr::StarExpr(Expr *expr) : AcceptorExtend(), expr(expr) {} +StarExpr::StarExpr(const StarExpr &expr, bool clean) + : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)) {} +std::string StarExpr::toString(int indent) const { + return wrapType(format("star {}", expr->toString(indent))); } -ACCEPT_IMPL(StarExpr, ASTVisitor); -KeywordStarExpr::KeywordStarExpr(ExprPtr what) : Expr(), what(std::move(what)) {} -KeywordStarExpr::KeywordStarExpr(const KeywordStarExpr &expr) - : Expr(expr), what(ast::clone(expr.what)) {} -std::string KeywordStarExpr::toString() const { - return wrapType(format("kwstar {}", what->toString())); +KeywordStarExpr::KeywordStarExpr(Expr *expr) : AcceptorExtend(), expr(expr) {} +KeywordStarExpr::KeywordStarExpr(const KeywordStarExpr &expr, bool clean) + : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)) {} +std::string KeywordStarExpr::toString(int indent) const { + return wrapType(format("kwstar {}", expr->toString(indent))); } -ACCEPT_IMPL(KeywordStarExpr, ASTVisitor); -TupleExpr::TupleExpr(std::vector items) : Expr(), items(std::move(items)) {} -TupleExpr::TupleExpr(const TupleExpr &expr) - : Expr(expr), items(ast::clone(expr.items)) {} -std::string TupleExpr::toString() const { +TupleExpr::TupleExpr(std::vector items) + : AcceptorExtend(), Items(std::move(items)) {} +TupleExpr::TupleExpr(const TupleExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} +std::string TupleExpr::toString(int) const { return wrapType(format("tuple {}", combine(items))); } -ACCEPT_IMPL(TupleExpr, ASTVisitor); -ListExpr::ListExpr(std::vector items) : Expr(), items(std::move(items)) {} -ListExpr::ListExpr(const ListExpr &expr) : Expr(expr), items(ast::clone(expr.items)) {} -std::string ListExpr::toString() const { +ListExpr::ListExpr(std::vector items) + : AcceptorExtend(), Items(std::move(items)) {} +ListExpr::ListExpr(const ListExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} +std::string ListExpr::toString(int) const { return wrapType(!items.empty() ? format("list {}", combine(items)) : "list"); } -ACCEPT_IMPL(ListExpr, ASTVisitor); -SetExpr::SetExpr(std::vector items) : Expr(), items(std::move(items)) {} -SetExpr::SetExpr(const SetExpr &expr) : Expr(expr), items(ast::clone(expr.items)) {} -std::string SetExpr::toString() const { +SetExpr::SetExpr(std::vector items) + : AcceptorExtend(), Items(std::move(items)) {} +SetExpr::SetExpr(const SetExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} +std::string SetExpr::toString(int) const { return wrapType(!items.empty() ? format("set {}", combine(items)) : "set"); } -ACCEPT_IMPL(SetExpr, ASTVisitor); -DictExpr::DictExpr(std::vector items) : Expr(), items(std::move(items)) { - for (auto &i : items) { - auto t = i->getTuple(); - seqassertn(t && t->items.size() == 2, "dictionary items are invalid"); +DictExpr::DictExpr(std::vector items) + : AcceptorExtend(), Items(std::move(items)) { + for (auto *i : *this) { + auto t = cast(i); + seqassertn(t && t->size() == 2, "dictionary items are invalid"); } } -DictExpr::DictExpr(const DictExpr &expr) : Expr(expr), items(ast::clone(expr.items)) {} -std::string DictExpr::toString() const { +DictExpr::DictExpr(const DictExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)) {} +std::string DictExpr::toString(int) const { return wrapType(!items.empty() ? format("dict {}", combine(items)) : "set"); } -ACCEPT_IMPL(DictExpr, ASTVisitor); -GeneratorBody GeneratorBody::clone() const { - return {ast::clone(vars), ast::clone(gen), ast::clone(conds)}; -} - -GeneratorExpr::GeneratorExpr(GeneratorExpr::GeneratorKind kind, ExprPtr expr, - std::vector loops) - : Expr(), kind(kind), expr(std::move(expr)), loops(std::move(loops)) {} -GeneratorExpr::GeneratorExpr(const GeneratorExpr &expr) - : Expr(expr), kind(expr.kind), expr(ast::clone(expr.expr)), - loops(ast::clone_nop(expr.loops)) {} -std::string GeneratorExpr::toString() const { +GeneratorExpr::GeneratorExpr(Cache *cache, GeneratorExpr::GeneratorKind kind, + Expr *expr, std::vector loops) + : AcceptorExtend(), kind(kind) { + this->cache = cache; + seqassert(!loops.empty() && cast(loops[0]), "bad generator constructor"); + loops.push_back(cache->N(cache->N(expr))); + formCompleteStmt(loops); +} +GeneratorExpr::GeneratorExpr(Cache *cache, Expr *key, Expr *expr, + std::vector loops) + : AcceptorExtend(), kind(GeneratorExpr::DictGenerator) { + this->cache = cache; + seqassert(!loops.empty() && cast(loops[0]), "bad generator constructor"); + Expr *t = cache->N(std::vector{key, expr}); + loops.push_back(cache->N(cache->N(t))); + formCompleteStmt(loops); +} +GeneratorExpr::GeneratorExpr(const GeneratorExpr &expr, bool clean) + : AcceptorExtend(expr, clean), kind(expr.kind), + loops(ast::clone(expr.loops, clean)) {} +std::string GeneratorExpr::toString(int indent) const { + auto pad = indent >= 0 ? ("\n" + std::string(indent + 2 * INDENT_SIZE, ' ')) : " "; std::string prefix; if (kind == GeneratorKind::ListGenerator) prefix = "list-"; if (kind == GeneratorKind::SetGenerator) prefix = "set-"; - std::string s; - for (auto &i : loops) { - std::string q; - for (auto &k : i.conds) - q += format(" (if {})", k->toString()); - s += format(" (for {} {}{})", i.vars->toString(), i.gen->toString(), q); + if (kind == GeneratorKind::DictGenerator) + prefix = "dict-"; + auto l = loops->toString(indent >= 0 ? indent + 2 * INDENT_SIZE : -1); + return wrapType(format("{}gen {}", prefix, l)); +} +Expr *GeneratorExpr::getFinalExpr() { + auto s = *(getFinalStmt()); + if (cast(s)) + return cast(s)->getExpr(); + return nullptr; +} +int GeneratorExpr::loopCount() const { + int cnt = 0; + for (Stmt *i = loops;;) { + if (auto sf = cast(i)) { + i = sf->getSuite(); + cnt++; + } else if (auto si = cast(i)) { + i = si->getIf(); + cnt++; + } else if (auto ss = cast(i)) { + if (ss->empty()) + break; + i = ss->back(); + } else + break; } - return wrapType(format("{}gen {}{}", prefix, expr->toString(), s)); -} -ACCEPT_IMPL(GeneratorExpr, ASTVisitor); - -DictGeneratorExpr::DictGeneratorExpr(ExprPtr key, ExprPtr expr, - std::vector loops) - : Expr(), key(std::move(key)), expr(std::move(expr)), loops(std::move(loops)) {} -DictGeneratorExpr::DictGeneratorExpr(const DictGeneratorExpr &expr) - : Expr(expr), key(ast::clone(expr.key)), expr(ast::clone(expr.expr)), - loops(ast::clone_nop(expr.loops)) {} -std::string DictGeneratorExpr::toString() const { - std::string s; - for (auto &i : loops) { - std::string q; - for (auto &k : i.conds) - q += format("( if {})", k->toString()); - s += format(" (for {} {}{})", i.vars->toString(), i.gen->toString(), q); + return cnt; +} +void GeneratorExpr::setFinalExpr(Expr *expr) { + *(getFinalStmt()) = cache->N(expr); +} +void GeneratorExpr::setFinalStmt(Stmt *stmt) { *(getFinalStmt()) = stmt; } +Stmt *GeneratorExpr::getFinalSuite() const { return loops; } +Stmt **GeneratorExpr::getFinalStmt() { + for (Stmt **i = &loops;;) { + if (auto sf = cast(*i)) + i = (Stmt **)&sf->suite; + else if (auto si = cast(*i)) + i = (Stmt **)&si->ifSuite; + else if (auto ss = cast(*i)) { + if (ss->empty()) + return i; + i = &(ss->back()); + } else + return i; + } + seqassert(false, "bad generator"); + return nullptr; +} +void GeneratorExpr::formCompleteStmt(const std::vector &loops) { + Stmt *final = nullptr; + for (size_t i = loops.size(); i-- > 0;) { + if (auto si = cast(loops[i])) + si->ifSuite = SuiteStmt::wrap(final); + else if (auto sf = cast(loops[i])) + sf->suite = SuiteStmt::wrap(final); + final = loops[i]; } - return wrapType(format("dict-gen {} {}{}", key->toString(), expr->toString(), s)); + this->loops = loops[0]; } -ACCEPT_IMPL(DictGeneratorExpr, ASTVisitor); -IfExpr::IfExpr(ExprPtr cond, ExprPtr ifexpr, ExprPtr elsexpr) - : Expr(), cond(std::move(cond)), ifexpr(std::move(ifexpr)), - elsexpr(std::move(elsexpr)) {} -IfExpr::IfExpr(const IfExpr &expr) - : Expr(expr), cond(ast::clone(expr.cond)), ifexpr(ast::clone(expr.ifexpr)), - elsexpr(ast::clone(expr.elsexpr)) {} -std::string IfExpr::toString() const { - return wrapType(format("if-expr {} {} {}", cond->toString(), ifexpr->toString(), - elsexpr->toString())); +IfExpr::IfExpr(Expr *cond, Expr *ifexpr, Expr *elsexpr) + : AcceptorExtend(), cond(cond), ifexpr(ifexpr), elsexpr(elsexpr) {} +IfExpr::IfExpr(const IfExpr &expr, bool clean) + : AcceptorExtend(expr, clean), cond(ast::clone(expr.cond, clean)), + ifexpr(ast::clone(expr.ifexpr, clean)), elsexpr(ast::clone(expr.elsexpr, clean)) { } -ACCEPT_IMPL(IfExpr, ASTVisitor); - -UnaryExpr::UnaryExpr(std::string op, ExprPtr expr) - : Expr(), op(std::move(op)), expr(std::move(expr)) {} -UnaryExpr::UnaryExpr(const UnaryExpr &expr) - : Expr(expr), op(expr.op), expr(ast::clone(expr.expr)) {} -std::string UnaryExpr::toString() const { - return wrapType(format("unary \"{}\" {}", op, expr->toString())); +std::string IfExpr::toString(int indent) const { + return wrapType(format("if-expr {} {} {}", cond->toString(indent), + ifexpr->toString(indent), elsexpr->toString(indent))); } -ACCEPT_IMPL(UnaryExpr, ASTVisitor); -BinaryExpr::BinaryExpr(ExprPtr lexpr, std::string op, ExprPtr rexpr, bool inPlace) - : Expr(), op(std::move(op)), lexpr(std::move(lexpr)), rexpr(std::move(rexpr)), - inPlace(inPlace) {} -BinaryExpr::BinaryExpr(const BinaryExpr &expr) - : Expr(expr), op(expr.op), lexpr(ast::clone(expr.lexpr)), - rexpr(ast::clone(expr.rexpr)), inPlace(expr.inPlace) {} -std::string BinaryExpr::toString() const { - return wrapType(format("binary \"{}\" {} {}{}", op, lexpr->toString(), - rexpr->toString(), inPlace ? " #:in-place" : "")); +UnaryExpr::UnaryExpr(std::string op, Expr *expr) + : AcceptorExtend(), op(std::move(op)), expr(expr) {} +UnaryExpr::UnaryExpr(const UnaryExpr &expr, bool clean) + : AcceptorExtend(expr, clean), op(expr.op), expr(ast::clone(expr.expr, clean)) {} +std::string UnaryExpr::toString(int indent) const { + return wrapType(format("unary \"{}\" {}", op, expr->toString(indent))); } -ACCEPT_IMPL(BinaryExpr, ASTVisitor); -ChainBinaryExpr::ChainBinaryExpr(std::vector> exprs) - : Expr(), exprs(std::move(exprs)) {} -ChainBinaryExpr::ChainBinaryExpr(const ChainBinaryExpr &expr) : Expr(expr) { +BinaryExpr::BinaryExpr(Expr *lexpr, std::string op, Expr *rexpr, bool inPlace) + : AcceptorExtend(), op(std::move(op)), lexpr(lexpr), rexpr(rexpr), + inPlace(inPlace) {} +BinaryExpr::BinaryExpr(const BinaryExpr &expr, bool clean) + : AcceptorExtend(expr, clean), op(expr.op), lexpr(ast::clone(expr.lexpr, clean)), + rexpr(ast::clone(expr.rexpr, clean)), inPlace(expr.inPlace) {} +std::string BinaryExpr::toString(int indent) const { + return wrapType(format("binary \"{}\" {} {}{}", op, lexpr->toString(indent), + rexpr->toString(indent), inPlace ? " #:in-place" : "")); +} + +ChainBinaryExpr::ChainBinaryExpr(std::vector> exprs) + : AcceptorExtend(), exprs(std::move(exprs)) {} +ChainBinaryExpr::ChainBinaryExpr(const ChainBinaryExpr &expr, bool clean) + : AcceptorExtend(expr, clean) { for (auto &e : expr.exprs) - exprs.emplace_back(make_pair(e.first, ast::clone(e.second))); + exprs.emplace_back(make_pair(e.first, ast::clone(e.second, clean))); } -std::string ChainBinaryExpr::toString() const { +std::string ChainBinaryExpr::toString(int indent) const { std::vector s; for (auto &i : exprs) - s.push_back(format("({} \"{}\")", i.first, i.second->toString())); + s.push_back(format("({} \"{}\")", i.first, i.second->toString(indent))); return wrapType(format("chain {}", join(s, " "))); } -ACCEPT_IMPL(ChainBinaryExpr, ASTVisitor); -PipeExpr::Pipe PipeExpr::Pipe::clone() const { return {op, ast::clone(expr)}; } +Pipe Pipe::clone(bool clean) const { return {op, ast::clone(expr, clean)}; } -PipeExpr::PipeExpr(std::vector items) - : Expr(), items(std::move(items)) { - for (auto &i : this->items) { - if (auto call = i.expr->getCall()) { - for (auto &a : call->args) - if (auto el = a.value->getEllipsis()) +PipeExpr::PipeExpr(std::vector items) + : AcceptorExtend(), Items(std::move(items)) { + for (auto &i : *this) { + if (auto call = cast(i.expr)) { + for (auto &a : *call) + if (auto el = cast(a.value)) el->mode = EllipsisExpr::PIPE; } } } -PipeExpr::PipeExpr(const PipeExpr &expr) - : Expr(expr), items(ast::clone_nop(expr.items)), inTypes(expr.inTypes) {} -void PipeExpr::validate() const {} -std::string PipeExpr::toString() const { +PipeExpr::PipeExpr(const PipeExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), + inTypes(expr.inTypes) {} +std::string PipeExpr::toString(int indent) const { std::vector s; for (auto &i : items) - s.push_back(format("({} \"{}\")", i.expr->toString(), i.op)); + s.push_back(format("({} \"{}\")", i.expr->toString(indent), i.op)); return wrapType(format("pipe {}", join(s, " "))); } -ACCEPT_IMPL(PipeExpr, ASTVisitor); -IndexExpr::IndexExpr(ExprPtr expr, ExprPtr index) - : Expr(), expr(std::move(expr)), index(std::move(index)) {} -IndexExpr::IndexExpr(const IndexExpr &expr) - : Expr(expr), expr(ast::clone(expr.expr)), index(ast::clone(expr.index)) {} -std::string IndexExpr::toString() const { - return wrapType(format("index {} {}", expr->toString(), index->toString())); +IndexExpr::IndexExpr(Expr *expr, Expr *index) + : AcceptorExtend(), expr(expr), index(index) {} +IndexExpr::IndexExpr(const IndexExpr &expr, bool clean) + : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)), + index(ast::clone(expr.index, clean)) {} +std::string IndexExpr::toString(int indent) const { + return wrapType( + format("index {} {}", expr->toString(indent), index->toString(indent))); } -ACCEPT_IMPL(IndexExpr, ASTVisitor); -CallExpr::Arg CallExpr::Arg::clone() const { return {name, ast::clone(value)}; } -CallExpr::Arg::Arg(const SrcInfo &info, const std::string &name, ExprPtr value) +CallArg CallArg::clone(bool clean) const { return {name, ast::clone(value, clean)}; } +CallArg::CallArg(const SrcInfo &info, const std::string &name, Expr *value) : name(name), value(value) { setSrcInfo(info); } -CallExpr::Arg::Arg(const std::string &name, ExprPtr value) : name(name), value(value) { +CallArg::CallArg(const std::string &name, Expr *value) : name(name), value(value) { if (value) setSrcInfo(value->getSrcInfo()); } -CallExpr::Arg::Arg(ExprPtr value) : CallExpr::Arg("", value) {} +CallArg::CallArg(Expr *value) : CallArg("", value) {} -CallExpr::CallExpr(const CallExpr &expr) - : Expr(expr), expr(ast::clone(expr.expr)), args(ast::clone_nop(expr.args)), - ordered(expr.ordered) {} -CallExpr::CallExpr(ExprPtr expr, std::vector args) - : Expr(), expr(std::move(expr)), args(std::move(args)), ordered(false) { - validate(); +CallExpr::CallExpr(const CallExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), + expr(ast::clone(expr.expr, clean)), ordered(expr.ordered), partial(expr.partial) { +} +CallExpr::CallExpr(Expr *expr, std::vector args) + : AcceptorExtend(), Items(std::move(args)), expr(expr), ordered(false), + partial(false) { } -CallExpr::CallExpr(ExprPtr expr, std::vector args) - : expr(std::move(expr)), ordered(false) { - for (auto &a : args) +CallExpr::CallExpr(Expr *expr, std::vector args) + : AcceptorExtend(), Items({}), expr(expr), ordered(false), partial(false) { + for (auto a : args) if (a) - this->args.push_back({"", std::move(a)}); - validate(); -} -void CallExpr::validate() const { - bool namesStarted = false, foundEllipsis = false; - for (auto &a : args) { - if (a.name.empty() && namesStarted && - !(CAST(a.value, KeywordStarExpr) || a.value->getEllipsis())) - E(Error::CALL_NAME_ORDER, a.value); - if (!a.name.empty() && (a.value->getStar() || CAST(a.value, KeywordStarExpr))) - E(Error::CALL_NAME_STAR, a.value); - if (a.value->getEllipsis() && foundEllipsis) - E(Error::CALL_ELLIPSIS, a.value); - foundEllipsis |= bool(a.value->getEllipsis()); - namesStarted |= !a.name.empty(); - } + items.emplace_back("", a); } -std::string CallExpr::toString() const { - std::string s; - for (auto &i : args) - if (i.name.empty()) - s += " " + i.value->toString(); - else - s += format("({}{})", i.value->toString(), - i.name.empty() ? "" : format(" #:name '{}", i.name)); - return wrapType(format("call {} {}", expr->toString(), s)); +std::string CallExpr::toString(int indent) const { + std::vector s; + auto pad = indent >= 0 ? ("\n" + std::string(indent + 2 * INDENT_SIZE, ' ')) : " "; + for (auto &i : *this) { + if (!i.name.empty()) + s.emplace_back(pad + format("#:name '{}", i.name)); + s.emplace_back(pad + + i.value->toString(indent >= 0 ? indent + 2 * INDENT_SIZE : -1)); + } + return wrapType(format("call{} {}{}", partial ? "-partial" : "", + expr->toString(indent), fmt::join(s, ""))); } -ACCEPT_IMPL(CallExpr, ASTVisitor); -DotExpr::DotExpr(ExprPtr expr, std::string member) - : Expr(), expr(std::move(expr)), member(std::move(member)) {} -DotExpr::DotExpr(const std::string &left, std::string member) - : Expr(), expr(std::make_shared(left)), member(std::move(member)) {} -DotExpr::DotExpr(const DotExpr &expr) - : Expr(expr), expr(ast::clone(expr.expr)), member(expr.member) {} -std::string DotExpr::toString() const { - return wrapType(format("dot {} '{}", expr->toString(), member)); +DotExpr::DotExpr(Expr *expr, std::string member) + : AcceptorExtend(), expr(expr), member(std::move(member)) {} +DotExpr::DotExpr(const DotExpr &expr, bool clean) + : AcceptorExtend(expr, clean), expr(ast::clone(expr.expr, clean)), + member(expr.member) {} +std::string DotExpr::toString(int indent) const { + return wrapType(format("dot {} '{}", expr->toString(indent), member)); } -ACCEPT_IMPL(DotExpr, ASTVisitor); -SliceExpr::SliceExpr(ExprPtr start, ExprPtr stop, ExprPtr step) - : Expr(), start(std::move(start)), stop(std::move(stop)), step(std::move(step)) {} -SliceExpr::SliceExpr(const SliceExpr &expr) - : Expr(expr), start(ast::clone(expr.start)), stop(ast::clone(expr.stop)), - step(ast::clone(expr.step)) {} -std::string SliceExpr::toString() const { +SliceExpr::SliceExpr(Expr *start, Expr *stop, Expr *step) + : AcceptorExtend(), start(start), stop(stop), step(step) {} +SliceExpr::SliceExpr(const SliceExpr &expr, bool clean) + : AcceptorExtend(expr, clean), start(ast::clone(expr.start, clean)), + stop(ast::clone(expr.stop, clean)), step(ast::clone(expr.step, clean)) {} +std::string SliceExpr::toString(int indent) const { return wrapType(format("slice{}{}{}", - start ? format(" #:start {}", start->toString()) : "", - stop ? format(" #:end {}", stop->toString()) : "", - step ? format(" #:step {}", step->toString()) : "")); + start ? format(" #:start {}", start->toString(indent)) : "", + stop ? format(" #:end {}", stop->toString(indent)) : "", + step ? format(" #:step {}", step->toString(indent)) : "")); } -ACCEPT_IMPL(SliceExpr, ASTVisitor); -EllipsisExpr::EllipsisExpr(EllipsisType mode) : Expr(), mode(mode) {} -std::string EllipsisExpr::toString() const { +EllipsisExpr::EllipsisExpr(EllipsisType mode) : AcceptorExtend(), mode(mode) {} +EllipsisExpr::EllipsisExpr(const EllipsisExpr &expr, bool clean) + : AcceptorExtend(expr, clean), mode(expr.mode) {} +std::string EllipsisExpr::toString(int) const { return wrapType(format( - "ellipsis{}", mode == PIPE ? " #:pipe" : (mode == PARTIAL ? "#:partial" : ""))); -} -ACCEPT_IMPL(EllipsisExpr, ASTVisitor); - -LambdaExpr::LambdaExpr(std::vector vars, ExprPtr expr) - : Expr(), vars(std::move(vars)), expr(std::move(expr)) {} -LambdaExpr::LambdaExpr(const LambdaExpr &expr) - : Expr(expr), vars(expr.vars), expr(ast::clone(expr.expr)) {} -std::string LambdaExpr::toString() const { - return wrapType(format("lambda ({}) {}", join(vars, " "), expr->toString())); -} -ACCEPT_IMPL(LambdaExpr, ASTVisitor); - -YieldExpr::YieldExpr() : Expr() {} -std::string YieldExpr::toString() const { return "yield-expr"; } -ACCEPT_IMPL(YieldExpr, ASTVisitor); - -AssignExpr::AssignExpr(ExprPtr var, ExprPtr expr) - : Expr(), var(std::move(var)), expr(std::move(expr)) {} -AssignExpr::AssignExpr(const AssignExpr &expr) - : Expr(expr), var(ast::clone(expr.var)), expr(ast::clone(expr.expr)) {} -std::string AssignExpr::toString() const { - return wrapType(format("assign-expr '{} {}", var->toString(), expr->toString())); + "ellipsis{}", mode == PIPE ? " #:pipe" : (mode == PARTIAL ? " #:partial" : ""))); +} + +LambdaExpr::LambdaExpr(std::vector vars, Expr *expr) + : AcceptorExtend(), Items(std::move(vars)), expr(expr) {} +LambdaExpr::LambdaExpr(const LambdaExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), + expr(ast::clone(expr.expr, clean)) {} +std::string LambdaExpr::toString(int indent) const { + std::vector as; + for (auto &a : items) + as.push_back(a.toString(indent)); + return wrapType(format("lambda ({}) {}", join(as, " "), expr->toString(indent))); +} + +YieldExpr::YieldExpr() : AcceptorExtend() {} +YieldExpr::YieldExpr(const YieldExpr &expr, bool clean) : AcceptorExtend(expr, clean) {} +std::string YieldExpr::toString(int) const { return "yield-expr"; } + +AssignExpr::AssignExpr(Expr *var, Expr *expr) + : AcceptorExtend(), var(var), expr(expr) {} +AssignExpr::AssignExpr(const AssignExpr &expr, bool clean) + : AcceptorExtend(expr, clean), var(ast::clone(expr.var, clean)), + expr(ast::clone(expr.expr, clean)) {} +std::string AssignExpr::toString(int indent) const { + return wrapType( + format("assign-expr '{} {}", var->toString(indent), expr->toString(indent))); } -ACCEPT_IMPL(AssignExpr, ASTVisitor); -RangeExpr::RangeExpr(ExprPtr start, ExprPtr stop) - : Expr(), start(std::move(start)), stop(std::move(stop)) {} -RangeExpr::RangeExpr(const RangeExpr &expr) - : Expr(expr), start(ast::clone(expr.start)), stop(ast::clone(expr.stop)) {} -std::string RangeExpr::toString() const { - return wrapType(format("range {} {}", start->toString(), stop->toString())); +RangeExpr::RangeExpr(Expr *start, Expr *stop) + : AcceptorExtend(), start(start), stop(stop) {} +RangeExpr::RangeExpr(const RangeExpr &expr, bool clean) + : AcceptorExtend(expr, clean), start(ast::clone(expr.start, clean)), + stop(ast::clone(expr.stop, clean)) {} +std::string RangeExpr::toString(int indent) const { + return wrapType( + format("range {} {}", start->toString(indent), stop->toString(indent))); } -ACCEPT_IMPL(RangeExpr, ASTVisitor); -StmtExpr::StmtExpr(std::vector> stmts, ExprPtr expr) - : Expr(), stmts(std::move(stmts)), expr(std::move(expr)) {} -StmtExpr::StmtExpr(std::shared_ptr stmt, ExprPtr expr) - : Expr(), expr(std::move(expr)) { - stmts.push_back(std::move(stmt)); +StmtExpr::StmtExpr(std::vector stmts, Expr *expr) + : AcceptorExtend(), Items(std::move(stmts)), expr(expr) {} +StmtExpr::StmtExpr(Stmt *stmt, Expr *expr) : AcceptorExtend(), Items({}), expr(expr) { + items.push_back(stmt); } -StmtExpr::StmtExpr(std::shared_ptr stmt, std::shared_ptr stmt2, - ExprPtr expr) - : Expr(), expr(std::move(expr)) { - stmts.push_back(std::move(stmt)); - stmts.push_back(std::move(stmt2)); +StmtExpr::StmtExpr(Stmt *stmt, Stmt *stmt2, Expr *expr) + : AcceptorExtend(), Items({}), expr(expr) { + items.push_back(stmt); + items.push_back(stmt2); } -StmtExpr::StmtExpr(const StmtExpr &expr) - : Expr(expr), stmts(ast::clone(expr.stmts)), expr(ast::clone(expr.expr)) {} -std::string StmtExpr::toString() const { - return wrapType(format("stmt-expr ({}) {}", combine(stmts, " "), expr->toString())); -} -ACCEPT_IMPL(StmtExpr, ASTVisitor); - -InstantiateExpr::InstantiateExpr(ExprPtr typeExpr, std::vector typeParams) - : Expr(), typeExpr(std::move(typeExpr)), typeParams(std::move(typeParams)) {} -InstantiateExpr::InstantiateExpr(ExprPtr typeExpr, ExprPtr typeParam) - : Expr(), typeExpr(std::move(typeExpr)) { - typeParams.push_back(std::move(typeParam)); -} -InstantiateExpr::InstantiateExpr(const InstantiateExpr &expr) - : Expr(expr), typeExpr(ast::clone(expr.typeExpr)), - typeParams(ast::clone(expr.typeParams)) {} -std::string InstantiateExpr::toString() const { +StmtExpr::StmtExpr(const StmtExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), + expr(ast::clone(expr.expr, clean)) {} +std::string StmtExpr::toString(int indent) const { + auto pad = indent >= 0 ? ("\n" + std::string(indent + 2 * INDENT_SIZE, ' ')) : " "; + std::vector s; + s.reserve(items.size()); + for (auto &i : items) + s.emplace_back(pad + i->toString(indent >= 0 ? indent + 2 * INDENT_SIZE : -1)); return wrapType( - format("instantiate {} {}", typeExpr->toString(), combine(typeParams))); -} -ACCEPT_IMPL(InstantiateExpr, ASTVisitor); - -StaticValue::Type getStaticGeneric(Expr *e) { - if (e && e->getIndex() && e->getIndex()->expr->isId("Static")) { - if (e->getIndex()->index && e->getIndex()->index->isId("str")) - return StaticValue::Type::STRING; - if (e->getIndex()->index && e->getIndex()->index->isId("int")) - return StaticValue::Type::INT; - return StaticValue::Type::NOT_SUPPORTED; + format("stmt-expr {} ({})", expr->toString(indent), fmt::join(s, ""))); +} + +InstantiateExpr::InstantiateExpr(Expr *expr, std::vector typeParams) + : AcceptorExtend(), Items(std::move(typeParams)), expr(expr) {} +InstantiateExpr::InstantiateExpr(Expr *expr, Expr *typeParam) + : AcceptorExtend(), Items({typeParam}), expr(expr) {} +InstantiateExpr::InstantiateExpr(const InstantiateExpr &expr, bool clean) + : AcceptorExtend(expr, clean), Items(ast::clone(expr.items, clean)), + expr(ast::clone(expr.expr, clean)) {} +std::string InstantiateExpr::toString(int indent) const { + return wrapType(format("instantiate {} {}", expr->toString(indent), combine(items))); +} + +bool isId(Expr *e, const std::string &s) { + auto ie = cast(e); + return ie && ie->getValue() == s; +} + +char getStaticGeneric(Expr *e) { + auto ie = cast(e); + if (!ie) + return 0; + if (cast(ie->getExpr()) && + cast(ie->getExpr())->getValue() == "Static") { + auto ixe = cast(ie->getIndex()); + if (!ixe) + return 0; + if (ixe->getValue() == "bool") + return 3; + if (ixe->getValue() == "str") + return 2; + if (ixe->getValue() == "int") + return 1; + return 4; } - return StaticValue::Type::NOT_STATIC; + return 0; } +const char ASTNode::NodeId = 0; +const char Expr::NodeId = 0; +ACCEPT_IMPL(NoneExpr, ASTVisitor); +ACCEPT_IMPL(BoolExpr, ASTVisitor); +ACCEPT_IMPL(IntExpr, ASTVisitor); +ACCEPT_IMPL(FloatExpr, ASTVisitor); +ACCEPT_IMPL(StringExpr, ASTVisitor); +ACCEPT_IMPL(IdExpr, ASTVisitor); +ACCEPT_IMPL(StarExpr, ASTVisitor); +ACCEPT_IMPL(KeywordStarExpr, ASTVisitor); +ACCEPT_IMPL(TupleExpr, ASTVisitor); +ACCEPT_IMPL(ListExpr, ASTVisitor); +ACCEPT_IMPL(SetExpr, ASTVisitor); +ACCEPT_IMPL(DictExpr, ASTVisitor); +ACCEPT_IMPL(GeneratorExpr, ASTVisitor); +ACCEPT_IMPL(IfExpr, ASTVisitor); +ACCEPT_IMPL(UnaryExpr, ASTVisitor); +ACCEPT_IMPL(BinaryExpr, ASTVisitor); +ACCEPT_IMPL(ChainBinaryExpr, ASTVisitor); +ACCEPT_IMPL(PipeExpr, ASTVisitor); +ACCEPT_IMPL(IndexExpr, ASTVisitor); +ACCEPT_IMPL(CallExpr, ASTVisitor); +ACCEPT_IMPL(DotExpr, ASTVisitor); +ACCEPT_IMPL(SliceExpr, ASTVisitor); +ACCEPT_IMPL(EllipsisExpr, ASTVisitor); +ACCEPT_IMPL(LambdaExpr, ASTVisitor); +ACCEPT_IMPL(YieldExpr, ASTVisitor); +ACCEPT_IMPL(AssignExpr, ASTVisitor); +ACCEPT_IMPL(RangeExpr, ASTVisitor); +ACCEPT_IMPL(StmtExpr, ASTVisitor); +ACCEPT_IMPL(InstantiateExpr, ASTVisitor); + } // namespace codon::ast diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index 877ea195..969f2810 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -8,685 +8,643 @@ #include #include +#include "codon/parser/ast/attr.h" +#include "codon/parser/ast/node.h" #include "codon/parser/ast/types.h" #include "codon/parser/common.h" +#include "codon/util/serialize.h" namespace codon::ast { -#define ACCEPT(X) \ - ExprPtr clone() const override; \ - void accept(X &visitor) override +#define ACCEPT(CLASS, VISITOR, ...) \ + static const char NodeId; \ + using AcceptorExtend::clone; \ + using AcceptorExtend::accept; \ + ASTNode *clone(bool c) const override; \ + void accept(VISITOR &visitor) override; \ + std::string toString(int) const override; \ + friend class TypecheckVisitor; \ + template friend struct CallbackASTVisitor; \ + friend struct ReplacingCallbackASTVisitor; \ + inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \ + SERIALIZE(CLASS, BASE(Expr), ##__VA_ARGS__) // Forward declarations -struct ASTVisitor; -struct BinaryExpr; -struct CallExpr; -struct DotExpr; -struct EllipsisExpr; -struct IdExpr; -struct IfExpr; -struct IndexExpr; -struct IntExpr; -struct ListExpr; -struct NoneExpr; -struct StarExpr; -struct StmtExpr; -struct StringExpr; -struct TupleExpr; -struct UnaryExpr; struct Stmt; -struct StaticValue { - std::variant value; - enum Type { NOT_STATIC = 0, STRING = 1, INT = 2, NOT_SUPPORTED = 3 } type; - bool evaluated; - - explicit StaticValue(Type); - // Static(bool); - explicit StaticValue(int64_t); - explicit StaticValue(std::string); - bool operator==(const StaticValue &s) const; - std::string toString() const; - int64_t getInt() const; - std::string getString() const; -}; - /** * A Seq AST expression. * Each AST expression is intended to be instantiated as a shared_ptr. */ -struct Expr : public codon::SrcObject { +struct Expr : public AcceptorExtend { using base_type = Expr; - // private: - /// Type of the expression. nullptr by default. - types::TypePtr type; - /// Flag that indicates if an expression describes a type (e.g. int or list[T]). - /// Used by transformation and type-checking stages. - bool isTypeExpr; - /// Flag that indicates if an expression is a compile-time static expression. - /// Such expression is of a form: - /// an integer (IntExpr) without any suffix that is within i64 range - /// a static generic - /// [-,not] a - /// a [+,-,*,//,%,and,or,==,!=,<,<=,>,>=] b - /// (note: and/or will NOT short-circuit) - /// a if cond else b - /// (note: cond is static, and is true if non-zero, false otherwise). - /// (note: both branches will be evaluated). - StaticValue staticValue; - /// Flag that indicates if all types in an expression are inferred (i.e. if a - /// type-checking procedure was successful). - bool done; - - /// Set of attributes. - int attributes; - - /// Original (pre-transformation) expression - std::shared_ptr origExpr; - -public: Expr(); - Expr(const Expr &expr) = default; - - /// Convert a node to an S-expression. - virtual std::string toString() const = 0; - /// Validate a node. Throw ParseASTException if a node is not valid. - void validate() const; - /// Deep copy a node. - virtual std::shared_ptr clone() const = 0; - /// Accept an AST visitor. - virtual void accept(ASTVisitor &visitor) = 0; + Expr(const Expr &); + Expr(const Expr &, bool); /// Get a node type. /// @return Type pointer or a nullptr if a type is not set. - types::TypePtr getType() const; - /// Set a node type. - void setType(types::TypePtr type); - /// @return true if a node describes a type expression. - bool isType() const; - /// Marks a node as a type expression. - void markType(); - /// True if a node is static expression. - bool isStatic() const; - - /// Allow pretty-printing to C++ streams. - friend std::ostream &operator<<(std::ostream &out, const Expr &expr) { - return out << expr.toString(); - } - - /// Convenience virtual functions to avoid unnecessary dynamic_cast calls. - virtual bool isId(const std::string &val) const { return false; } - virtual BinaryExpr *getBinary() { return nullptr; } - virtual CallExpr *getCall() { return nullptr; } - virtual DotExpr *getDot() { return nullptr; } - virtual EllipsisExpr *getEllipsis() { return nullptr; } - virtual IdExpr *getId() { return nullptr; } - virtual IfExpr *getIf() { return nullptr; } - virtual IndexExpr *getIndex() { return nullptr; } - virtual IntExpr *getInt() { return nullptr; } - virtual ListExpr *getList() { return nullptr; } - virtual NoneExpr *getNone() { return nullptr; } - virtual StarExpr *getStar() { return nullptr; } - virtual StmtExpr *getStmtExpr() { return nullptr; } - virtual StringExpr *getString() { return nullptr; } - virtual TupleExpr *getTuple() { return nullptr; } - virtual UnaryExpr *getUnary() { return nullptr; } - - /// Attribute helpers - bool hasAttr(int attr) const; - void setAttr(int attr); - + types::Type *getType() const { return type.get(); } + void setType(const types::TypePtr &t) { type = t; } + types::ClassType *getClassType() const; bool isDone() const { return done; } void setDone() { done = true; } + Expr *getOrigExpr() const { return origExpr; } + void setOrigExpr(Expr *orig) { origExpr = orig; } - /// @return Type name for IdExprs or instantiations. - std::string getTypeName(); + static const char NodeId; + SERIALIZE(Expr, BASE(ASTNode), /*type,*/ done, origExpr); + + Expr *operator<<(types::Type *t); protected: /// Add a type to S-expression string. std::string wrapType(const std::string &sexpr) const; + +private: + /// Type of the expression. nullptr by default. + types::TypePtr type; + /// Flag that indicates if all types in an expression are inferred (i.e. if a + /// type-checking procedure was successful). + bool done; + /// Original (pre-transformation) expression + Expr *origExpr; }; -using ExprPtr = std::shared_ptr; /// Function signature parameter helper node (name: type = defaultValue). struct Param : public codon::SrcObject { std::string name; - ExprPtr type; - ExprPtr defaultValue; + Expr *type; + Expr *defaultValue; enum { - Normal, + Value, Generic, HiddenGeneric } status; // 1 for normal generic, 2 for hidden generic - explicit Param(std::string name = "", ExprPtr type = nullptr, - ExprPtr defaultValue = nullptr, int generic = 0); - explicit Param(const SrcInfo &info, std::string name = "", ExprPtr type = nullptr, - ExprPtr defaultValue = nullptr, int generic = 0); + explicit Param(std::string name = "", Expr *type = nullptr, + Expr *defaultValue = nullptr, int generic = 0); + explicit Param(const SrcInfo &info, std::string name = "", Expr *type = nullptr, + Expr *defaultValue = nullptr, int generic = 0); - std::string toString() const; - Param clone() const; + std::string getName() const { return name; } + Expr *getType() const { return type; } + Expr *getDefault() const { return defaultValue; } + bool isValue() const { return status == Value; } + bool isGeneric() const { return status == Generic; } + bool isHiddenGeneric() const { return status == HiddenGeneric; } + std::pair getNameWithStars() const; + + SERIALIZE(Param, name, type, defaultValue); + Param clone(bool) const; + std::string toString(int) const; }; /// None expression. /// @li None -struct NoneExpr : public Expr { +struct NoneExpr : public AcceptorExtend { NoneExpr(); - NoneExpr(const NoneExpr &expr) = default; - - std::string toString() const override; - ACCEPT(ASTVisitor); + NoneExpr(const NoneExpr &, bool); - NoneExpr *getNone() override { return this; } + ACCEPT(NoneExpr, ASTVisitor); }; /// Bool expression (value). /// @li True -struct BoolExpr : public Expr { - bool value; +struct BoolExpr : public AcceptorExtend { + explicit BoolExpr(bool value = false); + BoolExpr(const BoolExpr &, bool); + + bool getValue() const; - explicit BoolExpr(bool value); - BoolExpr(const BoolExpr &expr) = default; + ACCEPT(BoolExpr, ASTVisitor, value); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + bool value; }; /// Int expression (value.suffix). /// @li 12 /// @li 13u /// @li 000_010b -struct IntExpr : public Expr { - /// Expression value is stored as a string that is parsed during the simplify stage. +struct IntExpr : public AcceptorExtend { + explicit IntExpr(int64_t intValue = 0); + explicit IntExpr(const std::string &value, std::string suffix = ""); + IntExpr(const IntExpr &, bool); + + bool hasStoredValue() const; + int64_t getValue() const; + std::pair getRawData() const; + + ACCEPT(IntExpr, ASTVisitor, value, suffix, intValue); + +private: + /// Expression value is stored as a string that is parsed during typechecking. std::string value; /// Number suffix (e.g. "u" for "123u"). std::string suffix; - /// Parsed value and sign for "normal" 64-bit integers. - std::unique_ptr intValue; - - explicit IntExpr(int64_t intValue); - explicit IntExpr(const std::string &value, std::string suffix = ""); - IntExpr(const IntExpr &expr); - - std::string toString() const override; - ACCEPT(ASTVisitor); - - IntExpr *getInt() override { return this; } + std::optional intValue; }; /// Float expression (value.suffix). /// @li 12.1 /// @li 13.15z /// @li e-12 -struct FloatExpr : public Expr { - /// Expression value is stored as a string that is parsed during the simplify stage. +struct FloatExpr : public AcceptorExtend { + explicit FloatExpr(double floatValue = 0.0); + explicit FloatExpr(const std::string &value, std::string suffix = ""); + FloatExpr(const FloatExpr &, bool); + + bool hasStoredValue() const; + double getValue() const; + std::pair getRawData() const; + + ACCEPT(FloatExpr, ASTVisitor, value, suffix, floatValue); + +private: + /// Expression value is stored as a string that is parsed during typechecking. std::string value; /// Number suffix (e.g. "u" for "123u"). std::string suffix; - /// Parsed value for 64-bit floats. - std::unique_ptr floatValue; - - explicit FloatExpr(double floatValue); - explicit FloatExpr(const std::string &value, std::string suffix = ""); - FloatExpr(const FloatExpr &expr); - - std::string toString() const override; - ACCEPT(ASTVisitor); + std::optional floatValue; }; /// String expression (prefix"value"). /// @li s'ACGT' /// @li "fff" -struct StringExpr : public Expr { +struct StringExpr : public AcceptorExtend { + struct FormatSpec { + std::string text; + std::string conversion; + std::string spec; + + SERIALIZE(FormatSpec, text, conversion, spec); + }; + // Vector of {value, prefix} strings. - std::vector> strings; + struct String : public SrcObject { + std::string value; + std::string prefix; + Expr *expr; + FormatSpec format; + + String(std::string v, std::string p = "", Expr *e = nullptr) + : value(std::move(v)), prefix(std::move(p)), expr(e), format() {} - explicit StringExpr(std::string value, std::string prefix = ""); - explicit StringExpr(std::vector> strings); - StringExpr(const StringExpr &expr) = default; + SERIALIZE(String, value, prefix, expr, format); + }; - std::string toString() const override; - ACCEPT(ASTVisitor); + explicit StringExpr(std::string value = "", std::string prefix = ""); + explicit StringExpr(std::vector strings); + StringExpr(const StringExpr &, bool); - StringExpr *getString() override { return this; } std::string getValue() const; + bool isSimple() const; + + ACCEPT(StringExpr, ASTVisitor, strings); + +private: + std::vector strings; + + auto begin() { return strings.begin(); } + auto end() { return strings.end(); } + + friend class ScopingVisitor; }; /// Identifier expression (value). -struct IdExpr : public Expr { - std::string value; +struct IdExpr : public AcceptorExtend { + explicit IdExpr(std::string value = ""); + IdExpr(const IdExpr &, bool); - explicit IdExpr(std::string value); - IdExpr(const IdExpr &expr) = default; + std::string getValue() const { return value; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(IdExpr, ASTVisitor, value); + +private: + std::string value; - bool isId(const std::string &val) const override { return this->value == val; } - IdExpr *getId() override { return this; } + void setValue(const std::string &s) { value = s; } + + friend class ScopingVisitor; }; /// Star (unpacking) expression (*what). /// @li *args -struct StarExpr : public Expr { - ExprPtr what; +struct StarExpr : public AcceptorExtend { + explicit StarExpr(Expr *what = nullptr); + StarExpr(const StarExpr &, bool); - explicit StarExpr(ExprPtr what); - StarExpr(const StarExpr &expr); + Expr *getExpr() const { return expr; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(StarExpr, ASTVisitor, expr); - StarExpr *getStar() override { return this; } +private: + Expr *expr; }; /// KeywordStar (unpacking) expression (**what). /// @li **kwargs -struct KeywordStarExpr : public Expr { - ExprPtr what; +struct KeywordStarExpr : public AcceptorExtend { + explicit KeywordStarExpr(Expr *what = nullptr); + KeywordStarExpr(const KeywordStarExpr &, bool); + + Expr *getExpr() const { return expr; } - explicit KeywordStarExpr(ExprPtr what); - KeywordStarExpr(const KeywordStarExpr &expr); + ACCEPT(KeywordStarExpr, ASTVisitor, expr); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + Expr *expr; }; /// Tuple expression ((items...)). /// @li (1, a) -struct TupleExpr : public Expr { - std::vector items; - - explicit TupleExpr(std::vector items = {}); - TupleExpr(const TupleExpr &expr); +struct TupleExpr : public AcceptorExtend, Items { + explicit TupleExpr(std::vector items = {}); + TupleExpr(const TupleExpr &, bool); - std::string toString() const override; - ACCEPT(ASTVisitor); - - TupleExpr *getTuple() override { return this; } + ACCEPT(TupleExpr, ASTVisitor, items); }; /// List expression ([items...]). /// @li [1, 2] -struct ListExpr : public Expr { - std::vector items; - - explicit ListExpr(std::vector items); - ListExpr(const ListExpr &expr); +struct ListExpr : public AcceptorExtend, Items { + explicit ListExpr(std::vector items = {}); + ListExpr(const ListExpr &, bool); - std::string toString() const override; - ACCEPT(ASTVisitor); - - ListExpr *getList() override { return this; } + ACCEPT(ListExpr, ASTVisitor, items); }; /// Set expression ({items...}). /// @li {1, 2} -struct SetExpr : public Expr { - std::vector items; - - explicit SetExpr(std::vector items); - SetExpr(const SetExpr &expr); +struct SetExpr : public AcceptorExtend, Items { + explicit SetExpr(std::vector items = {}); + SetExpr(const SetExpr &, bool); - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(SetExpr, ASTVisitor, items); }; /// Dictionary expression ({(key: value)...}). /// Each (key, value) pair is stored as a TupleExpr. /// @li {'s': 1, 't': 2} -struct DictExpr : public Expr { - std::vector items; - - explicit DictExpr(std::vector items); - DictExpr(const DictExpr &expr); - - std::string toString() const override; - ACCEPT(ASTVisitor); -}; - -/// Generator body node helper [for vars in gen (if conds)...]. -/// @li for i in lst if a if b -struct GeneratorBody { - ExprPtr vars; - ExprPtr gen; - std::vector conds; +struct DictExpr : public AcceptorExtend, Items { + explicit DictExpr(std::vector items = {}); + DictExpr(const DictExpr &, bool); - GeneratorBody clone() const; + ACCEPT(DictExpr, ASTVisitor, items); }; /// Generator or comprehension expression [(expr (loops...))]. /// @li [i for i in j] /// @li (f + 1 for j in k if j for f in j) -struct GeneratorExpr : public Expr { +struct GeneratorExpr : public AcceptorExtend { /// Generator kind: normal generator, list comprehension, set comprehension. - enum GeneratorKind { Generator, ListGenerator, SetGenerator }; + enum GeneratorKind { + Generator, + ListGenerator, + SetGenerator, + TupleGenerator, + DictGenerator + }; - GeneratorKind kind; - ExprPtr expr; - std::vector loops; + GeneratorExpr() : kind(Generator), loops(nullptr) {} + GeneratorExpr(Cache *cache, GeneratorKind kind, Expr *expr, + std::vector loops); + GeneratorExpr(Cache *cache, Expr *key, Expr *expr, std::vector loops); + GeneratorExpr(const GeneratorExpr &, bool); - GeneratorExpr(GeneratorKind kind, ExprPtr expr, std::vector loops); - GeneratorExpr(const GeneratorExpr &expr); + int loopCount() const; + Stmt *getFinalSuite() const; + Expr *getFinalExpr(); - std::string toString() const override; - ACCEPT(ASTVisitor); -}; + ACCEPT(GeneratorExpr, ASTVisitor, kind, loops); -/// Dictionary comprehension expression [{key: expr (loops...)}]. -/// @li {i: j for i, j in z.items()} -struct DictGeneratorExpr : public Expr { - ExprPtr key, expr; - std::vector loops; +private: + GeneratorKind kind; + Stmt *loops; - DictGeneratorExpr(ExprPtr key, ExprPtr expr, std::vector loops); - DictGeneratorExpr(const DictGeneratorExpr &expr); + Stmt **getFinalStmt(); + void setFinalExpr(Expr *); + void setFinalStmt(Stmt *); + void formCompleteStmt(const std::vector &); - std::string toString() const override; - ACCEPT(ASTVisitor); + friend class TranslateVisitor; }; /// Conditional expression [cond if ifexpr else elsexpr]. /// @li 1 if a else 2 -struct IfExpr : public Expr { - ExprPtr cond, ifexpr, elsexpr; +struct IfExpr : public AcceptorExtend { + IfExpr(Expr *cond = nullptr, Expr *ifexpr = nullptr, Expr *elsexpr = nullptr); + IfExpr(const IfExpr &, bool); - IfExpr(ExprPtr cond, ExprPtr ifexpr, ExprPtr elsexpr); - IfExpr(const IfExpr &expr); + Expr *getCond() const { return cond; } + Expr *getIf() const { return ifexpr; } + Expr *getElse() const { return elsexpr; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(IfExpr, ASTVisitor, cond, ifexpr, elsexpr); - IfExpr *getIf() override { return this; } +private: + Expr *cond, *ifexpr, *elsexpr; }; /// Unary expression [op expr]. /// @li -56 -struct UnaryExpr : public Expr { - std::string op; - ExprPtr expr; +struct UnaryExpr : public AcceptorExtend { + UnaryExpr(std::string op = "", Expr *expr = nullptr); + UnaryExpr(const UnaryExpr &, bool); - UnaryExpr(std::string op, ExprPtr expr); - UnaryExpr(const UnaryExpr &expr); + std::string getOp() const { return op; } + Expr *getExpr() const { return expr; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(UnaryExpr, ASTVisitor, op, expr); - UnaryExpr *getUnary() override { return this; } +private: + std::string op; + Expr *expr; }; /// Binary expression [lexpr op rexpr]. /// @li 1 + 2 /// @li 3 or 4 -struct BinaryExpr : public Expr { - std::string op; - ExprPtr lexpr, rexpr; +struct BinaryExpr : public AcceptorExtend { + BinaryExpr(Expr *lexpr = nullptr, std::string op = "", Expr *rexpr = nullptr, + bool inPlace = false); + BinaryExpr(const BinaryExpr &, bool); - /// True if an expression modifies lhs in-place (e.g. a += b). - bool inPlace; + std::string getOp() const { return op; } + Expr *getLhs() const { return lexpr; } + Expr *getRhs() const { return rexpr; } + bool isInPlace() const { return inPlace; } - BinaryExpr(ExprPtr lexpr, std::string op, ExprPtr rexpr, bool inPlace = false); - BinaryExpr(const BinaryExpr &expr); + ACCEPT(BinaryExpr, ASTVisitor, op, lexpr, rexpr); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + std::string op; + Expr *lexpr, *rexpr; - BinaryExpr *getBinary() override { return this; } + /// True if an expression modifies lhs in-place (e.g. a += b). + bool inPlace; }; /// Chained binary expression. /// @li 1 <= x <= 2 -struct ChainBinaryExpr : public Expr { - std::vector> exprs; +struct ChainBinaryExpr : public AcceptorExtend { + ChainBinaryExpr(std::vector> exprs = {}); + ChainBinaryExpr(const ChainBinaryExpr &, bool); - ChainBinaryExpr(std::vector> exprs); - ChainBinaryExpr(const ChainBinaryExpr &expr); + ACCEPT(ChainBinaryExpr, ASTVisitor, exprs); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + std::vector> exprs; +}; + +struct Pipe { + std::string op; + Expr *expr; + + SERIALIZE(Pipe, op, expr); + Pipe clone(bool) const; }; /// Pipe expression [(op expr)...]. /// op is either "" (only the first item), "|>" or "||>". /// @li a |> b ||> c -struct PipeExpr : public Expr { - struct Pipe { - std::string op; - ExprPtr expr; +struct PipeExpr : public AcceptorExtend, Items { + explicit PipeExpr(std::vector items = {}); + PipeExpr(const PipeExpr &, bool); - Pipe clone() const; - }; + ACCEPT(PipeExpr, ASTVisitor, items); - std::vector items; +private: /// Output type of a "prefix" pipe ending at the index position. /// Example: for a |> b |> c, inTypes[1] is typeof(a |> b). std::vector inTypes; - - explicit PipeExpr(std::vector items); - PipeExpr(const PipeExpr &expr); - - std::string toString() const override; - void validate() const; - ACCEPT(ASTVisitor); }; /// Index expression (expr[index]). /// @li a[5] -struct IndexExpr : public Expr { - ExprPtr expr, index; +struct IndexExpr : public AcceptorExtend { + IndexExpr(Expr *expr = nullptr, Expr *index = nullptr); + IndexExpr(const IndexExpr &, bool); - IndexExpr(ExprPtr expr, ExprPtr index); - IndexExpr(const IndexExpr &expr); + Expr *getExpr() const { return expr; } + Expr *getIndex() const { return index; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(IndexExpr, ASTVisitor, expr, index); - IndexExpr *getIndex() override { return this; } +private: + Expr *expr, *index; }; -/// Call expression (expr((name=value)...)). -/// @li a(1, b=2) -struct CallExpr : public Expr { - /// Each argument can have a name (e.g. foo(1, b=5)) - struct Arg : public codon::SrcObject { - std::string name; - ExprPtr value; +struct CallArg : public codon::SrcObject { + std::string name; + Expr *value; - Arg clone() const; + CallArg(const std::string &name = "", Expr *value = nullptr); + CallArg(const SrcInfo &info, const std::string &name, Expr *value); + CallArg(Expr *value); - Arg(const SrcInfo &info, const std::string &name, ExprPtr value); - Arg(const std::string &name, ExprPtr value); - Arg(ExprPtr value); - }; + std::string getName() const { return name; } + Expr *getExpr() const { return value; } + operator Expr *() const { return value; } - ExprPtr expr; - std::vector args; - /// True if type-checker has processed and re-ordered args. - bool ordered; + SERIALIZE(CallArg, name, value); + CallArg clone(bool) const; +}; - CallExpr(ExprPtr expr, std::vector args = {}); +/// Call expression (expr((name=value)...)). +/// @li a(1, b=2) +struct CallExpr : public AcceptorExtend, Items { + /// Each argument can have a name (e.g. foo(1, b=5)) + CallExpr(Expr *expr = nullptr, std::vector args = {}); /// Convenience constructors - CallExpr(ExprPtr expr, std::vector args); + CallExpr(Expr *expr, std::vector args); template - CallExpr(ExprPtr expr, ExprPtr arg, Ts... args) - : CallExpr(expr, std::vector{arg, args...}) {} - CallExpr(const CallExpr &expr); + CallExpr(Expr *expr, Expr *arg, Ts... args) + : CallExpr(expr, std::vector{arg, args...}) {} + CallExpr(const CallExpr &, bool); - void validate() const; - std::string toString() const override; - ACCEPT(ASTVisitor); + Expr *getExpr() const { return expr; } + bool isOrdered() const { return ordered; } + bool isPartial() const { return partial; } - CallExpr *getCall() override { return this; } + ACCEPT(CallExpr, ASTVisitor, expr, items, ordered, partial); + +private: + Expr *expr; + /// True if type-checker has processed and re-ordered args. + bool ordered; + /// True if the call is partial + bool partial = false; }; /// Dot (access) expression (expr.member). /// @li a.b -struct DotExpr : public Expr { - ExprPtr expr; - std::string member; +struct DotExpr : public AcceptorExtend { + DotExpr() : expr(nullptr), member() {} + DotExpr(Expr *expr, std::string member); + DotExpr(const DotExpr &, bool); - DotExpr(ExprPtr expr, std::string member); - /// Convenience constructor. - DotExpr(const std::string &left, std::string member); - DotExpr(const DotExpr &expr); + Expr *getExpr() const { return expr; } + std::string getMember() const { return member; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(DotExpr, ASTVisitor, expr, member); - DotExpr *getDot() override { return this; } +private: + Expr *expr; + std::string member; }; /// Slice expression (st:stop:step). /// @li 1:10:3 /// @li s::-1 /// @li ::: -struct SliceExpr : public Expr { - /// Any of these can be nullptr to account for partial slices. - ExprPtr start, stop, step; +struct SliceExpr : public AcceptorExtend { + SliceExpr(Expr *start = nullptr, Expr *stop = nullptr, Expr *step = nullptr); + SliceExpr(const SliceExpr &, bool); + + Expr *getStart() const { return start; } + Expr *getStop() const { return stop; } + Expr *getStep() const { return step; } - SliceExpr(ExprPtr start, ExprPtr stop, ExprPtr step); - SliceExpr(const SliceExpr &expr); + ACCEPT(SliceExpr, ASTVisitor, start, stop, step); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + /// Any of these can be nullptr to account for partial slices. + Expr *start, *stop, *step; }; /// Ellipsis expression. /// @li ... -struct EllipsisExpr : public Expr { +struct EllipsisExpr : public AcceptorExtend { /// True if this is a target partial argument within a PipeExpr. /// If true, this node will be handled differently during the type-checking stage. - enum EllipsisType { PIPE, PARTIAL, STANDALONE } mode; + enum EllipsisType { PIPE, PARTIAL, STANDALONE }; explicit EllipsisExpr(EllipsisType mode = STANDALONE); - EllipsisExpr(const EllipsisExpr &expr) = default; + EllipsisExpr(const EllipsisExpr &, bool); + + EllipsisType getMode() const { return mode; } + bool isStandalone() const { return mode == STANDALONE; } + bool isPipe() const { return mode == PIPE; } + bool isPartial() const { return mode == PARTIAL; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(EllipsisExpr, ASTVisitor, mode); - EllipsisExpr *getEllipsis() override { return this; } +private: + EllipsisType mode; + + friend class PipeExpr; }; /// Lambda expression (lambda (vars)...: expr). /// @li lambda a, b: a + b -struct LambdaExpr : public Expr { - std::vector vars; - ExprPtr expr; +struct LambdaExpr : public AcceptorExtend, Items { + LambdaExpr(std::vector vars = {}, Expr *expr = nullptr); + LambdaExpr(const LambdaExpr &, bool); + + Expr *getExpr() const { return expr; } - LambdaExpr(std::vector vars, ExprPtr expr); - LambdaExpr(const LambdaExpr &); + ACCEPT(LambdaExpr, ASTVisitor, expr, items); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + Expr *expr; }; /// Yield (send to generator) expression. /// @li (yield) -struct YieldExpr : public Expr { +struct YieldExpr : public AcceptorExtend { YieldExpr(); - YieldExpr(const YieldExpr &expr) = default; + YieldExpr(const YieldExpr &, bool); - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(YieldExpr, ASTVisitor); }; /// Assignment (walrus) expression (var := expr). /// @li a := 5 + 3 -struct AssignExpr : public Expr { - ExprPtr var, expr; +struct AssignExpr : public AcceptorExtend { + AssignExpr(Expr *var = nullptr, Expr *expr = nullptr); + AssignExpr(const AssignExpr &, bool); - AssignExpr(ExprPtr var, ExprPtr expr); - AssignExpr(const AssignExpr &); + Expr *getVar() const { return var; } + Expr *getExpr() const { return expr; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(AssignExpr, ASTVisitor, var, expr); + +private: + Expr *var, *expr; }; /// Range expression (start ... end). /// Used only in match-case statements. /// @li 1 ... 2 -struct RangeExpr : public Expr { - ExprPtr start, stop; +struct RangeExpr : public AcceptorExtend { + RangeExpr(Expr *start = nullptr, Expr *stop = nullptr); + RangeExpr(const RangeExpr &, bool); + + Expr *getStart() const { return start; } + Expr *getStop() const { return stop; } - RangeExpr(ExprPtr start, ExprPtr stop); - RangeExpr(const RangeExpr &); + ACCEPT(RangeExpr, ASTVisitor, start, stop); - std::string toString() const override; - ACCEPT(ASTVisitor); +private: + Expr *start, *stop; }; -/// The following nodes are created after the simplify stage. +/// The following nodes are created during typechecking. /// Statement expression (stmts...; expr). /// Statements are evaluated only if the expression is evaluated /// (to support short-circuiting). /// @li (a = 1; b = 2; a + b) -struct StmtExpr : public Expr { - std::vector> stmts; - ExprPtr expr; +struct StmtExpr : public AcceptorExtend, Items { + StmtExpr(Stmt *stmt = nullptr, Expr *expr = nullptr); + StmtExpr(std::vector stmts, Expr *expr); + StmtExpr(Stmt *stmt, Stmt *stmt2, Expr *expr); + StmtExpr(const StmtExpr &, bool); - StmtExpr(std::vector> stmts, ExprPtr expr); - StmtExpr(std::shared_ptr stmt, ExprPtr expr); - StmtExpr(std::shared_ptr stmt, std::shared_ptr stmt2, ExprPtr expr); - StmtExpr(const StmtExpr &expr); + Expr *getExpr() const { return expr; } - std::string toString() const override; - ACCEPT(ASTVisitor); + ACCEPT(StmtExpr, ASTVisitor, expr, items); - StmtExpr *getStmtExpr() override { return this; } +private: + Expr *expr; }; /// Static tuple indexing expression (expr[index]). /// @li (1, 2, 3)[2] -struct InstantiateExpr : Expr { - ExprPtr typeExpr; - std::vector typeParams; - - InstantiateExpr(ExprPtr typeExpr, std::vector typeParams); +struct InstantiateExpr : public AcceptorExtend, Items { + InstantiateExpr(Expr *expr = nullptr, std::vector typeParams = {}); /// Convenience constructor for a single type parameter. - InstantiateExpr(ExprPtr typeExpr, ExprPtr typeParam); - InstantiateExpr(const InstantiateExpr &expr); + InstantiateExpr(Expr *expr, Expr *typeParam); + InstantiateExpr(const InstantiateExpr &, bool); - std::string toString() const override; - ACCEPT(ASTVisitor); -}; + Expr *getExpr() const { return expr; } -#undef ACCEPT + ACCEPT(InstantiateExpr, ASTVisitor, expr, items); -enum ExprAttr { - SequenceItem, - StarSequenceItem, - List, - Set, - Dict, - Partial, - Dominated, - StarArgument, - KwStarArgument, - OrderedCall, - ExternVar, - __LAST__ +private: + Expr *expr; }; -StaticValue::Type getStaticGeneric(Expr *e); +#undef ACCEPT + +bool isId(Expr *e, const std::string &s); +char getStaticGeneric(Expr *e); } // namespace codon::ast -template -struct fmt::formatter< - T, std::enable_if_t::value, char>> - : fmt::ostream_formatter {}; - template <> -struct fmt::formatter : fmt::formatter { +struct fmt::formatter : fmt::formatter { template - auto format(const codon::ast::CallExpr::Arg &p, FormatContext &ctx) const + auto format(const codon::ast::CallArg &p, FormatContext &ctx) const -> decltype(ctx.out()) { return fmt::format_to(ctx.out(), "({}{})", - p.name.empty() ? "" : fmt::format("{} = ", p.name), p.value); + p.name.empty() ? "" : fmt::format("{} = ", p.name), + p.value ? p.value->toString(0) : "-"); } }; @@ -695,17 +653,30 @@ struct fmt::formatter : fmt::formatter { template auto format(const codon::ast::Param &p, FormatContext &ctx) const -> decltype(ctx.out()) { - return fmt::format_to(ctx.out(), "{}", p.toString()); + return fmt::format_to(ctx.out(), "{}", p.toString(0)); } }; -template -struct fmt::formatter< - T, std::enable_if_t< - std::is_convertible>::value, char>> - : fmt::formatter { - template - auto format(const T &p, FormatContext &ctx) const -> decltype(ctx.out()) { - return fmt::format_to(ctx.out(), "{}", p ? p->toString() : ""); +namespace tser { +using Archive = BinaryArchive; +static void operator<<(codon::ast::Expr *t, Archive &a) { + using S = codon::PolymorphicSerializer; + a.save(t != nullptr); + if (t) { + auto typ = t->dynamicNodeId(); + auto key = S::_serializers[(void *)typ]; + a.save(key); + S::save(key, t, a); } -}; +} +static void operator>>(codon::ast::Expr *&t, Archive &a) { + using S = codon::PolymorphicSerializer; + bool empty = a.load(); + if (!empty) { + std::string key = a.load(); + S::load(key, t, a); + } else { + t = nullptr; + } +} +} // namespace tser diff --git a/codon/parser/ast/node.h b/codon/parser/ast/node.h new file mode 100644 index 00000000..a02279ca --- /dev/null +++ b/codon/parser/ast/node.h @@ -0,0 +1,115 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#pragma once + +#include +#include + +#include "codon/cir/base.h" + +namespace codon::ast { + +using ir::cast; + +// Forward declarations +struct Cache; +struct ASTVisitor; + +struct ASTNode : public ir::Node { + static const char NodeId; + using ir::Node::Node; + + /// See LLVM documentation. + static const void *nodeId() { return &NodeId; } + const void *dynamicNodeId() const override { return &NodeId; } + /// See LLVM documentation. + virtual bool isConvertible(const void *other) const override { + return other == nodeId() || ir::Node::isConvertible(other); + } + + Cache *cache; + + ASTNode() = default; + ASTNode(const ASTNode &); + virtual ~ASTNode() = default; + + /// Convert a node to an S-expression. + virtual std::string toString(int) const = 0; + virtual std::string toString() const { return toString(-1); } + + /// Deep copy a node. + virtual ASTNode *clone(bool clean) const = 0; + ASTNode *clone() const { return clone(false); } + + /// Accept an AST visitor. + virtual void accept(ASTVisitor &visitor) = 0; + + /// Allow pretty-printing to C++ streams. + friend std::ostream &operator<<(std::ostream &out, const ASTNode &expr) { + return out << expr.toString(); + } + + void setAttribute(const std::string &key, std::unique_ptr value) { + attributes[key] = std::move(value); + } + void setAttribute(const std::string &key, const std::string &value) { + attributes[key] = std::make_unique(value); + } + void setAttribute(const std::string &key, int64_t value) { + attributes[key] = std::make_unique(value); + } + void setAttribute(const std::string &key) { + attributes[key] = std::make_unique(); + } + + inline decltype(auto) members() { + int a = 0; + return std::tie(a); + } +}; + +template void E(error::Error e, ASTNode *o, const TA &...args) { + E(e, o->getSrcInfo(), args...); +} +template void E(error::Error e, const ASTNode &o, const TA &...args) { + E(e, o.getSrcInfo(), args...); +} + +template class AcceptorExtend : public Parent { +public: + using Parent::Parent; + + /// See LLVM documentation. + static const void *nodeId() { return &Derived::NodeId; } + const void *dynamicNodeId() const override { return &Derived::NodeId; } + /// See LLVM documentation. + virtual bool isConvertible(const void *other) const override { + return other == nodeId() || Parent::isConvertible(other); + } +}; + +template struct Items { + Items(std::vector items) : items(std::move(items)) {} + const T &operator[](int i) const { return items[i]; } + T &operator[](int i) { return items[i]; } + auto begin() { return items.begin(); } + auto end() { return items.end(); } + auto begin() const { return items.begin(); } + auto end() const { return items.end(); } + auto size() const { return items.size(); } + bool empty() const { return items.empty(); } + const T &front() const { return items.front(); } + const T &back() const { return items.back(); } + T &front() { return items.front(); } + T &back() { return items.back(); } + +protected: + std::vector items; +}; + +} // namespace codon::ast + +template +struct fmt::formatter< + T, std::enable_if_t::value, char>> + : fmt::ostream_formatter {}; diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index b2490c19..5bc980ea 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -11,452 +11,356 @@ #include "codon/parser/visitors/visitor.h" #define ACCEPT_IMPL(T, X) \ - StmtPtr T::clone() const { return std::make_shared(*this); } \ - void T::accept(X &visitor) { visitor.visit(this); } + ASTNode *T::clone(bool clean) const { return cache->N(*this, clean); } \ + void T::accept(X &visitor) { visitor.visit(this); } \ + const char T::NodeId = 0; using fmt::format; using namespace codon::error; -const int INDENT_SIZE = 2; - namespace codon::ast { -Stmt::Stmt() : done(false), age(-1) {} -Stmt::Stmt(const codon::SrcInfo &s) : done(false), age(-1) { setSrcInfo(s); } -std::string Stmt::toString() const { return toString(-1); } -void Stmt::validate() const {} - -SuiteStmt::SuiteStmt(std::vector stmts) : Stmt() { - for (auto &s : stmts) - flatten(std::move(s), this->stmts); -} -SuiteStmt::SuiteStmt(const SuiteStmt &stmt) - : Stmt(stmt), stmts(ast::clone(stmt.stmts)) {} +Stmt::Stmt() : AcceptorExtend(), done(false) {} +Stmt::Stmt(const Stmt &stmt) : AcceptorExtend(stmt), done(stmt.done) {} +Stmt::Stmt(const codon::SrcInfo &s) : AcceptorExtend() { setSrcInfo(s); } +Stmt::Stmt(const Stmt &expr, bool clean) : AcceptorExtend(expr) { + if (clean) + done = false; +} +std::string Stmt::wrapStmt(const std::string &s) const { + // if (auto a = ir::Node::getAttribute(Attr::ExprTime)) + // return format("(${}...{}", + // a->value, + // s.substr(1)); + return s; +} + +SuiteStmt::SuiteStmt(std::vector stmts) + : AcceptorExtend(), Items(std::move(stmts)) {} +SuiteStmt::SuiteStmt(const SuiteStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)) {} std::string SuiteStmt::toString(int indent) const { + if (indent == -1) + return ""; std::string pad = indent >= 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string s; - for (int i = 0; i < stmts.size(); i++) - if (stmts[i]) { - auto is = stmts[i]->toString(indent >= 0 ? indent + INDENT_SIZE : -1); - if (stmts[i]->done) + for (int i = 0; i < size(); i++) + if (items[i]) { + auto is = items[i]->toString(indent >= 0 ? indent + INDENT_SIZE : -1); + if (items[i]->isDone()) is.insert(findStar(is), "*"); s += (i ? pad : "") + is; } - return format("(suite{})", s.empty() ? s : " " + pad + s); -} -ACCEPT_IMPL(SuiteStmt, ASTVisitor); -void SuiteStmt::flatten(const StmtPtr &s, std::vector &stmts) { - if (!s) - return; - if (!s->getSuite()) { - stmts.push_back(s); - } else { - for (auto &ss : s->getSuite()->stmts) - stmts.push_back(ss); - } -} -StmtPtr *SuiteStmt::lastInBlock() { - if (stmts.empty()) - return nullptr; - if (auto s = stmts.back()->getSuite()) { - auto l = s->lastInBlock(); - if (l) - return l; + return wrapStmt( + format("({}suite{})", (isDone() ? "*" : ""), (s.empty() ? s : " " + pad + s))); +} +void SuiteStmt::flatten() { + std::vector ns; + for (auto &s : items) { + if (!s) + continue; + if (!cast(s)) { + ns.push_back(s); + } else { + for (auto *ss : *cast(s)) + ns.push_back(ss); + } } - return &(stmts.back()); -} - -std::string BreakStmt::toString(int) const { return "(break)"; } -ACCEPT_IMPL(BreakStmt, ASTVisitor); - -std::string ContinueStmt::toString(int) const { return "(continue)"; } -ACCEPT_IMPL(ContinueStmt, ASTVisitor); - -ExprStmt::ExprStmt(ExprPtr expr) : Stmt(), expr(std::move(expr)) {} -ExprStmt::ExprStmt(const ExprStmt &stmt) : Stmt(stmt), expr(ast::clone(stmt.expr)) {} -std::string ExprStmt::toString(int) const { - return format("(expr {})", expr->toString()); -} -ACCEPT_IMPL(ExprStmt, ASTVisitor); - -AssignStmt::AssignStmt(ExprPtr lhs, ExprPtr rhs, ExprPtr type) - : Stmt(), lhs(std::move(lhs)), rhs(std::move(rhs)), type(std::move(type)), - update(Assign) {} -AssignStmt::AssignStmt(const AssignStmt &stmt) - : Stmt(stmt), lhs(ast::clone(stmt.lhs)), rhs(ast::clone(stmt.rhs)), - type(ast::clone(stmt.type)), update(stmt.update) {} -std::string AssignStmt::toString(int) const { - return format("({} {}{}{})", update != Assign ? "update" : "assign", lhs->toString(), - rhs ? " " + rhs->toString() : "", - type ? format(" #:type {}", type->toString()) : ""); -} -ACCEPT_IMPL(AssignStmt, ASTVisitor); - -DelStmt::DelStmt(ExprPtr expr) : Stmt(), expr(std::move(expr)) {} -DelStmt::DelStmt(const DelStmt &stmt) : Stmt(stmt), expr(ast::clone(stmt.expr)) {} -std::string DelStmt::toString(int) const { - return format("(del {})", expr->toString()); -} -ACCEPT_IMPL(DelStmt, ASTVisitor); - -PrintStmt::PrintStmt(std::vector items, bool isInline) - : Stmt(), items(std::move(items)), isInline(isInline) {} -PrintStmt::PrintStmt(const PrintStmt &stmt) - : Stmt(stmt), items(ast::clone(stmt.items)), isInline(stmt.isInline) {} -std::string PrintStmt::toString(int) const { - return format("(print {}{})", isInline ? "#:inline " : "", combine(items)); -} -ACCEPT_IMPL(PrintStmt, ASTVisitor); - -ReturnStmt::ReturnStmt(ExprPtr expr) : Stmt(), expr(std::move(expr)) {} -ReturnStmt::ReturnStmt(const ReturnStmt &stmt) - : Stmt(stmt), expr(ast::clone(stmt.expr)) {} -std::string ReturnStmt::toString(int) const { - return expr ? format("(return {})", expr->toString()) : "(return)"; -} -ACCEPT_IMPL(ReturnStmt, ASTVisitor); - -YieldStmt::YieldStmt(ExprPtr expr) : Stmt(), expr(std::move(expr)) {} -YieldStmt::YieldStmt(const YieldStmt &stmt) : Stmt(stmt), expr(ast::clone(stmt.expr)) {} -std::string YieldStmt::toString(int) const { - return expr ? format("(yield {})", expr->toString()) : "(yield)"; -} -ACCEPT_IMPL(YieldStmt, ASTVisitor); - -AssertStmt::AssertStmt(ExprPtr expr, ExprPtr message) - : Stmt(), expr(std::move(expr)), message(std::move(message)) {} -AssertStmt::AssertStmt(const AssertStmt &stmt) - : Stmt(stmt), expr(ast::clone(stmt.expr)), message(ast::clone(stmt.message)) {} -std::string AssertStmt::toString(int) const { - return format("(assert {}{})", expr->toString(), message ? message->toString() : ""); -} -ACCEPT_IMPL(AssertStmt, ASTVisitor); - -WhileStmt::WhileStmt(ExprPtr cond, StmtPtr suite, StmtPtr elseSuite) - : Stmt(), cond(std::move(cond)), suite(std::move(suite)), - elseSuite(std::move(elseSuite)) {} -WhileStmt::WhileStmt(const WhileStmt &stmt) - : Stmt(stmt), cond(ast::clone(stmt.cond)), suite(ast::clone(stmt.suite)), - elseSuite(ast::clone(stmt.elseSuite)) {} + items = ns; +} +void SuiteStmt::addStmt(Stmt *s) { + if (s) + items.push_back(s); +} +SuiteStmt *SuiteStmt::wrap(Stmt *s) { + if (s && !cast(s)) + return s->cache->NS(s, s); + return (SuiteStmt *)s; +} + +BreakStmt::BreakStmt(const BreakStmt &stmt, bool clean) : AcceptorExtend(stmt, clean) {} +std::string BreakStmt::toString(int indent) const { return wrapStmt("(break)"); } + +ContinueStmt::ContinueStmt(const ContinueStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean) {} +std::string ContinueStmt::toString(int indent) const { return wrapStmt("(continue)"); } + +ExprStmt::ExprStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} +ExprStmt::ExprStmt(const ExprStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} +std::string ExprStmt::toString(int indent) const { + return wrapStmt(format("(expr {})", expr->toString(indent))); +} + +AssignStmt::AssignStmt(Expr *lhs, Expr *rhs, Expr *type, UpdateMode update) + : AcceptorExtend(), lhs(lhs), rhs(rhs), type(type), update(update) {} +AssignStmt::AssignStmt(const AssignStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), lhs(ast::clone(stmt.lhs, clean)), + rhs(ast::clone(stmt.rhs, clean)), type(ast::clone(stmt.type, clean)), + update(stmt.update) {} +std::string AssignStmt::toString(int indent) const { + return wrapStmt(format("({} {}{}{})", update != Assign ? "update" : "assign", + lhs->toString(indent), rhs ? " " + rhs->toString(indent) : "", + type ? format(" #:type {}", type->toString(indent)) : "")); +} + +DelStmt::DelStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} +DelStmt::DelStmt(const DelStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} +std::string DelStmt::toString(int indent) const { + return wrapStmt(format("(del {})", expr->toString(indent))); +} + +PrintStmt::PrintStmt(std::vector items, bool noNewline) + : AcceptorExtend(), Items(std::move(items)), noNewline(noNewline) {} +PrintStmt::PrintStmt(const PrintStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), + noNewline(stmt.noNewline) {} +std::string PrintStmt::toString(int indent) const { + return wrapStmt(format("(print {}{})", noNewline ? "#:inline " : "", combine(items))); +} + +ReturnStmt::ReturnStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} +ReturnStmt::ReturnStmt(const ReturnStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} +std::string ReturnStmt::toString(int indent) const { + return wrapStmt(expr ? format("(return {})", expr->toString(indent)) : "(return)"); +} + +YieldStmt::YieldStmt(Expr *expr) : AcceptorExtend(), expr(expr) {} +YieldStmt::YieldStmt(const YieldStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} +std::string YieldStmt::toString(int indent) const { + return wrapStmt(expr ? format("(yield {})", expr->toString(indent)) : "(yield)"); +} + +AssertStmt::AssertStmt(Expr *expr, Expr *message) + : AcceptorExtend(), expr(expr), message(message) {} +AssertStmt::AssertStmt(const AssertStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)), + message(ast::clone(stmt.message, clean)) {} +std::string AssertStmt::toString(int indent) const { + return wrapStmt(format("(assert {}{})", expr->toString(indent), + message ? message->toString(indent) : "")); +} + +WhileStmt::WhileStmt(Expr *cond, Stmt *suite, Stmt *elseSuite) + : AcceptorExtend(), cond(cond), suite(SuiteStmt::wrap(suite)), + elseSuite(SuiteStmt::wrap(elseSuite)) {} +WhileStmt::WhileStmt(const WhileStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), cond(ast::clone(stmt.cond, clean)), + suite(ast::clone(stmt.suite, clean)), + elseSuite(ast::clone(stmt.elseSuite, clean)) {} std::string WhileStmt::toString(int indent) const { + if (indent == -1) + return wrapStmt(format("(while {})", cond->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; - if (elseSuite && elseSuite->firstInBlock()) - return format("(while-else {}{}{}{}{})", cond->toString(), pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, - elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); - else - return format("(while {}{}{})", cond->toString(), pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); + if (elseSuite && elseSuite->firstInBlock()) { + return wrapStmt( + format("(while-else {}{}{}{}{})", cond->toString(indent), pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } else { + return wrapStmt(format("(while {}{}{})", cond->toString(indent), pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } } -ACCEPT_IMPL(WhileStmt, ASTVisitor); -ForStmt::ForStmt(ExprPtr var, ExprPtr iter, StmtPtr suite, StmtPtr elseSuite, - ExprPtr decorator, std::vector ompArgs) - : Stmt(), var(std::move(var)), iter(std::move(iter)), suite(std::move(suite)), - elseSuite(std::move(elseSuite)), decorator(std::move(decorator)), +ForStmt::ForStmt(Expr *var, Expr *iter, Stmt *suite, Stmt *elseSuite, Expr *decorator, + std::vector ompArgs) + : AcceptorExtend(), var(var), iter(iter), suite(SuiteStmt::wrap(suite)), + elseSuite(SuiteStmt::wrap(elseSuite)), decorator(decorator), ompArgs(std::move(ompArgs)), wrapped(false), flat(false) {} -ForStmt::ForStmt(const ForStmt &stmt) - : Stmt(stmt), var(ast::clone(stmt.var)), iter(ast::clone(stmt.iter)), - suite(ast::clone(stmt.suite)), elseSuite(ast::clone(stmt.elseSuite)), - decorator(ast::clone(stmt.decorator)), ompArgs(ast::clone_nop(stmt.ompArgs)), - wrapped(stmt.wrapped), flat(stmt.flat) {} +ForStmt::ForStmt(const ForStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), var(ast::clone(stmt.var, clean)), + iter(ast::clone(stmt.iter, clean)), suite(ast::clone(stmt.suite, clean)), + elseSuite(ast::clone(stmt.elseSuite, clean)), + decorator(ast::clone(stmt.decorator, clean)), + ompArgs(ast::clone(stmt.ompArgs, clean)), wrapped(stmt.wrapped), flat(stmt.flat) { +} std::string ForStmt::toString(int indent) const { + auto vs = var->toString(indent); + if (indent == -1) + return wrapStmt(format("(for {} {})", vs, iter->toString(indent))); + std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string attr; if (decorator) - attr += " " + decorator->toString(); + attr += " " + decorator->toString(indent); if (!attr.empty()) attr = " #:attr" + attr; - if (elseSuite && elseSuite->firstInBlock()) - return format("(for-else {} {}{}{}{}{}{})", var->toString(), iter->toString(), attr, - pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, - elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); - else - return format("(for {} {}{}{}{})", var->toString(), iter->toString(), attr, pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); + if (elseSuite && elseSuite->firstInBlock()) { + return wrapStmt( + format("(for-else {} {}{}{}{}{}{})", vs, iter->toString(indent), attr, pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } else { + return wrapStmt(format("(for {} {}{}{}{})", vs, iter->toString(indent), attr, pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); + } } -ACCEPT_IMPL(ForStmt, ASTVisitor); -IfStmt::IfStmt(ExprPtr cond, StmtPtr ifSuite, StmtPtr elseSuite) - : Stmt(), cond(std::move(cond)), ifSuite(std::move(ifSuite)), - elseSuite(std::move(elseSuite)) {} -IfStmt::IfStmt(const IfStmt &stmt) - : Stmt(stmt), cond(ast::clone(stmt.cond)), ifSuite(ast::clone(stmt.ifSuite)), - elseSuite(ast::clone(stmt.elseSuite)) {} +IfStmt::IfStmt(Expr *cond, Stmt *ifSuite, Stmt *elseSuite) + : AcceptorExtend(), cond(cond), ifSuite(SuiteStmt::wrap(ifSuite)), + elseSuite(SuiteStmt::wrap(elseSuite)) {} +IfStmt::IfStmt(const IfStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), cond(ast::clone(stmt.cond, clean)), + ifSuite(ast::clone(stmt.ifSuite, clean)), + elseSuite(ast::clone(stmt.elseSuite, clean)) {} std::string IfStmt::toString(int indent) const { + if (indent == -1) + return wrapStmt(format("(if {})", cond->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; - return format("(if {}{}{}{})", cond->toString(), pad, - ifSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), - elseSuite - ? pad + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) - : ""); + return wrapStmt(format( + "(if {}{}{}{})", cond->toString(indent), pad, + ifSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), + elseSuite ? pad + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) + : "")); } -ACCEPT_IMPL(IfStmt, ASTVisitor); -MatchStmt::MatchCase MatchStmt::MatchCase::clone() const { - return {ast::clone(pattern), ast::clone(guard), ast::clone(suite)}; +MatchCase::MatchCase(Expr *pattern, Expr *guard, Stmt *suite) + : pattern(pattern), guard(guard), suite(SuiteStmt::wrap(suite)) {} +MatchCase MatchCase::clone(bool clean) const { + return {ast::clone(pattern, clean), ast::clone(guard, clean), + ast::clone(suite, clean)}; } -MatchStmt::MatchStmt(ExprPtr what, std::vector cases) - : Stmt(), what(std::move(what)), cases(std::move(cases)) {} -MatchStmt::MatchStmt(const MatchStmt &stmt) - : Stmt(stmt), what(ast::clone(stmt.what)), cases(ast::clone_nop(stmt.cases)) {} +MatchStmt::MatchStmt(Expr *expr, std::vector cases) + : AcceptorExtend(), Items(std::move(cases)), expr(expr) {} +MatchStmt::MatchStmt(const MatchStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), + expr(ast::clone(stmt.expr, clean)) {} std::string MatchStmt::toString(int indent) const { + if (indent == -1) + return wrapStmt(format("(match {})", expr->toString(indent))); std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string padExtra = indent > 0 ? std::string(INDENT_SIZE, ' ') : ""; std::vector s; - for (auto &c : cases) - s.push_back(format("(case {}{}{}{})", c.pattern->toString(), - c.guard ? " #:guard " + c.guard->toString() : "", pad + padExtra, + for (auto &c : items) + s.push_back(format("(case {}{}{}{})", c.pattern->toString(indent), + c.guard ? " #:guard " + c.guard->toString(indent) : "", + pad + padExtra, c.suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); - return format("(match {}{}{})", what->toString(), pad, join(s, pad)); + return wrapStmt(format("(match {}{}{})", expr->toString(indent), pad, join(s, pad))); } -ACCEPT_IMPL(MatchStmt, ASTVisitor); -ImportStmt::ImportStmt(ExprPtr from, ExprPtr what, std::vector args, ExprPtr ret, +ImportStmt::ImportStmt(Expr *from, Expr *what, std::vector args, Expr *ret, std::string as, size_t dots, bool isFunction) - : Stmt(), from(std::move(from)), what(std::move(what)), as(std::move(as)), - dots(dots), args(std::move(args)), ret(std::move(ret)), isFunction(isFunction) { - validate(); -} -ImportStmt::ImportStmt(const ImportStmt &stmt) - : Stmt(stmt), from(ast::clone(stmt.from)), what(ast::clone(stmt.what)), as(stmt.as), - dots(stmt.dots), args(ast::clone_nop(stmt.args)), ret(ast::clone(stmt.ret)), + : AcceptorExtend(), from(from), what(what), as(std::move(as)), dots(dots), + args(std::move(args)), ret(ret), isFunction(isFunction) {} +ImportStmt::ImportStmt(const ImportStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), from(ast::clone(stmt.from, clean)), + what(ast::clone(stmt.what, clean)), as(stmt.as), dots(stmt.dots), + args(ast::clone(stmt.args, clean)), ret(ast::clone(stmt.ret, clean)), isFunction(stmt.isFunction) {} -std::string ImportStmt::toString(int) const { +std::string ImportStmt::toString(int indent) const { std::vector va; for (auto &a : args) - va.push_back(a.toString()); - return format("(import {}{}{}{}{}{})", from->toString(), - as.empty() ? "" : format(" #:as '{}", as), - what ? format(" #:what {}", what->toString()) : "", - dots ? format(" #:dots {}", dots) : "", - va.empty() ? "" : format(" #:args ({})", join(va)), - ret ? format(" #:ret {}", ret->toString()) : ""); -} -void ImportStmt::validate() const { - if (from) { - Expr *e = from.get(); - while (auto d = e->getDot()) - e = d->expr.get(); - if (!from->isId("C") && !from->isId("python")) { - if (!e->getId()) - E(Error::IMPORT_IDENTIFIER, e); - if (!args.empty()) - E(Error::IMPORT_FN, args[0]); - if (ret) - E(Error::IMPORT_FN, ret); - if (what && !what->getId()) - E(Error::IMPORT_IDENTIFIER, what); - } - if (!isFunction && !args.empty()) - E(Error::IMPORT_FN, args[0]); - } -} -ACCEPT_IMPL(ImportStmt, ASTVisitor); - -TryStmt::Catch TryStmt::Catch::clone() const { - return {var, ast::clone(exc), ast::clone(suite)}; -} - -TryStmt::TryStmt(StmtPtr suite, std::vector catches, StmtPtr finally) - : Stmt(), suite(std::move(suite)), catches(std::move(catches)), - finally(std::move(finally)) {} -TryStmt::TryStmt(const TryStmt &stmt) - : Stmt(stmt), suite(ast::clone(stmt.suite)), catches(ast::clone_nop(stmt.catches)), - finally(ast::clone(stmt.finally)) {} -std::string TryStmt::toString(int indent) const { + va.push_back(a.toString(indent)); + return wrapStmt(format("(import {}{}{}{}{}{})", from ? from->toString(indent) : "", + as.empty() ? "" : format(" #:as '{}", as), + what ? format(" #:what {}", what->toString(indent)) : "", + dots ? format(" #:dots {}", dots) : "", + va.empty() ? "" : format(" #:args ({})", join(va)), + ret ? format(" #:ret {}", ret->toString(indent)) : "")); +} + +ExceptStmt::ExceptStmt(const std::string &var, Expr *exc, Stmt *suite) + : var(var), exc(exc), suite(SuiteStmt::wrap(suite)) {} +ExceptStmt::ExceptStmt(const ExceptStmt &stmt, bool clean) + : AcceptorExtend(stmt), var(stmt.var), exc(ast::clone(stmt.exc, clean)), + suite(ast::clone(stmt.suite, clean)) {} +std::string ExceptStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string padExtra = indent > 0 ? std::string(INDENT_SIZE, ' ') : ""; + return wrapStmt( + format("(catch {}{}{}{})", !var.empty() ? format("#:var '{}", var) : "", + exc ? format(" #:exc {}", exc->toString(indent)) : "", pad + padExtra, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); +} + +TryStmt::TryStmt(Stmt *suite, std::vector excepts, Stmt *elseSuite, + Stmt *finally) + : AcceptorExtend(), Items(std::move(excepts)), suite(SuiteStmt::wrap(suite)), + elseSuite(SuiteStmt::wrap(elseSuite)), finally(SuiteStmt::wrap(finally)) {} +TryStmt::TryStmt(const TryStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), + suite(ast::clone(stmt.suite, clean)), + elseSuite(ast::clone(stmt.elseSuite, clean)), + finally(ast::clone(stmt.finally, clean)) {} +std::string TryStmt::toString(int indent) const { + if (indent == -1) + return wrapStmt(format("(try)")); + std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector s; - for (auto &i : catches) - s.push_back( - format("(catch {}{}{}{})", !i.var.empty() ? format("#:var '{}", i.var) : "", - i.exc ? format(" #:exc {}", i.exc->toString()) : "", pad + padExtra, - i.suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1 * 2))); - return format( + for (auto &i : items) + s.push_back(i->toString(indent)); + return wrapStmt(format( "(try{}{}{}{}{})", pad, suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1), pad, join(s, pad), - finally ? format("{}{}", pad, + elseSuite ? format("{}(else {})", pad, + elseSuite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)) + : "", + finally ? format("{}(finally {})", pad, finally->toString(indent >= 0 ? indent + INDENT_SIZE : -1)) - : ""); + : "")); } -ACCEPT_IMPL(TryStmt, ASTVisitor); -ThrowStmt::ThrowStmt(ExprPtr expr, bool transformed) - : Stmt(), expr(std::move(expr)), transformed(transformed) {} -ThrowStmt::ThrowStmt(const ThrowStmt &stmt) - : Stmt(stmt), expr(ast::clone(stmt.expr)), transformed(stmt.transformed) {} -std::string ThrowStmt::toString(int) const { - return format("(throw{})", expr ? " " + expr->toString() : ""); +ThrowStmt::ThrowStmt(Expr *expr, Expr *from, bool transformed) + : AcceptorExtend(), expr(expr), from(from), transformed(transformed) {} +ThrowStmt::ThrowStmt(const ThrowStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)), + from(ast::clone(stmt.from, clean)), transformed(stmt.transformed) {} +std::string ThrowStmt::toString(int indent) const { + return wrapStmt(format("(throw{}{})", expr ? " " + expr->toString(indent) : "", + from ? format(" :from {}", from->toString(indent)) : "")); } -ACCEPT_IMPL(ThrowStmt, ASTVisitor); GlobalStmt::GlobalStmt(std::string var, bool nonLocal) - : Stmt(), var(std::move(var)), nonLocal(nonLocal) {} -std::string GlobalStmt::toString(int) const { - return format("({} '{})", nonLocal ? "nonlocal" : "global", var); -} -ACCEPT_IMPL(GlobalStmt, ASTVisitor); - -Attr::Attr(const std::vector &attrs) - : module(), parentClass(), isAttribute(false) { - for (auto &a : attrs) - set(a); -} -void Attr::set(const std::string &attr) { customAttr.insert(attr); } -void Attr::unset(const std::string &attr) { customAttr.erase(attr); } -bool Attr::has(const std::string &attr) const { return in(customAttr, attr); } - -const std::string Attr::LLVM = "llvm"; -const std::string Attr::Python = "python"; -const std::string Attr::Atomic = "atomic"; -const std::string Attr::Property = "property"; -const std::string Attr::StaticMethod = "staticmethod"; -const std::string Attr::Attribute = "__attribute__"; -const std::string Attr::Internal = "__internal__"; -const std::string Attr::ForceRealize = "__force__"; -const std::string Attr::RealizeWithoutSelf = - "std.internal.attributes.realize_without_self"; -const std::string Attr::HiddenFromUser = "__hidden__"; -const std::string Attr::C = "C"; -const std::string Attr::CVarArg = ".__vararg__"; -const std::string Attr::Method = ".__method__"; -const std::string Attr::Capture = ".__capture__"; -const std::string Attr::HasSelf = ".__hasself__"; -const std::string Attr::IsGenerator = ".__generator__"; -const std::string Attr::Extend = "extend"; -const std::string Attr::Tuple = "tuple"; -const std::string Attr::Test = "std.internal.attributes.test"; -const std::string Attr::Overload = "overload"; -const std::string Attr::Export = "std.internal.attributes.export"; - -FunctionStmt::FunctionStmt(std::string name, ExprPtr ret, std::vector args, - StmtPtr suite, Attr attributes, - std::vector decorators) - : Stmt(), name(std::move(name)), ret(std::move(ret)), args(std::move(args)), - suite(std::move(suite)), attributes(std::move(attributes)), - decorators(std::move(decorators)) { - parseDecorators(); -} -FunctionStmt::FunctionStmt(const FunctionStmt &stmt) - : Stmt(stmt), name(stmt.name), ret(ast::clone(stmt.ret)), - args(ast::clone_nop(stmt.args)), suite(ast::clone(stmt.suite)), - attributes(stmt.attributes), decorators(ast::clone(stmt.decorators)) {} + : AcceptorExtend(), var(std::move(var)), nonLocal(nonLocal) {} +GlobalStmt::GlobalStmt(const GlobalStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), var(stmt.var), nonLocal(stmt.nonLocal) {} +std::string GlobalStmt::toString(int indent) const { + return wrapStmt(format("({} '{})", nonLocal ? "nonlocal" : "global", var)); +} + +FunctionStmt::FunctionStmt(std::string name, Expr *ret, std::vector args, + Stmt *suite, std::vector decorators) + : AcceptorExtend(), Items(std::move(args)), name(std::move(name)), ret(ret), + suite(SuiteStmt::wrap(suite)), decorators(std::move(decorators)) {} +FunctionStmt::FunctionStmt(const FunctionStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), + name(stmt.name), ret(ast::clone(stmt.ret, clean)), + suite(ast::clone(stmt.suite, clean)), + decorators(ast::clone(stmt.decorators, clean)) {} std::string FunctionStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector as; - for (auto &a : args) - as.push_back(a.toString()); - std::vector dec, attr; + for (auto &a : items) + as.push_back(a.toString(indent)); + std::vector dec; for (auto &a : decorators) if (a) - dec.push_back(format("(dec {})", a->toString())); - for (auto &a : attributes.customAttr) - attr.push_back(format("'{}'", a)); - return format("(fn '{} ({}){}{}{}{}{})", name, join(as, " "), - ret ? " #:ret " + ret->toString() : "", - dec.empty() ? "" : format(" (dec {})", join(dec, " ")), - attr.empty() ? "" : format(" (attr {})", join(attr, " ")), pad, - suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) - : "(suite)"); -} -void FunctionStmt::validate() const { - if (!ret && (attributes.has(Attr::LLVM) || attributes.has(Attr::C))) - E(Error::FN_LLVM, getSrcInfo()); - - std::unordered_set seenArgs; - bool defaultsStarted = false, hasStarArg = false, hasKwArg = false; - for (size_t ia = 0; ia < args.size(); ia++) { - auto &a = args[ia]; - auto n = a.name; - int stars = trimStars(n); - if (stars == 2) { - if (hasKwArg) - E(Error::FN_MULTIPLE_ARGS, a); - if (a.defaultValue) - E(Error::FN_DEFAULT_STARARG, a.defaultValue); - if (ia != args.size() - 1) - E(Error::FN_LAST_KWARG, a); - hasKwArg = true; - } else if (stars == 1) { - if (hasStarArg) - E(Error::FN_MULTIPLE_ARGS, a); - if (a.defaultValue) - E(Error::FN_DEFAULT_STARARG, a.defaultValue); - hasStarArg = true; - } - if (in(seenArgs, n)) - E(Error::FN_ARG_TWICE, a, n); - seenArgs.insert(n); - if (!a.defaultValue && defaultsStarted && !stars && a.status == Param::Normal) - E(Error::FN_DEFAULT, a, n); - defaultsStarted |= bool(a.defaultValue); - if (attributes.has(Attr::C)) { - if (a.defaultValue) - E(Error::FN_C_DEFAULT, a.defaultValue, n); - if (stars != 1 && !a.type) - E(Error::FN_C_TYPE, a, n); - } - } + dec.push_back(format("(dec {})", a->toString(indent))); + if (indent == -1) + return wrapStmt(format("(fn '{} ({}){})", name, join(as, " "), + ret ? " #:ret " + ret->toString(indent) : "")); + return wrapStmt(format( + "(fn '{} ({}){}{}{}{})", name, join(as, " "), + ret ? " #:ret " + ret->toString(indent) : "", + dec.empty() ? "" : format(" (dec {})", join(dec, " ")), pad, + suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)")); } -ACCEPT_IMPL(FunctionStmt, ASTVisitor); std::string FunctionStmt::signature() const { std::vector s; - for (auto &a : args) + for (auto &a : items) s.push_back(a.type ? a.type->toString() : "-"); return format("{}", join(s, ":")); } -bool FunctionStmt::hasAttr(const std::string &attr) const { - return attributes.has(attr); -} -void FunctionStmt::parseDecorators() { - std::vector newDecorators; - for (auto &d : decorators) { - if (d->isId(Attr::Attribute)) { - if (decorators.size() != 1) - E(Error::FN_SINGLE_DECORATOR, decorators[1], Attr::Attribute); - attributes.isAttribute = true; - } else if (d->isId(Attr::LLVM)) { - attributes.set(Attr::LLVM); - } else if (d->isId(Attr::Python)) { - if (decorators.size() != 1) - E(Error::FN_SINGLE_DECORATOR, decorators[1], Attr::Python); - attributes.set(Attr::Python); - } else if (d->isId(Attr::Internal)) { - attributes.set(Attr::Internal); - } else if (d->isId(Attr::HiddenFromUser)) { - attributes.set(Attr::HiddenFromUser); - } else if (d->isId(Attr::Atomic)) { - attributes.set(Attr::Atomic); - } else if (d->isId(Attr::Property)) { - attributes.set(Attr::Property); - } else if (d->isId(Attr::StaticMethod)) { - attributes.set(Attr::StaticMethod); - } else if (d->isId(Attr::ForceRealize)) { - attributes.set(Attr::ForceRealize); - } else if (d->isId(Attr::C)) { - attributes.set(Attr::C); - } else { - newDecorators.emplace_back(d); - } - } - if (attributes.has(Attr::C)) { - for (auto &a : args) { - if (a.name.size() > 1 && a.name[0] == '*' && a.name[1] != '*') - attributes.set(Attr::CVarArg); - } - } - if (!args.empty() && !args[0].type && args[0].name == "self") { - attributes.set(Attr::HasSelf); - } - decorators = newDecorators; - validate(); -} size_t FunctionStmt::getStarArgs() const { size_t i = 0; - while (i < args.size()) { - if (startswith(args[i].name, "*") && !startswith(args[i].name, "**")) + while (i < items.size()) { + if (startswith(items[i].name, "*") && !startswith(items[i].name, "**")) break; i++; } @@ -464,17 +368,17 @@ size_t FunctionStmt::getStarArgs() const { } size_t FunctionStmt::getKwStarArgs() const { size_t i = 0; - while (i < args.size()) { - if (startswith(args[i].name, "**")) + while (i < items.size()) { + if (startswith(items[i].name, "**")) break; i++; } return i; } -std::string FunctionStmt::getDocstr() { +std::string FunctionStmt::getDocstr() const { if (auto s = suite->firstInBlock()) { - if (auto e = s->getExpr()) { - if (auto ss = e->expr->getString()) + if (auto e = cast(s)) { + if (auto ss = cast(e->getExpr())) return ss->getValue(); } } @@ -488,7 +392,7 @@ class IdSearchVisitor : public CallbackASTVisitor { public: IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {} - bool transform(const std::shared_ptr &expr) override { + bool transform(Expr *expr) override { if (result) return result; IdSearchVisitor v(what); @@ -496,7 +400,7 @@ class IdSearchVisitor : public CallbackASTVisitor { expr->accept(v); return result = v.result; } - bool transform(const std::shared_ptr &stmt) override { + bool transform(Stmt *stmt) override { if (result) return result; IdSearchVisitor v(what); @@ -505,19 +409,19 @@ class IdSearchVisitor : public CallbackASTVisitor { return result = v.result; } void visit(IdExpr *expr) override { - if (expr->value == what) + if (expr->getValue() == what) result = true; } }; /// Check if a function can be called with the given arguments. /// See @c reorderNamedArgs for details. -std::unordered_set FunctionStmt::getNonInferrableGenerics() { +std::unordered_set FunctionStmt::getNonInferrableGenerics() const { std::unordered_set nonInferrableGenerics; - for (auto &a : args) { + for (const auto &a : items) { if (a.status == Param::Generic && !a.defaultValue) { bool inferrable = false; - for (auto &b : args) + for (const auto &b : items) if (b.type && IdSearchVisitor(a.name).transform(b.type)) { inferrable = true; break; @@ -531,256 +435,168 @@ std::unordered_set FunctionStmt::getNonInferrableGenerics() { return nonInferrableGenerics; } -ClassStmt::ClassStmt(std::string name, std::vector args, StmtPtr suite, - std::vector decorators, std::vector baseClasses, - std::vector staticBaseClasses) - : Stmt(), name(std::move(name)), args(std::move(args)), suite(std::move(suite)), - decorators(std::move(decorators)), +ClassStmt::ClassStmt(std::string name, std::vector args, Stmt *suite, + std::vector decorators, std::vector baseClasses, + std::vector staticBaseClasses) + : AcceptorExtend(), Items(std::move(args)), name(std::move(name)), + suite(SuiteStmt::wrap(suite)), decorators(std::move(decorators)), staticBaseClasses(std::move(staticBaseClasses)) { for (auto &b : baseClasses) { - if (b->getIndex() && b->getIndex()->expr->isId("Static")) { - this->staticBaseClasses.push_back(b->getIndex()->index); + if (cast(b) && isId(cast(b)->getExpr(), "Static")) { + this->staticBaseClasses.push_back(cast(b)->getIndex()); } else { this->baseClasses.push_back(b); } } - parseDecorators(); -} -ClassStmt::ClassStmt(std::string name, std::vector args, StmtPtr suite, - Attr attr) - : Stmt(), name(std::move(name)), args(std::move(args)), suite(std::move(suite)), - attributes(std::move(attr)) { - validate(); -} -ClassStmt::ClassStmt(const ClassStmt &stmt) - : Stmt(stmt), name(stmt.name), args(ast::clone_nop(stmt.args)), - suite(ast::clone(stmt.suite)), attributes(stmt.attributes), - decorators(ast::clone(stmt.decorators)), - baseClasses(ast::clone(stmt.baseClasses)), - staticBaseClasses(ast::clone(stmt.staticBaseClasses)) {} +} +ClassStmt::ClassStmt(const ClassStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), + name(stmt.name), suite(ast::clone(stmt.suite, clean)), + decorators(ast::clone(stmt.decorators, clean)), + baseClasses(ast::clone(stmt.baseClasses, clean)), + staticBaseClasses(ast::clone(stmt.staticBaseClasses, clean)) {} std::string ClassStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector bases; for (auto &b : baseClasses) - bases.push_back(b->toString()); + bases.push_back(b->toString(indent)); for (auto &b : staticBaseClasses) - bases.push_back(fmt::format("(static {})", b->toString())); + bases.push_back(fmt::format("(static {})", b->toString(indent))); std::string as; - for (int i = 0; i < args.size(); i++) - as += (i ? pad : "") + args[i].toString(); + for (int i = 0; i < items.size(); i++) + as += (i ? pad : "") + items[i].toString(indent); std::vector attr; for (auto &a : decorators) - attr.push_back(format("(dec {})", a->toString())); - return format("(class '{}{}{}{}{}{})", name, - bases.empty() ? "" : format(" (bases {})", join(bases, " ")), - attr.empty() ? "" : format(" (attr {})", join(attr, " ")), - as.empty() ? as : pad + as, pad, - suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) - : "(suite)"); -} -void ClassStmt::validate() const { - std::unordered_set seen; - if (attributes.has(Attr::Extend) && !args.empty()) - E(Error::CLASS_EXTENSION, args[0]); - if (attributes.has(Attr::Extend) && - !(baseClasses.empty() && staticBaseClasses.empty())) - E(Error::CLASS_EXTENSION, - baseClasses.empty() ? staticBaseClasses[0] : baseClasses[0]); - for (auto &a : args) { - if (!a.type && !a.defaultValue) - E(Error::CLASS_MISSING_TYPE, a, a.name); - if (in(seen, a.name)) - E(Error::CLASS_ARG_TWICE, a, a.name); - seen.insert(a.name); - } -} -ACCEPT_IMPL(ClassStmt, ASTVisitor); -bool ClassStmt::isRecord() const { return hasAttr(Attr::Tuple); } -bool ClassStmt::hasAttr(const std::string &attr) const { return attributes.has(attr); } -void ClassStmt::parseDecorators() { - // @tuple(init=, repr=, eq=, order=, hash=, pickle=, container=, python=, add=, - // internal=...) - // @dataclass(...) - // @extend - - std::map tupleMagics = { - {"new", true}, {"repr", false}, {"hash", false}, - {"eq", false}, {"ne", false}, {"lt", false}, - {"le", false}, {"gt", false}, {"ge", false}, - {"pickle", true}, {"unpickle", true}, {"to_py", false}, - {"from_py", false}, {"iter", false}, {"getitem", false}, - {"len", false}, {"to_gpu", false}, {"from_gpu", false}, - {"from_gpu_new", false}, {"tuplesize", true}}; - - for (auto &d : decorators) { - if (d->isId("deduce")) { - attributes.customAttr.insert("deduce"); - } else if (d->isId("__notuple__")) { - attributes.customAttr.insert("__notuple__"); - } else if (auto c = d->getCall()) { - if (c->expr->isId(Attr::Tuple)) { - attributes.set(Attr::Tuple); - for (auto &m : tupleMagics) - m.second = true; - } else if (!c->expr->isId("dataclass")) { - E(Error::CLASS_BAD_DECORATOR, c->expr); - } else if (attributes.has(Attr::Tuple)) { - E(Error::CLASS_CONFLICT_DECORATOR, c, "dataclass", Attr::Tuple); - } - for (auto &a : c->args) { - auto b = CAST(a.value, BoolExpr); - if (!b) - E(Error::CLASS_NONSTATIC_DECORATOR, a); - char val = char(b->value); - if (a.name == "init") { - tupleMagics["new"] = val; - } else if (a.name == "repr") { - tupleMagics["repr"] = val; - } else if (a.name == "eq") { - tupleMagics["eq"] = tupleMagics["ne"] = val; - } else if (a.name == "order") { - tupleMagics["lt"] = tupleMagics["le"] = tupleMagics["gt"] = - tupleMagics["ge"] = val; - } else if (a.name == "hash") { - tupleMagics["hash"] = val; - } else if (a.name == "pickle") { - tupleMagics["pickle"] = tupleMagics["unpickle"] = val; - } else if (a.name == "python") { - tupleMagics["to_py"] = tupleMagics["from_py"] = val; - } else if (a.name == "gpu") { - tupleMagics["to_gpu"] = tupleMagics["from_gpu"] = - tupleMagics["from_gpu_new"] = val; - } else if (a.name == "container") { - tupleMagics["iter"] = tupleMagics["getitem"] = val; - } else { - E(Error::CLASS_BAD_DECORATOR_ARG, a); - } - } - } else if (d->isId(Attr::Tuple)) { - if (attributes.has(Attr::Tuple)) - E(Error::CLASS_MULTIPLE_DECORATORS, d, Attr::Tuple); - attributes.set(Attr::Tuple); - for (auto &m : tupleMagics) { - m.second = true; - } - } else if (d->isId(Attr::Extend)) { - attributes.set(Attr::Extend); - if (decorators.size() != 1) - E(Error::CLASS_SINGLE_DECORATOR, decorators[decorators[0] == d], Attr::Extend); - } else if (d->isId(Attr::Internal)) { - attributes.set(Attr::Internal); - } else { - E(Error::CLASS_BAD_DECORATOR, d); - } - } - if (attributes.has("deduce")) - tupleMagics["new"] = false; - if (!attributes.has(Attr::Tuple)) { - tupleMagics["init"] = tupleMagics["new"]; - tupleMagics["new"] = tupleMagics["raw"] = true; - tupleMagics["len"] = false; - } - tupleMagics["dict"] = true; - // Internal classes do not get any auto-generated members. - attributes.magics.clear(); - if (!attributes.has(Attr::Internal)) { - for (auto &m : tupleMagics) - if (m.second) - attributes.magics.insert(m.first); - } - - validate(); -} + attr.push_back(format("(dec {})", a->toString(indent))); + if (indent == -1) + return wrapStmt(format("(class '{} ({}))", name, as)); + return wrapStmt(format( + "(class '{}{}{}{}{}{})", name, + bases.empty() ? "" : format(" (bases {})", join(bases, " ")), + attr.empty() ? "" : format(" (attr {})", join(attr, " ")), + as.empty() ? as : pad + as, pad, + suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)")); +} +bool ClassStmt::isRecord() const { return hasAttribute(Attr::Tuple); } bool ClassStmt::isClassVar(const Param &p) { if (!p.defaultValue) return false; if (!p.type) return true; - if (auto i = p.type->getIndex()) - return i->expr->isId("ClassVar"); + if (auto i = cast(p.type)) + return isId(i->getExpr(), "ClassVar"); return false; } -std::string ClassStmt::getDocstr() { +std::string ClassStmt::getDocstr() const { if (auto s = suite->firstInBlock()) { - if (auto e = s->getExpr()) { - if (auto ss = e->expr->getString()) + if (auto e = cast(s)) { + if (auto ss = cast(e->getExpr())) return ss->getValue(); } } return ""; } -YieldFromStmt::YieldFromStmt(ExprPtr expr) : Stmt(), expr(std::move(expr)) {} -YieldFromStmt::YieldFromStmt(const YieldFromStmt &stmt) - : Stmt(stmt), expr(ast::clone(stmt.expr)) {} -std::string YieldFromStmt::toString(int) const { - return format("(yield-from {})", expr->toString()); +YieldFromStmt::YieldFromStmt(Expr *expr) : AcceptorExtend(), expr(std::move(expr)) {} +YieldFromStmt::YieldFromStmt(const YieldFromStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), expr(ast::clone(stmt.expr, clean)) {} +std::string YieldFromStmt::toString(int indent) const { + return wrapStmt(format("(yield-from {})", expr->toString(indent))); } -ACCEPT_IMPL(YieldFromStmt, ASTVisitor); -WithStmt::WithStmt(std::vector items, std::vector vars, - StmtPtr suite) - : Stmt(), items(std::move(items)), vars(std::move(vars)), suite(std::move(suite)) { +WithStmt::WithStmt(std::vector items, std::vector vars, + Stmt *suite) + : AcceptorExtend(), Items(std::move(items)), vars(std::move(vars)), + suite(SuiteStmt::wrap(suite)) { seqassert(this->items.size() == this->vars.size(), "vector size mismatch"); } -WithStmt::WithStmt(std::vector> itemVarPairs, StmtPtr suite) - : Stmt(), suite(std::move(suite)) { - for (auto &i : itemVarPairs) { - items.push_back(std::move(i.first)); - if (i.second) { - if (!i.second->getId()) - throw; - vars.push_back(i.second->getId()->value); +WithStmt::WithStmt(std::vector> itemVarPairs, Stmt *suite) + : AcceptorExtend(), Items({}), suite(SuiteStmt::wrap(suite)) { + for (auto [i, j] : itemVarPairs) { + items.push_back(i); + if (auto je = cast(j)) { + vars.push_back(je->getValue()); } else { vars.emplace_back(); } } } -WithStmt::WithStmt(const WithStmt &stmt) - : Stmt(stmt), items(ast::clone(stmt.items)), vars(stmt.vars), - suite(ast::clone(stmt.suite)) {} +WithStmt::WithStmt(const WithStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), Items(ast::clone(stmt.items, clean)), + vars(stmt.vars), suite(ast::clone(stmt.suite, clean)) {} std::string WithStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::vector as; as.reserve(items.size()); for (int i = 0; i < items.size(); i++) { as.push_back(!vars[i].empty() - ? format("({} #:var '{})", items[i]->toString(), vars[i]) - : items[i]->toString()); + ? format("({} #:var '{})", items[i]->toString(indent), vars[i]) + : items[i]->toString(indent)); } - return format("(with ({}){}{})", join(as, " "), pad, - suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)); -} -ACCEPT_IMPL(WithStmt, ASTVisitor); - -CustomStmt::CustomStmt(std::string keyword, ExprPtr expr, StmtPtr suite) - : Stmt(), keyword(std::move(keyword)), expr(std::move(expr)), - suite(std::move(suite)) {} -CustomStmt::CustomStmt(const CustomStmt &stmt) - : Stmt(stmt), keyword(stmt.keyword), expr(ast::clone(stmt.expr)), - suite(ast::clone(stmt.suite)) {} + if (indent == -1) + return wrapStmt(format("(with ({}))", join(as, " "))); + return wrapStmt(format("(with ({}){}{})", join(as, " "), pad, + suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1))); +} + +CustomStmt::CustomStmt(std::string keyword, Expr *expr, Stmt *suite) + : AcceptorExtend(), keyword(std::move(keyword)), expr(expr), + suite(SuiteStmt::wrap(suite)) {} +CustomStmt::CustomStmt(const CustomStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), keyword(stmt.keyword), + expr(ast::clone(stmt.expr, clean)), suite(ast::clone(stmt.suite, clean)) {} std::string CustomStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; - return format("(custom-{} {}{}{})", keyword, - expr ? format(" #:expr {}", expr->toString()) : "", pad, - suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : ""); + return wrapStmt( + format("(custom-{} {}{}{})", keyword, + expr ? format(" #:expr {}", expr->toString(indent)) : "", pad, + suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "")); } -ACCEPT_IMPL(CustomStmt, ASTVisitor); -AssignMemberStmt::AssignMemberStmt(ExprPtr lhs, std::string member, ExprPtr rhs) - : Stmt(), lhs(std::move(lhs)), member(std::move(member)), rhs(std::move(rhs)) {} -AssignMemberStmt::AssignMemberStmt(const AssignMemberStmt &stmt) - : Stmt(stmt), lhs(ast::clone(stmt.lhs)), member(stmt.member), - rhs(ast::clone(stmt.rhs)) {} -std::string AssignMemberStmt::toString(int) const { - return format("(assign-member {} {} {})", lhs->toString(), member, rhs->toString()); +AssignMemberStmt::AssignMemberStmt(Expr *lhs, std::string member, Expr *rhs) + : AcceptorExtend(), lhs(lhs), member(std::move(member)), rhs(rhs) {} +AssignMemberStmt::AssignMemberStmt(const AssignMemberStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), lhs(ast::clone(stmt.lhs, clean)), + member(stmt.member), rhs(ast::clone(stmt.rhs, clean)) {} +std::string AssignMemberStmt::toString(int indent) const { + return wrapStmt(format("(assign-member {} {} {})", lhs->toString(indent), member, + rhs->toString(indent))); } -ACCEPT_IMPL(AssignMemberStmt, ASTVisitor); -CommentStmt::CommentStmt(std::string comment) : Stmt(), comment(std::move(comment)) {} -std::string CommentStmt::toString(int) const { - return format("(comment \"{}\")", comment); +CommentStmt::CommentStmt(std::string comment) + : AcceptorExtend(), comment(std::move(comment)) {} +CommentStmt::CommentStmt(const CommentStmt &stmt, bool clean) + : AcceptorExtend(stmt, clean), comment(stmt.comment) {} +std::string CommentStmt::toString(int indent) const { + return wrapStmt(format("(comment \"{}\")", comment)); } + +const char Stmt::NodeId = 0; +ACCEPT_IMPL(SuiteStmt, ASTVisitor); +ACCEPT_IMPL(BreakStmt, ASTVisitor); +ACCEPT_IMPL(ContinueStmt, ASTVisitor); +ACCEPT_IMPL(ExprStmt, ASTVisitor); +ACCEPT_IMPL(AssignStmt, ASTVisitor); +ACCEPT_IMPL(DelStmt, ASTVisitor); +ACCEPT_IMPL(PrintStmt, ASTVisitor); +ACCEPT_IMPL(ReturnStmt, ASTVisitor); +ACCEPT_IMPL(YieldStmt, ASTVisitor); +ACCEPT_IMPL(AssertStmt, ASTVisitor); +ACCEPT_IMPL(WhileStmt, ASTVisitor); +ACCEPT_IMPL(ForStmt, ASTVisitor); +ACCEPT_IMPL(IfStmt, ASTVisitor); +ACCEPT_IMPL(MatchStmt, ASTVisitor); +ACCEPT_IMPL(ImportStmt, ASTVisitor); +ACCEPT_IMPL(ExceptStmt, ASTVisitor); +ACCEPT_IMPL(TryStmt, ASTVisitor); +ACCEPT_IMPL(ThrowStmt, ASTVisitor); +ACCEPT_IMPL(GlobalStmt, ASTVisitor); +ACCEPT_IMPL(FunctionStmt, ASTVisitor); +ACCEPT_IMPL(ClassStmt, ASTVisitor); +ACCEPT_IMPL(YieldFromStmt, ASTVisitor); +ACCEPT_IMPL(WithStmt, ASTVisitor); +ACCEPT_IMPL(CustomStmt, ASTVisitor); +ACCEPT_IMPL(AssignMemberStmt, ASTVisitor); ACCEPT_IMPL(CommentStmt, ASTVisitor); } // namespace codon::ast diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index f5dfb611..aacc7150 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -10,270 +10,282 @@ #include "codon/parser/ast/expr.h" #include "codon/parser/ast/types.h" #include "codon/parser/common.h" +#include "codon/util/serialize.h" namespace codon::ast { -#define ACCEPT(X) \ - using Stmt::toString; \ - StmtPtr clone() const override; \ - void accept(X &visitor) override +#define ACCEPT(CLASS, VISITOR, ...) \ + static const char NodeId; \ + using AcceptorExtend::clone; \ + using AcceptorExtend::accept; \ + ASTNode *clone(bool c) const override; \ + void accept(VISITOR &visitor) override; \ + std::string toString(int) const override; \ + friend class TypecheckVisitor; \ + template friend struct CallbackASTVisitor; \ + friend struct ReplacingCallbackASTVisitor; \ + inline decltype(auto) match_members() const { return std::tie(__VA_ARGS__); } \ + SERIALIZE(CLASS, BASE(Stmt), ##__VA_ARGS__) // Forward declarations struct ASTVisitor; -struct AssignStmt; -struct ClassStmt; -struct ExprStmt; -struct SuiteStmt; -struct FunctionStmt; /** * A Seq AST statement. * Each AST statement is intended to be instantiated as a shared_ptr. */ -struct Stmt : public codon::SrcObject { +struct Stmt : public AcceptorExtend { using base_type = Stmt; - /// Flag that indicates if all types in a statement are inferred (i.e. if a - /// type-checking procedure was successful). - bool done; - /// Statement age. - int age; - -public: Stmt(); - Stmt(const Stmt &s) = default; + Stmt(const Stmt &s); + Stmt(const Stmt &, bool); explicit Stmt(const codon::SrcInfo &s); - /// Convert a node to an S-expression. - std::string toString() const; - virtual std::string toString(int indent) const = 0; - /// Validate a node. Throw ParseASTException if a node is not valid. - void validate() const; - /// Deep copy a node. - virtual std::shared_ptr clone() const = 0; - /// Accept an AST visitor. - virtual void accept(ASTVisitor &) = 0; - - /// Allow pretty-printing to C++ streams. - friend std::ostream &operator<<(std::ostream &out, const Stmt &stmt) { - return out << stmt.toString(); - } - - /// Convenience virtual functions to avoid unnecessary dynamic_cast calls. - virtual AssignStmt *getAssign() { return nullptr; } - virtual ClassStmt *getClass() { return nullptr; } - virtual ExprStmt *getExpr() { return nullptr; } - virtual SuiteStmt *getSuite() { return nullptr; } - virtual FunctionStmt *getFunction() { return nullptr; } - + bool isDone() const { return done; } + void setDone() { done = true; } /// @return the first statement in a suite; if a statement is not a suite, returns the /// statement itself virtual Stmt *firstInBlock() { return this; } - bool isDone() const { return done; } - void setDone() { done = true; } + static const char NodeId; + SERIALIZE(Stmt, BASE(ASTNode), done); + + virtual std::string wrapStmt(const std::string &) const; + +private: + /// Flag that indicates if all types in a statement are inferred (i.e. if a + /// type-checking procedure was successful). + bool done; }; -using StmtPtr = std::shared_ptr; /// Suite (block of statements) statement (stmt...). /// @li a = 5; foo(1) -struct SuiteStmt : public Stmt { - using Stmt::Stmt; - - std::vector stmts; - - /// These constructors flattens the provided statement vector (see flatten() below). - explicit SuiteStmt(std::vector stmts = {}); +struct SuiteStmt : public AcceptorExtend, Items { + explicit SuiteStmt(std::vector stmts = {}); /// Convenience constructor template - SuiteStmt(StmtPtr stmt, Ts... stmts) : stmts({stmt, stmts...}) {} - SuiteStmt(const SuiteStmt &stmt); - - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + SuiteStmt(Stmt *stmt, Ts... stmts) : Items({stmt, stmts...}) {} + SuiteStmt(const SuiteStmt &, bool); - SuiteStmt *getSuite() override { return this; } Stmt *firstInBlock() override { - return stmts.empty() ? nullptr : stmts[0]->firstInBlock(); + return items.empty() ? nullptr : items[0]->firstInBlock(); } - StmtPtr *lastInBlock(); + void flatten(); + void addStmt(Stmt *s); - /// Flatten all nested SuiteStmt objects that do not own a block in the statement - /// vector. This is shallow flattening. - static void flatten(const StmtPtr &s, std::vector &stmts); + static SuiteStmt *wrap(Stmt *); + + ACCEPT(SuiteStmt, ASTVisitor, items); }; /// Break statement. /// @li break -struct BreakStmt : public Stmt { +struct BreakStmt : public AcceptorExtend { BreakStmt() = default; - BreakStmt(const BreakStmt &stmt) = default; + BreakStmt(const BreakStmt &, bool); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(BreakStmt, ASTVisitor); }; /// Continue statement. /// @li continue -struct ContinueStmt : public Stmt { +struct ContinueStmt : public AcceptorExtend { ContinueStmt() = default; - ContinueStmt(const ContinueStmt &stmt) = default; + ContinueStmt(const ContinueStmt &, bool); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(ContinueStmt, ASTVisitor); }; /// Expression statement (expr). /// @li 3 + foo() -struct ExprStmt : public Stmt { - ExprPtr expr; +struct ExprStmt : public AcceptorExtend { + explicit ExprStmt(Expr *expr = nullptr); + ExprStmt(const ExprStmt &, bool); - explicit ExprStmt(ExprPtr expr); - ExprStmt(const ExprStmt &stmt); + Expr *getExpr() const { return expr; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(ExprStmt, ASTVisitor, expr); - ExprStmt *getExpr() override { return this; } +private: + Expr *expr; }; /// Assignment statement (lhs: type = rhs). /// @li a = 5 /// @li a: Optional[int] = 5 /// @li a, b, c = 5, *z -struct AssignStmt : public Stmt { - ExprPtr lhs, rhs, type; +struct AssignStmt : public AcceptorExtend { + enum UpdateMode { Assign, Update, UpdateAtomic }; - AssignStmt(ExprPtr lhs, ExprPtr rhs, ExprPtr type = nullptr); - AssignStmt(const AssignStmt &stmt); + AssignStmt() + : lhs(nullptr), rhs(nullptr), type(nullptr), update(UpdateMode::Assign) {} + AssignStmt(Expr *lhs, Expr *rhs, Expr *type = nullptr, + UpdateMode update = UpdateMode::Assign); + AssignStmt(const AssignStmt &, bool); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + Expr *getLhs() const { return lhs; } + Expr *getRhs() const { return rhs; } + Expr *getTypeExpr() const { return type; } - AssignStmt *getAssign() override { return this; } - - bool isUpdate() const { return update != Assign; } + bool isAssignment() const { return update == Assign; } + bool isUpdate() const { return update == Update; } bool isAtomicUpdate() const { return update == UpdateAtomic; } void setUpdate() { update = Update; } void setAtomicUpdate() { update = UpdateAtomic; } + ACCEPT(AssignStmt, ASTVisitor, lhs, rhs, type, update); + private: - enum { Assign, Update, UpdateAtomic } update; + Expr *lhs, *rhs, *type; + UpdateMode update; }; /// Deletion statement (del expr). /// @li del a /// @li del a[5] -struct DelStmt : public Stmt { - ExprPtr expr; +struct DelStmt : public AcceptorExtend { + explicit DelStmt(Expr *expr = nullptr); + DelStmt(const DelStmt &, bool); - explicit DelStmt(ExprPtr expr); - DelStmt(const DelStmt &stmt); + Expr *getExpr() const { return expr; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(DelStmt, ASTVisitor, expr); + +private: + Expr *expr; }; /// Print statement (print expr). /// @li print a, b -struct PrintStmt : public Stmt { - std::vector items; - /// True if there is a dangling comma after print: print a, - bool isInline; +struct PrintStmt : public AcceptorExtend, Items { + explicit PrintStmt(std::vector items = {}, bool noNewline = false); + PrintStmt(const PrintStmt &, bool); + + bool hasNewline() const { return !noNewline; } - explicit PrintStmt(std::vector items, bool isInline); - PrintStmt(const PrintStmt &stmt); + ACCEPT(PrintStmt, ASTVisitor, items, noNewline); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +private: + /// True if there is a dangling comma after print: print a, + bool noNewline; }; /// Return statement (return expr). /// @li return /// @li return a -struct ReturnStmt : public Stmt { - /// nullptr if this is an empty return/yield statements. - ExprPtr expr; +struct ReturnStmt : public AcceptorExtend { + explicit ReturnStmt(Expr *expr = nullptr); + ReturnStmt(const ReturnStmt &, bool); + + Expr *getExpr() const { return expr; } - explicit ReturnStmt(ExprPtr expr = nullptr); - ReturnStmt(const ReturnStmt &stmt); + ACCEPT(ReturnStmt, ASTVisitor, expr); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +private: + /// nullptr if this is an empty return/yield statements. + Expr *expr; }; /// Yield statement (yield expr). /// @li yield /// @li yield a -struct YieldStmt : public Stmt { - /// nullptr if this is an empty return/yield statements. - ExprPtr expr; +struct YieldStmt : public AcceptorExtend { + explicit YieldStmt(Expr *expr = nullptr); + YieldStmt(const YieldStmt &, bool); + + Expr *getExpr() const { return expr; } - explicit YieldStmt(ExprPtr expr = nullptr); - YieldStmt(const YieldStmt &stmt); + ACCEPT(YieldStmt, ASTVisitor, expr); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +private: + /// nullptr if this is an empty return/yield statements. + Expr *expr; }; /// Assert statement (assert expr). /// @li assert a /// @li assert a, "Message" -struct AssertStmt : public Stmt { - ExprPtr expr; - /// nullptr if there is no message. - ExprPtr message; +struct AssertStmt : public AcceptorExtend { + explicit AssertStmt(Expr *expr = nullptr, Expr *message = nullptr); + AssertStmt(const AssertStmt &, bool); - explicit AssertStmt(ExprPtr expr, ExprPtr message = nullptr); - AssertStmt(const AssertStmt &stmt); + Expr *getExpr() const { return expr; } + Expr *getMessage() const { return message; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(AssertStmt, ASTVisitor, expr, message); + +private: + Expr *expr; + /// nullptr if there is no message. + Expr *message; }; /// While loop statement (while cond: suite; else: elseSuite). /// @li while True: print /// @li while True: break /// else: print -struct WhileStmt : public Stmt { - ExprPtr cond; - StmtPtr suite; +struct WhileStmt : public AcceptorExtend { + WhileStmt() : cond(nullptr), suite(nullptr), elseSuite(nullptr), gotoVar() {} + WhileStmt(Expr *cond, Stmt *suite, Stmt *elseSuite = nullptr); + WhileStmt(const WhileStmt &, bool); + + Expr *getCond() const { return cond; } + SuiteStmt *getSuite() const { return suite; } + SuiteStmt *getElse() const { return elseSuite; } + + ACCEPT(WhileStmt, ASTVisitor, cond, suite, elseSuite, gotoVar); + +private: + Expr *cond; + SuiteStmt *suite; /// nullptr if there is no else suite. - StmtPtr elseSuite; + SuiteStmt *elseSuite; + /// Set if a while loop is used to emulate goto statement /// (as `while gotoVar: ...`). - std::string gotoVar = ""; - - WhileStmt(ExprPtr cond, StmtPtr suite, StmtPtr elseSuite = nullptr); - WhileStmt(const WhileStmt &stmt); - - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + std::string gotoVar; }; /// For loop statement (for var in iter: suite; else elseSuite). /// @li for a, b in c: print /// @li for i in j: break /// else: print -struct ForStmt : public Stmt { - ExprPtr var; - ExprPtr iter; - StmtPtr suite; - StmtPtr elseSuite; - ExprPtr decorator; - std::vector ompArgs; +struct ForStmt : public AcceptorExtend { + ForStmt() + : var(nullptr), iter(nullptr), suite(nullptr), elseSuite(nullptr), + decorator(nullptr), ompArgs(), wrapped(false), flat(false) {} + ForStmt(Expr *var, Expr *iter, Stmt *suite, Stmt *elseSuite = nullptr, + Expr *decorator = nullptr, std::vector ompArgs = {}); + ForStmt(const ForStmt &, bool); + + Expr *getVar() const { return var; } + Expr *getIter() const { return iter; } + SuiteStmt *getSuite() const { return suite; } + SuiteStmt *getElse() const { return elseSuite; } + Expr *getDecorator() const { return decorator; } + void setDecorator(Expr *e) { decorator = e; } + bool isWrapped() const { return wrapped; } + bool isFlat() const { return flat; } + + ACCEPT(ForStmt, ASTVisitor, var, iter, suite, elseSuite, decorator, ompArgs, wrapped, + flat); + +private: + Expr *var; + Expr *iter; + SuiteStmt *suite; + SuiteStmt *elseSuite; + Expr *decorator; + std::vector ompArgs; /// Indicates if iter was wrapped with __iter__() call. bool wrapped; /// True if there are no break/continue within the loop bool flat; - ForStmt(ExprPtr var, ExprPtr iter, StmtPtr suite, StmtPtr elseSuite = nullptr, - ExprPtr decorator = nullptr, std::vector ompArgs = {}); - ForStmt(const ForStmt &stmt); - - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + friend class GeneratorExpr; + friend class ScopingVisitor; }; /// If block statement (if cond: suite; (elif cond: suite)...). @@ -283,38 +295,59 @@ struct ForStmt : public Stmt { /// @li if a: foo() /// elif b: bar() /// else: baz() -struct IfStmt : public Stmt { - ExprPtr cond; +struct IfStmt : public AcceptorExtend { + IfStmt(Expr *cond = nullptr, Stmt *ifSuite = nullptr, Stmt *elseSuite = nullptr); + IfStmt(const IfStmt &, bool); + + Expr *getCond() const { return cond; } + SuiteStmt *getIf() const { return ifSuite; } + SuiteStmt *getElse() const { return elseSuite; } + + ACCEPT(IfStmt, ASTVisitor, cond, ifSuite, elseSuite); + +private: + Expr *cond; /// elseSuite can be nullptr (if no else is found). - StmtPtr ifSuite, elseSuite; + SuiteStmt *ifSuite, *elseSuite; + + friend class GeneratorExpr; +}; + +struct MatchCase { + MatchCase(Expr *pattern = nullptr, Expr *guard = nullptr, Stmt *suite = nullptr); - IfStmt(ExprPtr cond, StmtPtr ifSuite, StmtPtr elseSuite = nullptr); - IfStmt(const IfStmt &stmt); + Expr *getPattern() const { return pattern; } + Expr *getGuard() const { return guard; } + SuiteStmt *getSuite() const { return suite; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + MatchCase clone(bool) const; + SERIALIZE(MatchCase, pattern, guard, suite); + +private: + Expr *pattern; + Expr *guard; + SuiteStmt *suite; + + friend class MatchStmt; + friend class TypecheckVisitor; + template friend struct CallbackASTVisitor; + friend struct ReplacingCallbackASTVisitor; }; /// Match statement (match what: (case pattern: case)...). /// @li match a: /// case 1: print /// case _: pass -struct MatchStmt : public Stmt { - struct MatchCase { - ExprPtr pattern; - ExprPtr guard; - StmtPtr suite; - - MatchCase clone() const; - }; - ExprPtr what; - std::vector cases; - - MatchStmt(ExprPtr what, std::vector cases); - MatchStmt(const MatchStmt &stmt); - - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +struct MatchStmt : public AcceptorExtend, Items { + MatchStmt(Expr *what = nullptr, std::vector cases = {}); + MatchStmt(const MatchStmt &, bool); + + Expr *getExpr() const { return expr; } + + ACCEPT(MatchStmt, ASTVisitor, items, expr); + +private: + Expr *expr; }; /// Import statement. @@ -329,163 +362,152 @@ struct MatchStmt : public Stmt { /// @li from c import foo(int) -> int as bar /// @li from python.numpy import array /// @li from python import numpy.array(int) -> int as na -struct ImportStmt : public Stmt { - ExprPtr from, what; +struct ImportStmt : public AcceptorExtend { + ImportStmt(Expr *from = nullptr, Expr *what = nullptr, std::vector args = {}, + Expr *ret = nullptr, std::string as = "", size_t dots = 0, + bool isFunction = true); + ImportStmt(const ImportStmt &, bool); + + Expr *getFrom() const { return from; } + Expr *getWhat() const { return what; } + std::string getAs() const { return as; } + size_t getDots() const { return dots; } + Expr *getReturnType() const { return ret; } + const std::vector &getArgs() const { return args; } + bool isCVar() const { return !isFunction; } + + ACCEPT(ImportStmt, ASTVisitor, from, what, as, dots, args, ret, isFunction); + +private: + Expr *from, *what; std::string as; /// Number of dots in a relative import (e.g. dots is 3 for "from ...foo"). size_t dots; /// Function argument types for C imports. std::vector args; /// Function return type for C imports. - ExprPtr ret; + Expr *ret; /// Set if this is a function C import (not variable import) bool isFunction; +}; - ImportStmt(ExprPtr from, ExprPtr what, std::vector args = {}, - ExprPtr ret = nullptr, std::string as = "", size_t dots = 0, - bool isFunction = true); - ImportStmt(const ImportStmt &stmt); +struct ExceptStmt : public AcceptorExtend { + ExceptStmt(const std::string &var = "", Expr *exc = nullptr, Stmt *suite = nullptr); + ExceptStmt(const ExceptStmt &, bool); - std::string toString(int indent) const override; - void validate() const; - ACCEPT(ASTVisitor); + std::string getVar() const { return var; } + Expr *getException() const { return exc; } + SuiteStmt *getSuite() const { return suite; } + + ACCEPT(ExceptStmt, ASTVisitor, var, exc, suite); + +private: + /// empty string if an except is unnamed. + std::string var; + /// nullptr if there is no explicit exception type. + Expr *exc; + SuiteStmt *suite; + + friend class ScopingVisitor; }; -/// Try-catch statement (try: suite; (catch var (as exc): suite)...; finally: finally). +/// Try-except statement (try: suite; (except var (as exc): suite)...; finally: +/// finally). /// @li: try: a -/// catch e: pass -/// catch e as Exc: pass -/// catch: pass +/// except e: pass +/// except e as Exc: pass +/// except: pass /// finally: print -struct TryStmt : public Stmt { - struct Catch { - /// empty string if a catch is unnamed. - std::string var; - /// nullptr if there is no explicit exception type. - ExprPtr exc; - StmtPtr suite; - - Catch clone() const; - }; - - StmtPtr suite; - std::vector catches; - /// nullptr if there is no finally block. - StmtPtr finally; +struct TryStmt : public AcceptorExtend, Items { + TryStmt(Stmt *suite = nullptr, std::vector catches = {}, + Stmt *elseSuite = nullptr, Stmt *finally = nullptr); + TryStmt(const TryStmt &, bool); - TryStmt(StmtPtr suite, std::vector catches, StmtPtr finally = nullptr); - TryStmt(const TryStmt &stmt); + SuiteStmt *getSuite() const { return suite; } + SuiteStmt *getElse() const { return elseSuite; } + SuiteStmt *getFinally() const { return finally; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(TryStmt, ASTVisitor, items, suite, elseSuite, finally); + +private: + SuiteStmt *suite; + /// nullptr if there is no else block. + SuiteStmt *elseSuite; + /// nullptr if there is no finally block. + SuiteStmt *finally; }; /// Throw statement (raise expr). /// @li: raise a -struct ThrowStmt : public Stmt { - ExprPtr expr; +struct ThrowStmt : public AcceptorExtend { + explicit ThrowStmt(Expr *expr = nullptr, Expr *from = nullptr, + bool transformed = false); + ThrowStmt(const ThrowStmt &, bool); + + Expr *getExpr() const { return expr; } + Expr *getFrom() const { return from; } + bool isTransformed() const { return transformed; } + + ACCEPT(ThrowStmt, ASTVisitor, expr, from, transformed); + +private: + Expr *expr; + Expr *from; // True if a statement was transformed during type-checking stage // (to avoid setting up ExcHeader multiple times). bool transformed; - - explicit ThrowStmt(ExprPtr expr, bool transformed = false); - ThrowStmt(const ThrowStmt &stmt); - - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); }; /// Global variable statement (global var). /// @li: global a -struct GlobalStmt : public Stmt { - std::string var; - bool nonLocal; +struct GlobalStmt : public AcceptorExtend { + explicit GlobalStmt(std::string var = "", bool nonLocal = false); + GlobalStmt(const GlobalStmt &, bool); - explicit GlobalStmt(std::string var, bool nonLocal = false); - GlobalStmt(const GlobalStmt &stmt) = default; + std::string getVar() const { return var; } + bool isNonLocal() const { return nonLocal; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); -}; + ACCEPT(GlobalStmt, ASTVisitor, var, nonLocal); -struct Attr { - // Toplevel attributes - const static std::string LLVM; - const static std::string Python; - const static std::string Atomic; - const static std::string Property; - const static std::string StaticMethod; - const static std::string Attribute; - const static std::string C; - // Internal attributes - const static std::string Internal; - const static std::string HiddenFromUser; - const static std::string ForceRealize; - const static std::string RealizeWithoutSelf; // not internal - // Compiler-generated attributes - const static std::string CVarArg; - const static std::string Method; - const static std::string Capture; - const static std::string HasSelf; - const static std::string IsGenerator; - // Class attributes - const static std::string Extend; - const static std::string Tuple; - // Standard library attributes - const static std::string Test; - const static std::string Overload; - const static std::string Export; - // Function module - std::string module; - // Parent class (set for methods only) - std::string parentClass; - // True if a function is decorated with __attribute__ - bool isAttribute; - - std::set magics; - - // Set of attributes - std::set customAttr; - - explicit Attr(const std::vector &attrs = std::vector()); - void set(const std::string &attr); - void unset(const std::string &attr); - bool has(const std::string &attr) const; +private: + std::string var; + bool nonLocal; }; /// Function statement (@(attributes...) def name[funcs...](args...) -> ret: suite). /// @li: @decorator /// def foo[T=int, U: int](a, b: int = 0) -> list[T]: pass -struct FunctionStmt : public Stmt { - std::string name; - /// nullptr if return type is not specified. - ExprPtr ret; - std::vector args; - StmtPtr suite; - Attr attributes; - std::vector decorators; - - FunctionStmt(std::string name, ExprPtr ret, std::vector args, StmtPtr suite, - Attr attributes = Attr(), std::vector decorators = {}); - FunctionStmt(const FunctionStmt &stmt); - - std::string toString(int indent) const override; - void validate() const; - ACCEPT(ASTVisitor); +struct FunctionStmt : public AcceptorExtend, Items { + FunctionStmt(std::string name = "", Expr *ret = nullptr, std::vector args = {}, + Stmt *suite = nullptr, std::vector decorators = {}); + FunctionStmt(const FunctionStmt &, bool); + + std::string getName() const { return name; } + Expr *getReturn() const { return ret; } + SuiteStmt *getSuite() const { return suite; } + void setSuite(SuiteStmt *s) { suite = s; } + const std::vector &getDecorators() const { return decorators; } + void setDecorators(const std::vector &d) { decorators = d; } /// @return a function signature that consists of generics and arguments in a /// S-expression form. /// @li (T U (int 0)) std::string signature() const; - bool hasAttr(const std::string &attr) const; - void parseDecorators(); - size_t getStarArgs() const; size_t getKwStarArgs() const; + std::string getDocstr() const; + std::unordered_set getNonInferrableGenerics() const; + + ACCEPT(FunctionStmt, ASTVisitor, name, items, ret, suite, decorators); - FunctionStmt *getFunction() override { return this; } - std::string getDocstr(); - std::unordered_set getNonInferrableGenerics(); +private: + std::string name; + /// nullptr if return type is not specified. + Expr *ret; + SuiteStmt *suite; + std::vector decorators; + + friend class Cache; }; /// Class statement (@(attributes...) class name[generics...]: args... ; suite). @@ -493,122 +515,144 @@ struct FunctionStmt : public Stmt { /// class F[T]: /// m: T /// def __new__() -> F[T]: ... -struct ClassStmt : public Stmt { - std::string name; - std::vector args; - StmtPtr suite; - Attr attributes; - std::vector decorators; - std::vector baseClasses; - std::vector staticBaseClasses; - - ClassStmt(std::string name, std::vector args, StmtPtr suite, - std::vector decorators = {}, std::vector baseClasses = {}, - std::vector staticBaseClasses = {}); - ClassStmt(std::string name, std::vector args, StmtPtr suite, Attr attr); - ClassStmt(const ClassStmt &stmt); - - std::string toString(int indent) const override; - void validate() const; - ACCEPT(ASTVisitor); +struct ClassStmt : public AcceptorExtend, Items { + ClassStmt(std::string name = "", std::vector args = {}, Stmt *suite = nullptr, + std::vector decorators = {}, std::vector baseClasses = {}, + std::vector staticBaseClasses = {}); + ClassStmt(const ClassStmt &, bool); + + std::string getName() const { return name; } + SuiteStmt *getSuite() const { return suite; } + const std::vector &getDecorators() const { return decorators; } + void setDecorators(const std::vector &d) { decorators = d; } + const std::vector &getBaseClasses() const { return baseClasses; } + const std::vector &getStaticBaseClasses() const { return staticBaseClasses; } /// @return true if a class is a tuple-like record (e.g. has a "@tuple" attribute) bool isRecord() const; - bool hasAttr(const std::string &attr) const; - - ClassStmt *getClass() override { return this; } + std::string getDocstr() const; - void parseDecorators(); static bool isClassVar(const Param &p); - std::string getDocstr(); + + ACCEPT(ClassStmt, ASTVisitor, name, suite, items, decorators, baseClasses, + staticBaseClasses); + +private: + std::string name; + SuiteStmt *suite; + std::vector decorators; + std::vector baseClasses; + std::vector staticBaseClasses; }; /// Yield-from statement (yield from expr). /// @li: yield from it -struct YieldFromStmt : public Stmt { - ExprPtr expr; +struct YieldFromStmt : public AcceptorExtend { + explicit YieldFromStmt(Expr *expr = nullptr); + YieldFromStmt(const YieldFromStmt &, bool); - explicit YieldFromStmt(ExprPtr expr); - YieldFromStmt(const YieldFromStmt &stmt); + Expr *getExpr() const { return expr; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(YieldFromStmt, ASTVisitor, expr); + +private: + Expr *expr; }; /// With statement (with (item as var)...: suite). /// @li: with foo(), bar() as b: pass -struct WithStmt : public Stmt { - std::vector items; - /// empty string if a corresponding item is unnamed - std::vector vars; - StmtPtr suite; +struct WithStmt : public AcceptorExtend, Items { + WithStmt(std::vector items = {}, std::vector vars = {}, + Stmt *suite = nullptr); + WithStmt(std::vector> items, Stmt *suite); + WithStmt(const WithStmt &, bool); - WithStmt(std::vector items, std::vector vars, StmtPtr suite); - WithStmt(std::vector> items, StmtPtr suite); - WithStmt(const WithStmt &stmt); + const std::vector &getVars() const { return vars; } + SuiteStmt *getSuite() const { return suite; } - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); + ACCEPT(WithStmt, ASTVisitor, items, vars, suite); + +private: + /// empty string if a corresponding item is unnamed + std::vector vars; + SuiteStmt *suite; }; /// Custom block statement (foo: ...). /// @li: pt_tree: pass -struct CustomStmt : public Stmt { - std::string keyword; - ExprPtr expr; - StmtPtr suite; +struct CustomStmt : public AcceptorExtend { + CustomStmt(std::string keyword = "", Expr *expr = nullptr, Stmt *suite = nullptr); + CustomStmt(const CustomStmt &, bool); + + std::string getKeyword() const { return keyword; } + Expr *getExpr() const { return expr; } + SuiteStmt *getSuite() const { return suite; } - CustomStmt(std::string keyword, ExprPtr expr, StmtPtr suite); - CustomStmt(const CustomStmt &stmt); + ACCEPT(CustomStmt, ASTVisitor, keyword, expr, suite); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +private: + std::string keyword; + Expr *expr; + SuiteStmt *suite; }; -/// The following nodes are created after the simplify stage. +/// The following nodes are created during typechecking. /// Member assignment statement (lhs.member = rhs). /// @li: a.x = b -struct AssignMemberStmt : public Stmt { - ExprPtr lhs; - std::string member; - ExprPtr rhs; +struct AssignMemberStmt : public AcceptorExtend { + AssignMemberStmt(Expr *lhs = nullptr, std::string member = "", Expr *rhs = nullptr); + AssignMemberStmt(const AssignMemberStmt &, bool); + + Expr *getLhs() const { return lhs; } + std::string getMember() const { return member; } + Expr *getRhs() const { return rhs; } - AssignMemberStmt(ExprPtr lhs, std::string member, ExprPtr rhs); - AssignMemberStmt(const AssignMemberStmt &stmt); + ACCEPT(AssignMemberStmt, ASTVisitor, lhs, member, rhs); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +private: + Expr *lhs; + std::string member; + Expr *rhs; }; /// Comment statement (# comment). /// Currently used only for pretty-printing. -struct CommentStmt : public Stmt { - std::string comment; +struct CommentStmt : public AcceptorExtend { + explicit CommentStmt(std::string comment = ""); + CommentStmt(const CommentStmt &, bool); + + std::string getComment() const { return comment; } - explicit CommentStmt(std::string comment); - CommentStmt(const CommentStmt &stmt) = default; + ACCEPT(CommentStmt, ASTVisitor, comment); - std::string toString(int indent) const override; - ACCEPT(ASTVisitor); +private: + std::string comment; }; #undef ACCEPT } // namespace codon::ast -template -struct fmt::formatter< - T, std::enable_if_t::value, char>> - : fmt::ostream_formatter {}; - -template -struct fmt::formatter< - T, std::enable_if_t< - std::is_convertible>::value, char>> - : fmt::formatter { - template - auto format(const T &p, FormatContext &ctx) const -> decltype(ctx.out()) { - return fmt::format_to(ctx.out(), "{}", p ? p->toString() : ""); +namespace tser { +static void operator<<(codon::ast::Stmt *t, Archive &a) { + using S = codon::PolymorphicSerializer; + a.save(t != nullptr); + if (t) { + auto typ = t->dynamicNodeId(); + auto key = S::_serializers[(void *)typ]; + a.save(key); + S::save(key, t, a); } -}; +} +static void operator>>(codon::ast::Stmt *&t, Archive &a) { + using S = codon::PolymorphicSerializer; + bool empty = a.load(); + if (!empty) { + std::string key = a.load(); + S::load(key, t, a); + } else { + t = nullptr; + } +} +} // namespace tser diff --git a/codon/parser/ast/types/class.cpp b/codon/parser/ast/types/class.cpp index 048fe983..882e49d2 100644 --- a/codon/parser/ast/types/class.cpp +++ b/codon/parser/ast/types/class.cpp @@ -5,31 +5,148 @@ #include #include "codon/parser/ast/types/class.h" -#include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast::types { +std::string ClassType::Generic::debugString(char mode) const { + if (!isStatic && type->getStatic() && mode != 2) + return type->getStatic()->getNonStaticType()->debugString(mode); + return type->debugString(mode); +} + +std::string ClassType::Generic::realizedName() const { + if (!isStatic && type->getStatic()) + return type->getStatic()->getNonStaticType()->realizedName(); + return type->realizedName(); +} + +ClassType::Generic ClassType::Generic::generalize(int atLevel) { + TypePtr t = nullptr; + if (!isStatic && type && type->getStatic()) + t = type->getStatic()->getNonStaticType()->generalize(atLevel); + else if (type) + t = type->generalize(atLevel); + return ClassType::Generic(name, niceName, t, id, isStatic); +} + +ClassType::Generic +ClassType::Generic::instantiate(int atLevel, int *unboundCount, + std::unordered_map *cache) { + TypePtr t = nullptr; + if (!isStatic && type && type->getStatic()) + t = type->getStatic()->getNonStaticType()->instantiate(atLevel, unboundCount, + cache); + else if (type) + t = type->instantiate(atLevel, unboundCount, cache); + return ClassType::Generic(name, niceName, t, id, isStatic); +} + ClassType::ClassType(Cache *cache, std::string name, std::string niceName, std::vector generics, std::vector hiddenGenerics) : Type(cache), name(std::move(name)), niceName(std::move(niceName)), generics(std::move(generics)), hiddenGenerics(std::move(hiddenGenerics)) {} -ClassType::ClassType(const ClassTypePtr &base) - : Type(base), name(base->name), niceName(base->niceName), generics(base->generics), - hiddenGenerics(base->hiddenGenerics) {} +ClassType::ClassType(ClassType *base) + : Type(*base), name(base->name), niceName(base->niceName), generics(base->generics), + hiddenGenerics(base->hiddenGenerics), isTuple(base->isTuple) {} int ClassType::unify(Type *typ, Unification *us) { if (auto tc = typ->getClass()) { + if (name == "int" && tc->name == "Int") + return tc->unify(this, us); + if (tc->name == "int" && name == "Int") { + auto t64 = std::make_shared(cache, 64); + return generics[0].type->unify(t64.get(), us); + } + if (name == "unrealized_type" && tc->name == name) { + // instantiate + unify! + std::unordered_map genericCache; + auto l = generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache); + genericCache.clear(); + auto r = + tc->generics[0].type->instantiate(0, &(cache->unboundCount), &genericCache); + return l->unify(r.get(), us); + } + + int s1 = 3, s = 0; + if (name == "__NTuple__" && tc->name == name) { + auto n1 = generics[0].getType()->getIntStatic(); + auto n2 = tc->generics[0].getType()->getIntStatic(); + if (n1 && n2) { + auto t1 = generics[1].getType()->getClass(); + auto t2 = tc->generics[1].getType()->getClass(); + seqassert(t1 && t2, "bad ntuples"); + if (n1->value * t1->generics.size() != n2->value * t2->generics.size()) + return -1; + for (size_t i = 0; i < t1->generics.size() * n1->value; i++) { + if ((s = t1->generics[i % t1->generics.size()].getType()->unify( + t2->generics[i % t2->generics.size()].getType(), us)) == -1) + return -1; + s1 += s; + } + return s1; + } + } else if (tc->name == "__NTuple__") { + return tc->unify(this, us); + } else if (name == "__NTuple__" && tc->name == TYPE_TUPLE) { + auto n1 = generics[0].getType()->getIntStatic(); + if (!n1) { + auto n = tc->generics.size(); + auto tn = std::make_shared(cache, n); + // If we are unifying NT[N, T] and T[X, X, ...], we assume that N is number of + // X's + if (generics[0].type->unify(tn.get(), us) == -1) + return -1; + + auto tv = TypecheckVisitor(cache->typeCtx); + TypePtr tt; + if (n) { + tt = tv.instantiateType(tv.generateTuple(1), {tc->generics[0].getType()}); + for (size_t i = 1; i < tc->generics.size(); i++) { + if ((s = tt->getClass()->generics[0].getType()->unify( + tc->generics[i].getType(), us)) == -1) + return -1; + s1 += s; + } + } else { + tt = tv.instantiateType(tv.generateTuple(1)); + // tt = tv.instantiateType(tv.generateTuple(0)); + } + if (generics[1].type->unify(tt.get(), us) == -1) + return -1; + } else { + auto t1 = generics[1].getType()->getClass(); + seqassert(t1, "bad ntuples"); + if (n1->value * t1->generics.size() != tc->generics.size()) + return -1; + for (size_t i = 0; i < t1->generics.size() * n1->value; i++) { + if ((s = t1->generics[i % t1->generics.size()].getType()->unify( + tc->generics[i].getType(), us)) == -1) + return -1; + s1 += s; + } + } + return s1; + } + // Check names. - if (name != tc->name) + if (name != tc->name) { return -1; + } // Check generics. - int s1 = 3, s = 0; if (generics.size() != tc->generics.size()) return -1; for (int i = 0; i < generics.size(); i++) { - if ((s = generics[i].type->unify(tc->generics[i].type.get(), us)) == -1) + if ((s = generics[i].type->unify(tc->generics[i].type.get(), us)) == -1) { return -1; + } + s1 += s; + } + for (int i = 0; i < hiddenGenerics.size(); i++) { + if ((s = hiddenGenerics[i].type->unify(tc->hiddenGenerics[i].type.get(), us)) == + -1) { + return -1; + } s1 += s; } return s1; @@ -41,30 +158,42 @@ int ClassType::unify(Type *typ, Unification *us) { } TypePtr ClassType::generalize(int atLevel) { - auto g = generics, hg = hiddenGenerics; - for (auto &t : g) - t.type = t.type ? t.type->generalize(atLevel) : nullptr; - for (auto &t : hg) - t.type = t.type ? t.type->generalize(atLevel) : nullptr; + std::vector g, hg; + for (auto &t : generics) + g.push_back(t.generalize(atLevel)); + for (auto &t : hiddenGenerics) + hg.push_back(t.generalize(atLevel)); auto c = std::make_shared(cache, name, niceName, g, hg); + c->isTuple = isTuple; c->setSrcInfo(getSrcInfo()); return c; } TypePtr ClassType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) { - auto g = generics, hg = hiddenGenerics; - for (auto &t : g) - t.type = t.type ? t.type->instantiate(atLevel, unboundCount, cache) : nullptr; - for (auto &t : hg) - t.type = t.type ? t.type->instantiate(atLevel, unboundCount, cache) : nullptr; + std::vector g, hg; + for (auto &t : generics) + g.push_back(t.instantiate(atLevel, unboundCount, cache)); + for (auto &t : hiddenGenerics) + hg.push_back(t.instantiate(atLevel, unboundCount, cache)); auto c = std::make_shared(this->cache, name, niceName, g, hg); + c->isTuple = isTuple; c->setSrcInfo(getSrcInfo()); return c; } -std::vector ClassType::getUnbounds() const { - std::vector u; +bool ClassType::hasUnbounds(bool includeGenerics) const { + for (auto &t : generics) + if (t.type && t.type->hasUnbounds(includeGenerics)) + return true; + for (auto &t : hiddenGenerics) + if (t.type && t.type->hasUnbounds(includeGenerics)) + return true; + return false; +} + +std::vector ClassType::getUnbounds() const { + std::vector u; for (auto &t : generics) if (t.type) { auto tu = t.type->getUnbounds(); @@ -79,6 +208,12 @@ std::vector ClassType::getUnbounds() const { } bool ClassType::canRealize() const { + if (name == "type") { + if (!hasUnbounds()) + return true; // always true! + } + if (name == "unrealized_type") + return generics[0].type->getClass() != nullptr; return std::all_of(generics.begin(), generics.end(), [](auto &t) { return !t.type || t.type->canRealize(); }) && std::all_of(hiddenGenerics.begin(), hiddenGenerics.end(), @@ -86,21 +221,57 @@ bool ClassType::canRealize() const { } bool ClassType::isInstantiated() const { + if (name == "unrealized_type") + return generics[0].type->getClass() != nullptr; return std::all_of(generics.begin(), generics.end(), [](auto &t) { return !t.type || t.type->isInstantiated(); }) && std::all_of(hiddenGenerics.begin(), hiddenGenerics.end(), [](auto &t) { return !t.type || t.type->isInstantiated(); }); } +ClassType *ClassType::getHeterogenousTuple() { + seqassert(canRealize(), "{} not realizable", toString()); + seqassert(name == TYPE_TUPLE, "{} not a tuple", toString()); + if (generics.size() > 1) { + std::string first = generics[0].type->realizedName(); + for (int i = 1; i < generics.size(); i++) + if (generics[i].type->realizedName() != first) + return getClass(); + } + return nullptr; +} + std::string ClassType::debugString(char mode) const { + if (name == "Partial" && generics[3].type->getClass()) { + std::vector as; + auto known = getPartialMask(); + auto func = getPartialFunc(); + for (int i = 0, gi = 0; i < known.size(); i++) { + if ((*func->ast)[i].isValue()) + as.emplace_back( + known[i] && generics[1].type->getClass() + ? generics[1].type->getClass()->generics[gi++].debugString(mode) + : "..."); + } + auto fnname = func->ast->getName(); + if (mode == 0) { + fnname = cache->rev(func->ast->getName()); + } else if (mode == 2) { + fnname = func->debugString(mode); + } + return fmt::format("{}[{}{}]", fnname, join(as, ","), + mode == 2 ? fmt::format(";{};{}", generics[1].debugString(mode), + generics[2].debugString(mode)) + : ""); + } std::vector gs; for (auto &a : generics) if (!a.name.empty()) - gs.push_back(a.type->debugString(mode)); + gs.push_back(a.debugString(mode)); if ((mode == 2) && !hiddenGenerics.empty()) { for (auto &a : hiddenGenerics) if (!a.name.empty()) - gs.push_back("-" + a.type->debugString(mode)); + gs.push_back("-" + a.debugString(mode)); } // Special formatting for Functions and Tuples auto n = mode == 0 ? niceName : name; @@ -111,205 +282,49 @@ std::string ClassType::realizedName() const { if (!_rn.empty()) return _rn; + std::string s; std::vector gs; - for (auto &a : generics) - if (!a.name.empty()) - gs.push_back(a.type->realizedName()); - std::string s = join(gs, ","); - if (canRealize()) - const_cast(this)->_rn = - fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s)); - return _rn; -} - -std::string ClassType::realizedTypeName() const { - return this->ClassType::realizedName(); -} - -RecordType::RecordType(Cache *cache, std::string name, std::string niceName, - std::vector generics, std::vector args, - bool noTuple, const std::shared_ptr &repeats) - : ClassType(cache, std::move(name), std::move(niceName), std::move(generics)), - args(std::move(args)), noTuple(false), repeats(repeats) {} - -RecordType::RecordType(const ClassTypePtr &base, std::vector args, - bool noTuple, const std::shared_ptr &repeats) - : ClassType(base), args(std::move(args)), noTuple(noTuple), repeats(repeats) {} - -int RecordType::unify(Type *typ, Unification *us) { - if (auto tr = typ->getRecord()) { - // Handle int <-> Int[64] - if (name == "int" && tr->name == "Int") - return tr->unify(this, us); - if (tr->name == "int" && name == "Int") { - auto t64 = std::make_shared(cache, 64); - return generics[0].type->unify(t64.get(), us); - } - - // TODO: we now support very limited unification strategy where repetitions must - // match. We should expand this later on... - if (repeats || tr->repeats) { - if (!repeats && tr->repeats) { - auto n = std::make_shared(cache, args.size()); - if (tr->repeats->unify(n.get(), us) == -1) - return -1; - } else if (!tr->repeats) { - auto n = std::make_shared(cache, tr->args.size()); - if (repeats->unify(n.get(), us) == -1) - return -1; - } else { - if (repeats->unify(tr->repeats.get(), us) == -1) - return -1; - } - } - if (getRepeats() != -1) - flatten(); - if (tr->getRepeats() != -1) - tr->flatten(); - - int s1 = 2, s = 0; - if (args.size() != tr->args.size()) - return -1; - for (int i = 0; i < args.size(); i++) { - if ((s = args[i]->unify(tr->args[i].get(), us)) != -1) - s1 += s; - else - return -1; - } - // Handle Tuple<->@tuple: when unifying tuples, only record members matter. - if (name == TYPE_TUPLE || tr->name == TYPE_TUPLE) { - if (!args.empty() || (!noTuple && !tr->noTuple)) // prevent POD<->() unification - return s1 + int(name == tr->name); - else - return -1; - } - return this->ClassType::unify(tr.get(), us); - } else if (auto t = typ->getLink()) { - return t->unify(this, us); + if (name == "Partial") { + gs.push_back(generics[3].realizedName()); + for (size_t i = 0; i < generics.size() - 1; i++) + gs.push_back(generics[i].realizedName()); + } else if (name == "Union" && generics[0].type->getClass()) { + std::set gss; + for (auto &a : generics[0].type->getClass()->generics) + gss.insert(a.realizedName()); + gs = {join(gss, " | ")}; } else { - return -1; - } -} - -TypePtr RecordType::generalize(int atLevel) { - auto c = std::static_pointer_cast(this->ClassType::generalize(atLevel)); - auto a = args; - for (auto &t : a) - t = t->generalize(atLevel); - auto r = repeats ? repeats->generalize(atLevel)->getStatic() : nullptr; - return std::make_shared(c, a, noTuple, r); -} - -TypePtr RecordType::instantiate(int atLevel, int *unboundCount, - std::unordered_map *cache) { - auto c = std::static_pointer_cast( - this->ClassType::instantiate(atLevel, unboundCount, cache)); - auto a = args; - for (auto &t : a) - t = t->instantiate(atLevel, unboundCount, cache); - auto r = repeats ? repeats->instantiate(atLevel, unboundCount, cache)->getStatic() - : nullptr; - return std::make_shared(c, a, noTuple, r); -} - -std::vector RecordType::getUnbounds() const { - std::vector u; - if (repeats) { - auto tu = repeats->getUnbounds(); - u.insert(u.begin(), tu.begin(), tu.end()); - } - for (auto &a : args) { - auto tu = a->getUnbounds(); - u.insert(u.begin(), tu.begin(), tu.end()); - } - auto tu = this->ClassType::getUnbounds(); - u.insert(u.begin(), tu.begin(), tu.end()); - return u; -} - -bool RecordType::canRealize() const { - return getRepeats() >= 0 && - std::all_of(args.begin(), args.end(), - [](auto &a) { return a->canRealize(); }) && - this->ClassType::canRealize(); -} - -bool RecordType::isInstantiated() const { - return (!repeats || repeats->isInstantiated()) && - std::all_of(args.begin(), args.end(), - [](auto &a) { return a->isInstantiated(); }) && - this->ClassType::isInstantiated(); -} - -std::string RecordType::realizedName() const { - if (!_rn.empty()) - return _rn; - if (name == TYPE_TUPLE) { - std::vector gs; - auto n = getRepeats(); - if (n == -1) - gs.push_back(repeats->realizedName()); - for (int i = 0; i < std::max(n, int64_t(0)); i++) - for (auto &a : args) - gs.push_back(a->realizedName()); - std::string s = join(gs, ","); - if (canRealize()) - const_cast(this)->_rn = - fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s)); - return _rn; - } - return ClassType::realizedName(); -} - -std::string RecordType::debugString(char mode) const { - if (name == TYPE_TUPLE) { - std::vector gs; - auto n = getRepeats(); - if (n == -1) - gs.push_back(repeats->debugString(mode)); - for (int i = 0; i < std::max(n, int64_t(0)); i++) - for (auto &a : args) - gs.push_back(a->debugString(mode)); - return fmt::format("{}{}", name, - gs.empty() ? "" : fmt::format("[{}]", join(gs, ","))); - } else { - return fmt::format("{}{}", repeats ? repeats->debugString(mode) + "," : "", - this->ClassType::debugString(mode)); + for (auto &a : generics) + if (!a.name.empty()) + gs.push_back(a.realizedName()); } + s = join(gs, ","); + s = fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s)); + return s; } -std::string RecordType::realizedTypeName() const { return realizedName(); } - -std::shared_ptr RecordType::getHeterogenousTuple() { - seqassert(canRealize(), "{} not realizable", toString()); - if (args.size() > 1) { - std::string first = args[0]->realizedName(); - for (int i = 1; i < args.size(); i++) - if (args[i]->realizedName() != first) - return getRecord(); - } - return nullptr; +FuncType *ClassType::getPartialFunc() const { + seqassert(name == "Partial", "not a partial"); + auto n = generics[3].type->getClass()->generics[0].type; + seqassert(n->getFunc(), "not a partial func"); + return n->getFunc(); } -/// Returns -1 if the type cannot be realized yet -int64_t RecordType::getRepeats() const { - if (!repeats) - return 1; - if (repeats->canRealize()) - return std::max(repeats->evaluate().getInt(), int64_t(0)); - return -1; +std::vector ClassType::getPartialMask() const { + seqassert(name == "Partial", "not a partial"); + auto n = generics[0].type->getStrStatic()->value; + std::vector r(n.size(), 0); + for (size_t i = 0; i < n.size(); i++) + if (n[i] == '1') + r[i] = 1; + return r; } -void RecordType::flatten() { - auto n = getRepeats(); - seqassert(n >= 0, "bad call to flatten"); - - auto a = args; - args.clear(); - for (int64_t i = 0; i < n; i++) - args.insert(args.end(), a.begin(), a.end()); - - repeats = nullptr; +bool ClassType::isPartialEmpty() const { + auto a = generics[1].type->getClass(); + auto ka = generics[2].type->getClass(); + return a->generics.size() == 1 && a->generics[0].type->getClass()->generics.empty() && + ka->generics[1].type->getClass()->generics.empty(); } } // namespace codon::ast::types diff --git a/codon/parser/ast/types/class.h b/codon/parser/ast/types/class.h index 96125496..c20ac1c1 100644 --- a/codon/parser/ast/types/class.h +++ b/codon/parser/ast/types/class.h @@ -11,6 +11,8 @@ namespace codon::ast::types { +struct FuncType; + /** * A generic class reference type. All Seq types inherit from this class. */ @@ -28,10 +30,19 @@ struct ClassType : public Type { int id; // Pointer to realized type (or generic LinkType). TypePtr type; - - Generic(std::string name, std::string niceName, TypePtr type, int id) - : name(std::move(name)), niceName(std::move(niceName)), id(id), - type(std::move(type)) {} + // Set if this is a static generic + char isStatic; + + Generic(std::string name, std::string niceName, TypePtr type, int id, char isStatic) + : name(std::move(name)), niceName(std::move(niceName)), type(std::move(type)), + id(id), isStatic(isStatic) {} + + types::Type *getType() const { return type.get(); } + Generic generalize(int atLevel); + Generic instantiate(int atLevel, int *unboundCount, + std::unordered_map *cache); + std::string debugString(char mode) const; + std::string realizedName() const; }; /// Canonical type name. @@ -43,12 +54,13 @@ struct ClassType : public Type { std::vector hiddenGenerics; + bool isTuple = false; std::string _rn; explicit ClassType(Cache *cache, std::string name, std::string niceName, std::vector generics = {}, std::vector hiddenGenerics = {}); - explicit ClassType(const std::shared_ptr &base); + explicit ClassType(ClassType *base); public: int unify(Type *typ, Unification *undo) override; @@ -57,58 +69,24 @@ struct ClassType : public Type { std::unordered_map *cache) override; public: - std::vector getUnbounds() const override; + bool hasUnbounds(bool = false) const override; + std::vector getUnbounds() const override; bool canRealize() const override; bool isInstantiated() const override; std::string debugString(char mode) const override; std::string realizedName() const override; - /// Convenience function to get the name of realized type - /// (needed if a subclass realizes something else as well). - virtual std::string realizedTypeName() const; - std::shared_ptr getClass() override { - return std::static_pointer_cast(shared_from_this()); - } -}; -using ClassTypePtr = std::shared_ptr; + ClassType *getClass() override { return this; } + ClassType *getPartial() override { return name == "Partial" ? getClass() : nullptr; } + bool isRecord() const { return isTuple; } -/** - * A generic class tuple (record) type. All Seq tuples inherit from this class. - */ -struct RecordType : public ClassType { - /// List of tuple arguments. - std::vector args; - bool noTuple; - std::shared_ptr repeats = nullptr; - - explicit RecordType( - Cache *cache, std::string name, std::string niceName, - std::vector generics = std::vector(), - std::vector args = std::vector(), bool noTuple = false, - const std::shared_ptr &repeats = nullptr); - RecordType(const ClassTypePtr &base, std::vector args, bool noTuple = false, - const std::shared_ptr &repeats = nullptr); + size_t size() const { return generics.size(); } + Type *operator[](int i) const { return generics[i].getType(); } public: - int unify(Type *typ, Unification *undo) override; - TypePtr generalize(int atLevel) override; - TypePtr instantiate(int atLevel, int *unboundCount, - std::unordered_map *cache) override; - -public: - std::vector getUnbounds() const override; - bool canRealize() const override; - bool isInstantiated() const override; - std::string debugString(char mode) const override; - std::string realizedName() const override; - std::string realizedTypeName() const override; - - std::shared_ptr getRecord() override { - return std::static_pointer_cast(shared_from_this()); - } - std::shared_ptr getHeterogenousTuple() override; - - int64_t getRepeats() const; - void flatten(); + ClassType *getHeterogenousTuple() override; + FuncType *getPartialFunc() const; + std::vector getPartialMask() const; + bool isPartialEmpty() const; }; } // namespace codon::ast::types diff --git a/codon/parser/ast/types/function.cpp b/codon/parser/ast/types/function.cpp index 4b681258..eb3673b7 100644 --- a/codon/parser/ast/types/function.cpp +++ b/codon/parser/ast/types/function.cpp @@ -10,10 +10,10 @@ namespace codon::ast::types { -FuncType::FuncType(const std::shared_ptr &baseType, FunctionStmt *ast, +FuncType::FuncType(ClassType *baseType, FunctionStmt *ast, size_t index, std::vector funcGenerics, TypePtr funcParent) - : RecordType(*baseType), ast(ast), funcGenerics(std::move(funcGenerics)), - funcParent(std::move(funcParent)) {} + : ClassType(baseType), ast(ast), index(index), + funcGenerics(std::move(funcGenerics)), funcParent(std::move(funcParent)) {} int FuncType::unify(Type *typ, Unification *us) { if (this == typ) @@ -21,52 +21,83 @@ int FuncType::unify(Type *typ, Unification *us) { int s1 = 2, s = 0; if (auto t = typ->getFunc()) { // Check if names and parents match. - if (ast->name != t->ast->name || (bool(funcParent) ^ bool(t->funcParent))) + if (ast->getName() != t->ast->getName() || index != t->index || + (bool(funcParent) ^ bool(t->funcParent))) return -1; - if (funcParent && (s = funcParent->unify(t->funcParent.get(), us)) == -1) + if (funcParent && (s = funcParent->unify(t->funcParent.get(), us)) == -1) { return -1; + } s1 += s; // Check if function generics match. seqassert(funcGenerics.size() == t->funcGenerics.size(), - "generic size mismatch for {}", ast->name); + "generic size mismatch for {}", ast->getName()); for (int i = 0; i < funcGenerics.size(); i++) { if ((s = funcGenerics[i].type->unify(t->funcGenerics[i].type.get(), us)) == -1) return -1; s1 += s; } } - s = this->RecordType::unify(typ, us); + s = this->ClassType::unify(typ, us); return s == -1 ? s : s1 + s; } TypePtr FuncType::generalize(int atLevel) { - auto g = funcGenerics; - for (auto &t : g) - t.type = t.type ? t.type->generalize(atLevel) : nullptr; + std::vector fg; + for (auto &t : funcGenerics) + fg.push_back(t.generalize(atLevel)); auto p = funcParent ? funcParent->generalize(atLevel) : nullptr; - return std::make_shared( - std::static_pointer_cast(this->RecordType::generalize(atLevel)), ast, - g, p); + + auto r = std::static_pointer_cast(this->ClassType::generalize(atLevel)); + + // // Fix statics + // auto &at = r->generics[0].getType()->getClass()->generics; + // for (size_t i = 0; i < at.size(); i++) { + // bool isStatic = ast && getStaticGeneric((*ast)[i].getType()); + // if (!isStatic && at[i].getType() && at[i].getType()->getStatic()) + // at[i].type = + // at[i].getType()->getStatic()->getNonStaticType()->generalize(atLevel); + // } + // auto rt = r->generics[0].getType(); + // bool isStatic = ast && getStaticGeneric(ast->getReturn()); + // if (!isStatic && rt && rt->getStatic()) + // r->generics[0].type = rt->getStatic()->getNonStaticType()->generalize(atLevel); + + auto t = std::make_shared(r->getClass(), ast, index, fg, p); + return t; } TypePtr FuncType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) { - auto g = funcGenerics; - for (auto &t : g) - if (t.type) { - t.type = t.type->instantiate(atLevel, unboundCount, cache); + std::vector fg; + for (auto &t : funcGenerics) { + fg.push_back(t.instantiate(atLevel, unboundCount, cache)); + if (fg.back().type) { if (cache && cache->find(t.id) == cache->end()) - (*cache)[t.id] = t.type; + (*cache)[t.id] = fg.back().type; } + } auto p = funcParent ? funcParent->instantiate(atLevel, unboundCount, cache) : nullptr; - return std::make_shared( - std::static_pointer_cast( - this->RecordType::instantiate(atLevel, unboundCount, cache)), - ast, g, p); + auto r = std::static_pointer_cast( + this->ClassType::instantiate(atLevel, unboundCount, cache)); + + auto t = std::make_shared(r->getClass(), ast, index, fg, p); + return t; +} + +bool FuncType::hasUnbounds(bool includeGenerics) const { + for (auto &t : funcGenerics) + if (t.type && t.type->hasUnbounds(includeGenerics)) + return true; + if (funcParent && funcParent->hasUnbounds(includeGenerics)) + return true; + for (const auto &a : *this) + if (a.getType()->hasUnbounds(includeGenerics)) + return true; + return getRetType()->hasUnbounds(includeGenerics); } -std::vector FuncType::getUnbounds() const { - std::vector u; +std::vector FuncType::getUnbounds() const { + std::vector u; for (auto &t : funcGenerics) if (t.type) { auto tu = t.type->getUnbounds(); @@ -77,8 +108,8 @@ std::vector FuncType::getUnbounds() const { u.insert(u.begin(), tu.begin(), tu.end()); } // Important: return type unbounds are not important, so skip them. - for (auto &a : getArgTypes()) { - auto tu = a->getUnbounds(); + for (const auto &a : *this) { + auto tu = a.getType()->getUnbounds(); u.insert(u.begin(), tu.begin(), tu.end()); } return u; @@ -86,11 +117,10 @@ std::vector FuncType::getUnbounds() const { bool FuncType::canRealize() const { // Important: return type does not have to be realized. - bool skipSelf = ast->hasAttr(Attr::RealizeWithoutSelf); + bool skipSelf = ast->hasAttribute(Attr::RealizeWithoutSelf); - auto args = getArgTypes(); - for (int ai = skipSelf; ai < args.size(); ai++) - if (!args[ai]->getFunc() && !args[ai]->canRealize()) + for (int ai = skipSelf; ai < size(); ai++) + if (!(*this)[ai]->getFunc() && !(*this)[ai]->canRealize()) return false; bool generics = std::all_of(funcGenerics.begin(), funcGenerics.end(), [](auto &a) { return !a.type || a.type->canRealize(); }); @@ -99,10 +129,6 @@ bool FuncType::canRealize() const { return generics; } -std::string FuncType::realizedTypeName() const { - return this->ClassType::realizedName(); -} - bool FuncType::isInstantiated() const { TypePtr removed = nullptr; auto retType = getRetType(); @@ -113,7 +139,7 @@ bool FuncType::isInstantiated() const { auto res = std::all_of(funcGenerics.begin(), funcGenerics.end(), [](auto &a) { return !a.type || a.type->isInstantiated(); }) && (!funcParent || funcParent->isInstantiated()) && - this->RecordType::isInstantiated(); + this->ClassType::isInstantiated(); if (removed) retType->getFunc()->funcParent = removed; return res; @@ -123,23 +149,37 @@ std::string FuncType::debugString(char mode) const { std::vector gs; for (auto &a : funcGenerics) if (!a.name.empty()) - gs.push_back(a.type->debugString(mode)); + gs.push_back(mode < 2 + ? a.type->debugString(mode) + : fmt::format("{}={}", a.niceName, a.type->debugString(mode))); std::string s = join(gs, ","); std::vector as; // Important: return type does not have to be realized. if (mode == 2) - as.push_back(getRetType()->debugString(mode)); - for (auto &a : getArgTypes()) - as.push_back(a->debugString(mode)); + as.push_back(fmt::format("RET={}", getRetType()->debugString(mode))); + + if (mode < 2 || !ast) { + for (const auto &a : *this) { + as.push_back(a.debugString(mode)); + } + } else { + for (size_t i = 0, si = 0; i < ast->size(); i++) { + if ((*ast)[i].isGeneric()) + continue; + as.push_back(fmt::format("{}={}", (*ast)[i].getName(), (*this)[si++]->debugString(mode))); + } + } std::string a = join(as, ","); - s = s.empty() ? a : join(std::vector{a, s}, ","); + s = s.empty() ? a : join(std::vector{s, a}, ";"); - auto fnname = ast->name; + auto fnname = ast->getName(); if (mode == 0) { - fnname = cache->rev(ast->name); - // if (funcParent) - // fnname = fmt::format("{}.{}", funcParent->debugString(mode), fnname); + fnname = cache->rev(ast->getName()); } + if (mode && index) + fnname += fmt::format("/{}", index); + if (mode == 2 && funcParent) + s += fmt::format(";{}", funcParent->debugString(mode)); return fmt::format("{}{}", fnname, s.empty() ? "" : fmt::format("[{}]", s)); } @@ -147,73 +187,39 @@ std::string FuncType::realizedName() const { std::vector gs; for (auto &a : funcGenerics) if (!a.name.empty()) - gs.push_back(a.type->realizedName()); + gs.push_back(a.realizedName()); std::string s = join(gs, ","); std::vector as; // Important: return type does not have to be realized. - for (auto &a : getArgTypes()) - as.push_back(a->getFunc() ? a->getFunc()->realizedName() : a->realizedName()); + for (const auto &a : *this) + as.push_back(a.getType()->getFunc() ? a.getType()->getFunc()->realizedName() + : a.realizedName()); std::string a = join(as, ","); s = s.empty() ? a : join(std::vector{a, s}, ","); - return fmt::format("{}{}{}", funcParent ? funcParent->realizedName() + ":" : "", - ast->name, s.empty() ? "" : fmt::format("[{}]", s)); + return fmt::format("{}{}{}{}", funcParent ? funcParent->realizedName() + ":" : "", + ast->getName(), index ? fmt::format("/{}", index) : "", + s.empty() ? "" : fmt::format("[{}]", s)); } -PartialType::PartialType(const std::shared_ptr &baseType, - std::shared_ptr func, std::vector known) - : RecordType(*baseType), func(std::move(func)), known(std::move(known)) {} +Type *FuncType::getRetType() const { return generics[1].type.get(); } -int PartialType::unify(Type *typ, Unification *us) { - return this->RecordType::unify(typ, us); -} +std::string FuncType::getFuncName() const { return ast->getName(); } -TypePtr PartialType::generalize(int atLevel) { - return std::make_shared( - std::static_pointer_cast(this->RecordType::generalize(atLevel)), func, - known); +Type *FuncType::operator[](int i) const { + return generics[0].type->getClass()->generics[i].getType(); } -TypePtr PartialType::instantiate(int atLevel, int *unboundCount, - std::unordered_map *cache) { - auto rec = std::static_pointer_cast( - this->RecordType::instantiate(atLevel, unboundCount, cache)); - return std::make_shared(rec, func, known); +std::vector::iterator FuncType::begin() const { + return generics[0].type->getClass()->generics.begin(); } -std::string PartialType::debugString(char mode) const { - std::vector gs; - for (auto &a : generics) - if (!a.name.empty()) - gs.push_back(a.type->debugString(mode)); - std::vector as; - int i = 0, gi = 0; - for (; i < known.size(); i++) - if (func->ast->args[i].status == Param::Normal) { - if (!known[i]) - as.emplace_back("..."); - else - as.emplace_back(gs[gi++]); - } - auto fnname = func->ast->name; - if (mode == 0) { - fnname = cache->rev(func->ast->name); - // if (func->funcParent) - // fnname = fmt::format("{}.{}", func->funcParent->debugString(mode), fnname); - } else if (mode == 2) { - fnname = func->debugString(mode); - } - return fmt::format("{}[{}{}]", fnname, join(as, ","), - mode == 2 ? fmt::format(";{}", join(gs, ",")) : ""); +std::vector::iterator FuncType::end() const { + return generics[0].type->getClass()->generics.begin() + + generics[0].type->getClass()->generics.size(); } -std::string PartialType::realizedName() const { - std::vector gs; - gs.push_back(func->ast->name); - for (auto &a : generics) - if (!a.name.empty()) - gs.push_back(a.type->realizedName()); - std::string s = join(gs, ","); - return fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s)); -} +size_t FuncType::size() const { return generics[0].type->getClass()->generics.size(); } + +bool FuncType::empty() const { return generics[0].type->getClass()->generics.empty(); } } // namespace codon::ast::types diff --git a/codon/parser/ast/types/function.h b/codon/parser/ast/types/function.h index 4d36d1a1..eb2ab8e8 100644 --- a/codon/parser/ast/types/function.h +++ b/codon/parser/ast/types/function.h @@ -18,13 +18,15 @@ namespace codon::ast::types { /** * A generic type that represents a Seq function instantiation. - * It inherits RecordType that realizes Callable[...]. + * It inherits ClassType that realizes Callable[...]. * * ⚠️ This is not a function pointer (Function[...]) type. */ -struct FuncType : public RecordType { +struct FuncType : public ClassType { /// Canonical AST node. FunctionStmt *ast; + /// Function capture index. + size_t index; /// Function generics (e.g. T in def foo[T](...)). std::vector funcGenerics; /// Enclosing class or a function. @@ -32,7 +34,7 @@ struct FuncType : public RecordType { public: FuncType( - const std::shared_ptr &baseType, FunctionStmt *ast, + ClassType *baseType, FunctionStmt *ast, size_t index = 0, std::vector funcGenerics = std::vector(), TypePtr funcParent = nullptr); @@ -43,55 +45,24 @@ struct FuncType : public RecordType { std::unordered_map *cache) override; public: - std::vector getUnbounds() const override; + bool hasUnbounds(bool = false) const override; + std::vector getUnbounds() const override; bool canRealize() const override; bool isInstantiated() const override; std::string debugString(char mode) const override; std::string realizedName() const override; - std::string realizedTypeName() const override; - std::shared_ptr getFunc() override { - return std::static_pointer_cast(shared_from_this()); - } + FuncType *getFunc() override { return this; } - std::vector &getArgTypes() const { - return generics[0].type->getRecord()->args; - } - TypePtr getRetType() const { return generics[1].type; } -}; -using FuncTypePtr = std::shared_ptr; - -/** - * A generic type that represents a partial Seq function instantiation. - * It inherits RecordType that realizes Tuple[...]. - * - * Note: partials only work on Seq functions. Function pointer partials - * will become a partials of Function.__call__ Seq function. - */ -struct PartialType : public RecordType { - /// Seq function that is being partialized. Always generic (not instantiated). - FuncTypePtr func; - /// Arguments that are already provided (1 for known argument, 0 for expecting). - std::vector known; - -public: - PartialType(const std::shared_ptr &baseType, - std::shared_ptr func, std::vector known); - -public: - int unify(Type *typ, Unification *us) override; - TypePtr generalize(int atLevel) override; - TypePtr instantiate(int atLevel, int *unboundCount, - std::unordered_map *cache) override; + Type *getRetType() const; + Type *getParentType() const { return funcParent.get(); } + std::string getFuncName() const; - std::string debugString(char mode) const override; - std::string realizedName() const override; - -public: - std::shared_ptr getPartial() override { - return std::static_pointer_cast(shared_from_this()); - } + Type *operator[](int i) const; + std::vector::iterator begin() const; + std::vector::iterator end() const; + size_t size() const; + bool empty() const; }; -using PartialTypePtr = std::shared_ptr; } // namespace codon::ast::types diff --git a/codon/parser/ast/types/link.cpp b/codon/parser/ast/types/link.cpp index 836572b1..b2183cad 100644 --- a/codon/parser/ast/types/link.cpp +++ b/codon/parser/ast/types/link.cpp @@ -31,46 +31,52 @@ int LinkType::unify(Type *typ, Unification *undo) { if (kind == Link) { // Case 1: Just follow the link return type->unify(typ, undo); - } else if (kind == Generic) { - // Case 2: Generic types cannot be unified. - return -1; } else { // Case 3: Unbound unification - if (isStaticType() != typ->isStaticType()) - return -1; - if (auto ts = typ->getStatic()) { - if (ts->expr->getId()) - return unify(ts->generics[0].type.get(), undo); + if (isStaticType() != typ->isStaticType()) { + if (!isStaticType()) { + // other one is; move this to non-static equivalent + if (undo) { + undo->statics.push_back(shared_from_this()); + isStatic = typ->isStaticType(); + } + } else { + return -1; + } } if (auto t = typ->getLink()) { if (t->kind == Link) return t->type->unify(this, undo); - else if (t->kind == Generic) + if (kind != t->kind) return -1; - else { - if (id == t->id) { - // Identical unbound types get a score of 1 - return 1; - } else if (id < t->id) { - // Always merge a newer type into the older type (e.g. keep the types with - // lower IDs around). - return t->unify(this, undo); - } - } + // Identical unbound types get a score of 1 + if (id == t->id) + return 1; + // Generics must have matching IDs unless we are doing non-destructive unification + if (kind == Generic) + return undo ? -1 : 1; + // Always merge a newer type into the older type (e.g. keep the types with + // lower IDs around). + if (id < t->id) + return t->unify(this, undo); + } else if (kind == Generic) { + return -1; } + // Generics must be handled by now; only unbounds can be unified! + seqassertn(kind == Unbound, "not an unbound"); + // Ensure that we do not have recursive unification! (e.g. unify ?1 with list[?1]) if (occurs(typ, undo)) return -1; - + // Handle traits if (trait && trait->unify(typ, undo) == -1) return -1; - // ⚠️ Unification: destructive part. seqassert(!type, "type has been already unified or is in inconsistent state"); if (undo) { LOG_TYPECHECK("[unify] {} := {}", id, typ->debugString(2)); // Link current type to typ and ensure that this modification is recorded in undo. - undo->linked.push_back(this); + undo->linked.push_back(shared_from_this()); kind = Link; seqassert(!typ->getLink() || typ->getLink()->kind != Unbound || typ->getLink()->id <= id, @@ -78,7 +84,7 @@ int LinkType::unify(Type *typ, Unification *undo) { type = typ->follow(); if (auto t = type->getLink()) if (trait && t->kind == Unbound && !t->trait) { - undo->traits.push_back(t.get()); + undo->traits.push_back(t->shared_from_this()); t->trait = trait; } } @@ -134,14 +140,24 @@ TypePtr LinkType::follow() { return shared_from_this(); } -std::vector LinkType::getUnbounds() const { +std::vector LinkType::getUnbounds() const { if (kind == Unbound) - return {std::const_pointer_cast(shared_from_this())}; + return {(Type *)this}; else if (kind == Link) return type->getUnbounds(); return {}; } +bool LinkType::hasUnbounds(bool includeGenerics) const { + if (kind == Unbound) + return true; + if (includeGenerics && kind == Generic) + return true; + if (kind == Link) + return type->hasUnbounds(includeGenerics); + return false; +} + bool LinkType::canRealize() const { if (kind != Link) return false; @@ -154,55 +170,59 @@ bool LinkType::isInstantiated() const { return kind == Link && type->isInstantia std::string LinkType::debugString(char mode) const { if (kind == Unbound || kind == Generic) { if (mode == 2) { - return fmt::format("{}{}{}{}", genericName.empty() ? "" : genericName + ":", + return fmt::format("{}{}{}{}{}", genericName.empty() ? "" : genericName + ":", kind == Unbound ? '?' : '#', id, - trait ? ":" + trait->debugString(mode) : ""); - } - if (trait) + trait ? ":" + trait->debugString(mode) : "", + isStatic ? fmt::format(":S{}", int(isStatic)) : ""); + } else if (trait) { return trait->debugString(mode); + } return (genericName.empty() ? (mode ? "?" : "") : genericName); } + // if (mode == 2) + // return ">" + type->debugString(mode); return type->debugString(mode); } std::string LinkType::realizedName() const { - if (kind == Unbound || kind == Generic) - return "?"; + if (kind == Unbound) + // return "?"; + return fmt::format("#{}", genericName); + if (kind == Generic) + return fmt::format("#{}", genericName); seqassert(kind == Link, "unexpected generic link"); return type->realizedName(); } -std::shared_ptr LinkType::getLink() { - return std::static_pointer_cast(shared_from_this()); -} +LinkType *LinkType::getLink() { return this; } -std::shared_ptr LinkType::getFunc() { - return kind == Link ? type->getFunc() : nullptr; -} +FuncType *LinkType::getFunc() { return kind == Link ? type->getFunc() : nullptr; } -std::shared_ptr LinkType::getPartial() { +ClassType *LinkType::getPartial() { return kind == Link ? type->getPartial() : nullptr; } -std::shared_ptr LinkType::getClass() { - return kind == Link ? type->getClass() : nullptr; -} +ClassType *LinkType::getClass() { return kind == Link ? type->getClass() : nullptr; } + +StaticType *LinkType::getStatic() { return kind == Link ? type->getStatic() : nullptr; } -std::shared_ptr LinkType::getRecord() { - return kind == Link ? type->getRecord() : nullptr; +IntStaticType *LinkType::getIntStatic() { + return kind == Link ? type->getIntStatic() : nullptr; } -std::shared_ptr LinkType::getStatic() { - return kind == Link ? type->getStatic() : nullptr; +StrStaticType *LinkType::getStrStatic() { + return kind == Link ? type->getStrStatic() : nullptr; } -std::shared_ptr LinkType::getUnion() { - return kind == Link ? type->getUnion() : nullptr; +BoolStaticType *LinkType::getBoolStatic() { + return kind == Link ? type->getBoolStatic() : nullptr; } -std::shared_ptr LinkType::getUnbound() { +UnionType *LinkType::getUnion() { return kind == Link ? type->getUnion() : nullptr; } + +LinkType *LinkType::getUnbound() { if (kind == Unbound) - return std::static_pointer_cast(shared_from_this()); + return this; if (kind == Link) return type->getUnbound(); return nullptr; @@ -216,7 +236,7 @@ bool LinkType::occurs(Type *typ, Type::Unification *undo) { if (tl->trait && occurs(tl->trait.get(), undo)) return true; if (undo && tl->level > level) { - undo->leveled.emplace_back(make_pair(tl.get(), tl->level)); + undo->leveled.emplace_back(tl->shared_from_this(), tl->level); tl->level = level; } return false; @@ -226,19 +246,12 @@ bool LinkType::occurs(Type *typ, Type::Unification *undo) { return false; } } else if (auto ts = typ->getStatic()) { - for (auto &g : ts->generics) - if (g.type && occurs(g.type.get(), undo)) - return true; return false; } if (auto tc = typ->getClass()) { for (auto &g : tc->generics) if (g.type && occurs(g.type.get(), undo)) return true; - if (auto tr = typ->getRecord()) - for (auto &t : tr->args) - if (occurs(t.get(), undo)) - return true; return false; } else { return false; diff --git a/codon/parser/ast/types/link.h b/codon/parser/ast/types/link.h index 862933ca..2c2a6845 100644 --- a/codon/parser/ast/types/link.h +++ b/codon/parser/ast/types/link.h @@ -45,20 +45,23 @@ struct LinkType : public Type { public: TypePtr follow() override; - std::vector getUnbounds() const override; + bool hasUnbounds(bool = false) const override; + std::vector getUnbounds() const override; bool canRealize() const override; bool isInstantiated() const override; std::string debugString(char mode) const override; std::string realizedName() const override; - std::shared_ptr getLink() override; - std::shared_ptr getFunc() override; - std::shared_ptr getPartial() override; - std::shared_ptr getClass() override; - std::shared_ptr getRecord() override; - std::shared_ptr getStatic() override; - std::shared_ptr getUnion() override; - std::shared_ptr getUnbound() override; + LinkType *getLink() override; + FuncType *getFunc() override; + ClassType *getPartial() override; + ClassType *getClass() override; + StaticType *getStatic() override; + IntStaticType *getIntStatic() override; + StrStaticType *getStrStatic() override; + BoolStaticType *getBoolStatic() override; + UnionType *getUnion() override; + LinkType *getUnbound() override; private: /// Checks if a current (unbound) type occurs within a given type. diff --git a/codon/parser/ast/types/static.cpp b/codon/parser/ast/types/static.cpp index b0d4532b..999653f7 100644 --- a/codon/parser/ast/types/static.cpp +++ b/codon/parser/ast/types/static.cpp @@ -13,167 +13,92 @@ namespace codon::ast::types { -StaticType::StaticType(Cache *cache, const std::shared_ptr &e) - : Type(cache), expr(e->clone()) { - if (!expr->isStatic() || !expr->staticValue.evaluated) { - std::unordered_set seen; - parseExpr(expr, seen); - } +StaticType::StaticType(Cache *cache, const std::string &typeName) + : ClassType(cache, typeName, typeName) {} + +TypePtr StaticType::generalize(int atLevel) { return shared_from_this(); } + +TypePtr StaticType::instantiate(int atLevel, int *unboundCount, + std::unordered_map *cache) { + return shared_from_this(); } -StaticType::StaticType(Cache *cache, std::vector generics, - const std::shared_ptr &e) - : Type(cache), generics(std::move(generics)), expr(e->clone()) {} - -StaticType::StaticType(Cache *cache, int64_t i) - : Type(cache), expr(std::make_shared(i)) {} - -StaticType::StaticType(Cache *cache, const std::string &s) - : Type(cache), expr(std::make_shared(s)) {} - -int StaticType::unify(Type *typ, Unification *us) { - if (auto t = typ->getStatic()) { - if (canRealize()) - expr->staticValue = evaluate(); - if (t->canRealize()) - t->expr->staticValue = t->evaluate(); - // Check if both types are already evaluated. - if (expr->staticValue.type != t->expr->staticValue.type) - return -1; - if (expr->staticValue.evaluated && t->expr->staticValue.evaluated) - return expr->staticValue == t->expr->staticValue ? 2 : -1; - else if (expr->staticValue.evaluated && !t->expr->staticValue.evaluated) - return typ->unify(this, us); - - // Right now, *this is not evaluated - // Let us see can we unify it with other _if_ it is a simple IdExpr? - if (expr->getId() && t->expr->staticValue.evaluated) { - return generics[0].type->unify(typ, us); - } - - // At this point, *this is a complex expression (e.g. A+1). - seqassert(!generics.empty(), "unevaluated simple expression"); - if (generics.size() != t->generics.size()) - return -1; - - int s1 = 2, s = 0; - if (!(expr->getId() && t->expr->getId()) && expr->toString() != t->expr->toString()) - return -1; - for (int i = 0; i < generics.size(); i++) { - if ((s = generics[i].type->unify(t->generics[i].type.get(), us)) == -1) - return -1; - s1 += s; - } - return s1; +bool StaticType::canRealize() const { return true; } + +bool StaticType::isInstantiated() const { return true; } + +std::string StaticType::realizedName() const { return debugString(0); } + +Type *StaticType::getNonStaticType() const { return cache->findClass(name); } + +/*****************************************************************/ + +IntStaticType::IntStaticType(Cache *cache, int64_t i) + : StaticType(cache, "int"), value(i) {} + +int IntStaticType::unify(Type *typ, Unification *us) { + if (auto t = typ->getIntStatic()) { + return value == t->value ? 1 : -1; + } else if (auto c = typ->getClass()) { + return ClassType::unify(c, us); } else if (auto tl = typ->getLink()) { return tl->unify(this, us); + } else { + return -1; } - return -1; } -TypePtr StaticType::generalize(int atLevel) { - auto e = generics; - for (auto &t : e) - t.type = t.type ? t.type->generalize(atLevel) : nullptr; - auto c = std::make_shared(cache, e, expr); - c->setSrcInfo(getSrcInfo()); - return c; +std::string IntStaticType::debugString(char mode) const { + return mode == 0 ? fmt::format("{}", value) : fmt::format("Static[{}]", value); } -TypePtr StaticType::instantiate(int atLevel, int *unboundCount, - std::unordered_map *cache) { - auto e = generics; - for (auto &t : e) - t.type = t.type ? t.type->instantiate(atLevel, unboundCount, cache) : nullptr; - auto c = std::make_shared(this->cache, e, expr); - c->setSrcInfo(getSrcInfo()); - return c; -} +Expr *IntStaticType::getStaticExpr() const { return cache->N(value); } -std::vector StaticType::getUnbounds() const { - std::vector u; - for (auto &t : generics) - if (t.type) { - auto tu = t.type->getUnbounds(); - u.insert(u.begin(), tu.begin(), tu.end()); - } - return u; -} - -bool StaticType::canRealize() const { - if (!expr->staticValue.evaluated) - for (auto &t : generics) - if (t.type && !t.type->canRealize()) - return false; - return true; -} +/*****************************************************************/ -bool StaticType::isInstantiated() const { return expr->staticValue.evaluated; } +StrStaticType::StrStaticType(Cache *cache, std::string s) + : StaticType(cache, "str"), value(std::move(s)) {} -std::string StaticType::debugString(char mode) const { - if (expr->staticValue.evaluated) - return expr->staticValue.toString(); - if (mode == 2) { - std::vector s; - for (auto &g : generics) - s.push_back(g.type->debugString(mode)); - return fmt::format("Static[{};{}]", join(s, ","), expr->toString()); +int StrStaticType::unify(Type *typ, Unification *us) { + if (auto t = typ->getStrStatic()) { + return value == t->value ? 1 : -1; + } else if (auto c = typ->getClass()) { + return ClassType::unify(c, us); + } else if (auto tl = typ->getLink()) { + return tl->unify(this, us); } else { - return fmt::format("Static[{}]", FormatVisitor::apply(expr)); + return -1; } } -std::string StaticType::realizedName() const { - seqassert(canRealize(), "cannot realize {}", toString()); - std::vector deps; - for (auto &e : generics) - deps.push_back(e.type->realizedName()); - if (!expr->staticValue.evaluated) // If not already evaluated, evaluate! - const_cast(this)->expr->staticValue = evaluate(); - seqassert(expr->staticValue.evaluated, "static value not evaluated"); - return expr->staticValue.toString(); +std::string StrStaticType::debugString(char mode) const { + return mode == 0 ? fmt::format("'{}'", escape(value)) + : fmt::format("Static['{}']", escape(value)); } -StaticValue StaticType::evaluate() const { - if (expr->staticValue.evaluated) - return expr->staticValue; - cache->typeCtx->addBlock(); - for (auto &g : generics) - cache->typeCtx->add(TypecheckItem::Type, g.name, g.type); - auto oldChangedNodes = cache->typeCtx->changedNodes; - auto en = TypecheckVisitor(cache->typeCtx).transform(expr->clone()); - cache->typeCtx->changedNodes = oldChangedNodes; - seqassert(en->isStatic() && en->staticValue.evaluated, "{} cannot be evaluated", en); - cache->typeCtx->popBlock(); - return en->staticValue; -} +Expr *StrStaticType::getStaticExpr() const { return cache->N(value); } + +/*****************************************************************/ -void StaticType::parseExpr(const ExprPtr &e, std::unordered_set &seen) { - e->type = nullptr; - if (auto ei = e->getId()) { - if (!in(seen, ei->value)) { - auto val = cache->typeCtx->find(ei->value); - seqassert(val && val->type->isStaticType(), "invalid static expression"); - auto genTyp = val->type->follow(); - auto id = genTyp->getLink() ? genTyp->getLink()->id - : genTyp->getStatic()->generics.empty() - ? 0 - : genTyp->getStatic()->generics[0].id; - generics.emplace_back(ClassType::Generic( - ei->value, cache->typeCtx->cache->reverseIdentifierLookup[ei->value], genTyp, - id)); - seen.insert(ei->value); - } - } else if (auto eu = e->getUnary()) { - parseExpr(eu->expr, seen); - } else if (auto eb = e->getBinary()) { - parseExpr(eb->lexpr, seen); - parseExpr(eb->rexpr, seen); - } else if (auto ef = e->getIf()) { - parseExpr(ef->cond, seen); - parseExpr(ef->ifexpr, seen); - parseExpr(ef->elsexpr, seen); +BoolStaticType::BoolStaticType(Cache *cache, bool b) + : StaticType(cache, "bool"), value(b) {} + +int BoolStaticType::unify(Type *typ, Unification *us) { + if (auto t = typ->getBoolStatic()) { + return value == t->value ? 1 : -1; + } else if (auto c = typ->getClass()) { + return ClassType::unify(c, us); + } else if (auto tl = typ->getLink()) { + return tl->unify(this, us); + } else { + return -1; } } +std::string BoolStaticType::debugString(char mode) const { + return fmt::format("Static[{}]", value ? "True" : "False"); +} + +Expr *BoolStaticType::getStaticExpr() const { return cache->N(value); } + } // namespace codon::ast::types diff --git a/codon/parser/ast/types/static.h b/codon/parser/ast/types/static.h index f4730865..6c19d167 100644 --- a/codon/parser/ast/types/static.h +++ b/codon/parser/ast/types/static.h @@ -9,52 +9,73 @@ #include "codon/parser/ast/types/class.h" -namespace codon::ast { -struct StaticValue; -} - namespace codon::ast::types { -/** - * A static integer type (e.g. N in def foo[N: int]). Usually an integer, but can point - * to a static expression. - */ -struct StaticType : public Type { - /// List of static variables that a type depends on - /// (e.g. for A+B+2, generics are {A, B}). - std::vector generics; - /// A static expression that needs to be evaluated. - /// Can be nullptr if there is no expression. - std::shared_ptr expr; - - StaticType(Cache *cache, std::vector generics, - const std::shared_ptr &expr); - /// Convenience function that parses expr and populates static type generics. - StaticType(Cache *cache, const std::shared_ptr &expr); - /// Convenience function for static types whose evaluation is already known. - explicit StaticType(Cache *cache, int64_t i); - explicit StaticType(Cache *cache, const std::string &s); +struct StaticType : public ClassType { + explicit StaticType(Cache *, const std::string &); public: - int unify(Type *typ, Unification *undo) override; TypePtr generalize(int atLevel) override; TypePtr instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) override; - -public: - std::vector getUnbounds() const override; bool canRealize() const override; bool isInstantiated() const override; - std::string debugString(char mode) const override; std::string realizedName() const override; + virtual Expr *getStaticExpr() const = 0; + virtual Type *getNonStaticType() const; + StaticType *getStatic() override { return this; } +}; + +struct IntStaticType : public StaticType { + int64_t value; - StaticValue evaluate() const; - std::shared_ptr getStatic() override { - return std::static_pointer_cast(shared_from_this()); - } +public: + explicit IntStaticType(Cache *cache, int64_t); -private: - void parseExpr(const std::shared_ptr &e, std::unordered_set &seen); +public: + int unify(Type *typ, Unification *undo) override; + +public: + std::string debugString(char mode) const override; + Expr *getStaticExpr() const override; + + IntStaticType *getIntStatic() override { return this; } }; +struct StrStaticType : public StaticType { + std::string value; + +public: + explicit StrStaticType(Cache *cache, std::string); + +public: + int unify(Type *typ, Unification *undo) override; + +public: + std::string debugString(char mode) const override; + Expr *getStaticExpr() const override; + + StrStaticType *getStrStatic() override { return this; } +}; + +struct BoolStaticType : public StaticType { + bool value; + +public: + explicit BoolStaticType(Cache *cache, bool); + +public: + int unify(Type *typ, Unification *undo) override; + +public: + std::string debugString(char mode) const override; + Expr *getStaticExpr() const override; + + BoolStaticType *getBoolStatic() override { return this; } +}; + +using StaticTypePtr = std::shared_ptr; +using IntStaticTypePtr = std::shared_ptr; +using StrStaticTypePtr = std::shared_ptr; + } // namespace codon::ast::types diff --git a/codon/parser/ast/types/traits.cpp b/codon/parser/ast/types/traits.cpp index 34193f90..bc79e28a 100644 --- a/codon/parser/ast/types/traits.cpp +++ b/codon/parser/ast/types/traits.cpp @@ -24,106 +24,136 @@ CallableTrait::CallableTrait(Cache *cache, std::vector args) : Trait(cache), args(std::move(args)) {} int CallableTrait::unify(Type *typ, Unification *us) { - if (auto tr = typ->getRecord()) { + if (auto tr = typ->getClass()) { + TypePtr ft = nullptr; + if (typ->is("TypeWrap")) { + TypecheckVisitor tv(cache->typeCtx); + ft = tv.instantiateType( + tv.findMethod(typ->getClass(), "__call_no_self__").front(), typ->getClass()); + tr = ft->getClass(); + } + if (tr->name == "NoneType") return 1; if (tr->name != "Function" && !tr->getPartial()) return -1; + if (!tr->isRecord()) + return -1; if (args.empty()) return 1; std::vector known; + TypePtr func = nullptr; // trFun can point to it auto trFun = tr; if (auto pt = tr->getPartial()) { int ic = 0; std::unordered_map c; - trFun = pt->func->instantiate(0, &ic, &c)->getRecord(); - known = pt->known; + func = pt->getPartialFunc()->instantiate(0, &ic, &c); + trFun = func->getClass(); + known = pt->getPartialMask(); + + auto knownArgTypes = pt->generics[1].type->getClass(); + for (size_t i = 0, j = 0, k = 0; i < known.size(); i++) + if ((*func->getFunc()->ast)[i].isGeneric()) { + j++; + } else if (known[i]) { + if ((*func->getFunc())[i - j]->unify(knownArgTypes->generics[k].type.get(), + us) == -1) + return -1; + k++; + } } else { - known = std::vector(tr->generics[0].type->getRecord()->args.size(), 0); + known = std::vector(tr->generics[0].type->getClass()->generics.size(), 0); } - auto &inArgs = args[0]->getRecord()->args; - auto &trInArgs = trFun->generics[0].type->getRecord()->args; + auto inArgs = args[0]->getClass(); + auto trInArgs = trFun->generics[0].type->getClass(); auto trAst = trFun->getFunc() ? trFun->getFunc()->ast : nullptr; - size_t star = trInArgs.size(), kwStar = trInArgs.size(); + size_t star = 0, kwStar = trInArgs->generics.size(); size_t total = 0; if (trAst) { star = trAst->getStarArgs(); kwStar = trAst->getKwStarArgs(); - if (kwStar < trAst->args.size() && star >= trInArgs.size()) + for (size_t fi = 0; fi < trAst->size(); fi++) { + if (fi < star && !(*trAst)[fi].isValue()) + star--; + if (fi < kwStar && !(*trAst)[fi].isValue()) + kwStar--; + } + if (kwStar < trAst->size() && star >= trInArgs->generics.size()) star -= 1; size_t preStar = 0; - for (size_t fi = 0; fi < trAst->args.size(); fi++) { - if (fi != kwStar && !known[fi] && trAst->args[fi].status == Param::Normal) { + for (size_t fi = 0; fi < trAst->size(); fi++) { + if (fi != kwStar && !known[fi] && (*trAst)[fi].isValue()) { total++; if (fi < star) preStar++; } } if (preStar < total) { - if (inArgs.size() < preStar) + if (inArgs->generics.size() < preStar) return -1; - } else if (inArgs.size() != total) { + } else if (inArgs->generics.size() != total) { return -1; } } else { - total = star = trInArgs.size(); - if (inArgs.size() != total) + total = star = trInArgs->generics.size(); + if (inArgs->generics.size() != total) return -1; } size_t i = 0; - for (size_t fi = 0; i < inArgs.size() && fi < star; fi++) { - if (!known[fi] && trAst->args[fi].status == Param::Normal) { - if (inArgs[i++]->unify(trInArgs[fi].get(), us) == -1) + for (size_t fi = 0; i < inArgs->generics.size() && fi < star; fi++) { + if (!known[fi] && (*trAst)[fi].isValue()) { + if (inArgs->generics[i++].type->unify(trInArgs->generics[fi].type.get(), us) == + -1) return -1; } } // NOTE: *args / **kwargs types will be typecheck when the function is called + auto tv = TypecheckVisitor(cache->typeCtx); if (auto pf = trFun->getFunc()) { // Make sure to set types of *args/**kwargs so that the function that // is being unified with Callable[] can be realized - - if (star < trInArgs.size() - (kwStar < trInArgs.size())) { - std::vector starArgTypes; + if (star < trInArgs->generics.size() - (kwStar < trInArgs->generics.size())) { + std::vector starArgTypes; if (auto tp = tr->getPartial()) { - auto ts = tp->args[tp->args.size() - 2]->getRecord(); - seqassert(ts, "bad partial *args/**kwargs"); - starArgTypes = ts->args; + auto ts = tp->generics[1].type->getClass(); + seqassert(ts && !ts->generics.empty() && + ts->generics[ts->generics.size() - 1].type->getClass(), + "bad partial *args/**kwargs"); + for (auto &tt : + ts->generics[ts->generics.size() - 1].type->getClass()->generics) + starArgTypes.push_back(tt.getType()); } - starArgTypes.insert(starArgTypes.end(), inArgs.begin() + i, inArgs.end()); + for (; i < inArgs->generics.size(); i++) + starArgTypes.push_back(inArgs->generics[i].getType()); - auto tv = TypecheckVisitor(cache->typeCtx); - auto t = cache->typeCtx->instantiateTuple(starArgTypes)->getClass(); - if (t->unify(trInArgs[star].get(), us) == -1) + auto tn = + tv.instantiateType(tv.generateTuple(starArgTypes.size()), starArgTypes); + if (tn->unify(trInArgs->generics[star].type.get(), us) == -1) return -1; } - if (kwStar < trInArgs.size()) { - auto tv = TypecheckVisitor(cache->typeCtx); - std::vector names; - std::vector starArgTypes; + if (kwStar < trInArgs->generics.size()) { + auto tt = tv.generateTuple(0); + size_t id = 0; if (auto tp = tr->getPartial()) { - auto ts = tp->args.back()->getRecord(); - seqassert(ts, "bad partial *args/**kwargs"); - auto ff = tv.getClassFields(ts.get()); - for (size_t i = 0; i < ts->args.size(); i++) { - names.emplace_back(ff[i].name); - starArgTypes.emplace_back(ts->args[i]); - } + auto ts = tp->generics[2].type->getClass(); + seqassert(ts && ts->is("NamedTuple"), "bad partial *args/**kwargs"); + id = ts->generics[0].type->getIntStatic()->value; + tt = ts->generics[1].getType()->getClass(); } - auto name = tv.generateTuple(starArgTypes.size(), TYPE_KWTUPLE, names); - auto t = cache->typeCtx->forceFind(name)->type; - t = cache->typeCtx->instantiateGeneric(t, starArgTypes)->getClass(); - if (t->unify(trInArgs[kwStar].get(), us) == -1) + auto tid = std::make_shared(cache, id); + auto kt = tv.instantiateType(tv.getStdLibType("NamedTuple"), {tid.get(), tt}); + if (kt->unify(trInArgs->generics[kwStar].type.get(), us) == -1) return -1; } if (us && pf->canRealize()) { // Realize if possible to allow deduction of return type - auto rf = TypecheckVisitor(cache->typeCtx).realize(pf); - pf->unify(rf.get(), us); + auto rf = tv.realize(pf); + pf->unify(rf, us); } - if (args[1]->unify(pf->getRetType().get(), us) == -1) + if (args[1]->unify(pf->getRetType(), us) == -1) return -1; } return 1; @@ -172,7 +202,15 @@ std::string CallableTrait::debugString(char mode) const { TypeTrait::TypeTrait(TypePtr typ) : Trait(typ), type(std::move(typ)) {} -int TypeTrait::unify(Type *typ, Unification *us) { return typ->unify(type.get(), us); } +int TypeTrait::unify(Type *typ, Unification *us) { + if (auto tc = typ->getClass()) { + // does not make sense otherwise and results in infinite cycles + return typ->unify(type.get(), us); + } + if (typ->getUnbound()) + return 0; + return -1; +} TypePtr TypeTrait::generalize(int atLevel) { auto c = std::make_shared(type->generalize(atLevel)); @@ -188,7 +226,7 @@ TypePtr TypeTrait::instantiate(int atLevel, int *unboundCount, } std::string TypeTrait::debugString(char mode) const { - return fmt::format("Trait[{}]", type->debugString(mode)); + return fmt::format("Trait[{}]", type->getClass() ? type->getClass()->name : "-"); } } // namespace codon::ast::types diff --git a/codon/parser/ast/types/traits.h b/codon/parser/ast/types/traits.h index 1a59fd48..5c412b44 100644 --- a/codon/parser/ast/types/traits.h +++ b/codon/parser/ast/types/traits.h @@ -46,16 +46,4 @@ struct TypeTrait : public Trait { std::string debugString(char mode) const override; }; -struct VariableTupleTrait : public Trait { - TypePtr size; - -public: - explicit VariableTupleTrait(TypePtr size); - int unify(Type *typ, Unification *undo) override; - TypePtr generalize(int atLevel) override; - TypePtr instantiate(int atLevel, int *unboundCount, - std::unordered_map *cache) override; - std::string debugString(char mode) const override; -}; - } // namespace codon::ast::types diff --git a/codon/parser/ast/types/type.cpp b/codon/parser/ast/types/type.cpp index 9683e52b..17e39233 100644 --- a/codon/parser/ast/types/type.cpp +++ b/codon/parser/ast/types/type.cpp @@ -13,16 +13,18 @@ namespace codon::ast::types { /// Undo a destructive unification. void Type::Unification::undo() { for (size_t i = linked.size(); i-- > 0;) { - linked[i]->kind = LinkType::Unbound; - linked[i]->type = nullptr; + linked[i]->getLink()->kind = LinkType::Unbound; + linked[i]->getLink()->type = nullptr; } for (size_t i = leveled.size(); i-- > 0;) { - seqassertn(leveled[i].first->kind == LinkType::Unbound, "not unbound [{}]", - leveled[i].first->getSrcInfo()); - leveled[i].first->level = leveled[i].second; + seqassertn(leveled[i].first->getLink()->kind == LinkType::Unbound, + "not unbound [{}]", leveled[i].first->getSrcInfo()); + leveled[i].first->getLink()->level = leveled[i].second; } for (auto &t : traits) - t->trait = nullptr; + t->getLink()->trait = nullptr; + for (auto &t : statics) + t->getLink()->isStatic = 0; } Type::Type(const std::shared_ptr &typ) : cache(typ->cache) { @@ -33,7 +35,9 @@ Type::Type(Cache *cache, const SrcInfo &info) : cache(cache) { setSrcInfo(info); TypePtr Type::follow() { return shared_from_this(); } -std::vector> Type::getUnbounds() const { return {}; } +bool Type::hasUnbounds(bool) const { return false; } + +std::vector Type::getUnbounds() const { return {}; } std::string Type::toString() const { return debugString(1); } @@ -43,24 +47,26 @@ bool Type::is(const std::string &s) { return getClass() && getClass()->name == s char Type::isStaticType() { auto t = follow(); - if (auto s = t->getStatic()) - return char(s->expr->staticValue.type); + if (t->getBoolStatic()) + return 3; + if (t->getStrStatic()) + return 2; + if (t->getIntStatic()) + return 1; if (auto l = t->getLink()) return l->isStatic; - return false; -} - -TypePtr Type::makeType(Cache *cache, const std::string &name, - const std::string &niceName, bool isRecord) { - if (name == "Union") - return std::make_shared(cache); - if (isRecord) - return std::make_shared(cache, name, niceName); - return std::make_shared(cache, name, niceName); + return 0; } -std::shared_ptr Type::makeStatic(Cache *cache, const ExprPtr &expr) { - return std::make_shared(cache, expr); +Type *Type::operator<<(Type *t) { + seqassert(t, "rhs is nullptr"); + types::Type::Unification undo; + if (unify(t, &undo) >= 0) { + return this; + } else { + undo.undo(); + return nullptr; + } } } // namespace codon::ast::types diff --git a/codon/parser/ast/types/type.h b/codon/parser/ast/types/type.h index 30013bb6..10984c2a 100644 --- a/codon/parser/ast/types/type.h +++ b/codon/parser/ast/types/type.h @@ -13,17 +13,19 @@ namespace codon::ast { struct Cache; struct Expr; +struct TypeContext; } // namespace codon::ast namespace codon::ast::types { /// Forward declarations -struct FuncType; struct ClassType; +struct FuncType; struct LinkType; -struct RecordType; -struct PartialType; struct StaticType; +struct IntStaticType; +struct StrStaticType; +struct BoolStaticType; struct UnionType; /** @@ -39,14 +41,13 @@ struct Type : public codon::SrcObject, public std::enable_shared_from_this /// Needed because the unify() is destructive. struct Unification { /// List of unbound types that have been changed. - std::vector linked; + std::vector> linked; /// List of unbound types whose level has been changed. - std::vector> leveled; + std::vector, int>> leveled; /// List of assigned traits. - std::vector traits; - /// List of pointers that are owned by unification process - /// (to avoid memory issues with undoing). - std::vector> ownedTypes; + std::vector> traits; + /// List of unbound types whose static status has been changed. + std::vector> statics; public: /// Undo the unification step. @@ -84,8 +85,10 @@ struct Type : public codon::SrcObject, public std::enable_shared_from_this /// Get the final type (follow through all LinkType links). /// For example, for (a->b->c->d) it returns d. virtual std::shared_ptr follow(); + /// Check if type has unbound/generic types. + virtual bool hasUnbounds(bool = false) const; /// Obtain the list of internal unbound types. - virtual std::vector> getUnbounds() const; + virtual std::vector getUnbounds() const; /// True if a type is realizable. virtual bool canRealize() const = 0; /// True if a type is completely instantiated (has no unbounds or generics). @@ -102,23 +105,22 @@ struct Type : public codon::SrcObject, public std::enable_shared_from_this virtual std::string realizedName() const = 0; /// Convenience virtual functions to avoid unnecessary dynamic_cast calls. - virtual std::shared_ptr getFunc() { return nullptr; } - virtual std::shared_ptr getPartial() { return nullptr; } - virtual std::shared_ptr getClass() { return nullptr; } - virtual std::shared_ptr getRecord() { return nullptr; } - virtual std::shared_ptr getLink() { return nullptr; } - virtual std::shared_ptr getUnbound() { return nullptr; } - virtual std::shared_ptr getStatic() { return nullptr; } - virtual std::shared_ptr getUnion() { return nullptr; } - virtual std::shared_ptr getHeterogenousTuple() { return nullptr; } + virtual FuncType *getFunc() { return nullptr; } + virtual ClassType *getPartial() { return nullptr; } + virtual ClassType *getClass() { return nullptr; } + virtual LinkType *getLink() { return nullptr; } + virtual LinkType *getUnbound() { return nullptr; } + virtual StaticType *getStatic() { return nullptr; } + virtual IntStaticType *getIntStatic() { return nullptr; } + virtual StrStaticType *getStrStatic() { return nullptr; } + virtual BoolStaticType *getBoolStatic() { return nullptr; } + virtual UnionType *getUnion() { return nullptr; } + virtual ClassType *getHeterogenousTuple() { return nullptr; } virtual bool is(const std::string &s); char isStaticType(); -public: - static std::shared_ptr makeType(Cache *, const std::string &, - const std::string &, bool = false); - static std::shared_ptr makeStatic(Cache *, const std::shared_ptr &); + Type *operator<<(Type *t); protected: Cache *cache; @@ -132,30 +134,23 @@ using TypePtr = std::shared_ptr; template struct fmt::formatter< T, std::enable_if_t::value, char>> - : fmt::ostream_formatter {}; - -template -struct fmt::formatter< - T, - std::enable_if_t< - std::is_convertible>::value, char>> : fmt::formatter { - char presentation = 'd'; + char presentation = 'b'; constexpr auto parse(format_parse_context &ctx) -> decltype(ctx.begin()) { auto it = ctx.begin(), end = ctx.end(); - if (it != end && (*it == 'p' || *it == 'd' || *it == 'D')) + if (it != end && (*it == 'a' || *it == 'b' || *it == 'c')) presentation = *it++; return it; } template auto format(const T &p, FormatContext &ctx) const -> decltype(ctx.out()) { - if (presentation == 'p') - return fmt::format_to(ctx.out(), "{}", p ? p->debugString(0) : ""); - else if (presentation == 'd') - return fmt::format_to(ctx.out(), "{}", p ? p->debugString(1) : ""); + if (presentation == 'a') + return fmt::format_to(ctx.out(), "{}", p.debugString(0)); + else if (presentation == 'b') + return fmt::format_to(ctx.out(), "{}", p.debugString(1)); else - return fmt::format_to(ctx.out(), "{}", p ? p->debugString(2) : ""); + return fmt::format_to(ctx.out(), "{}", p.debugString(2)); } }; diff --git a/codon/parser/ast/types/union.cpp b/codon/parser/ast/types/union.cpp index b567e2ea..c520384b 100644 --- a/codon/parser/ast/types/union.cpp +++ b/codon/parser/ast/types/union.cpp @@ -10,15 +10,18 @@ namespace codon::ast::types { -UnionType::UnionType(Cache *cache) : RecordType(cache, "Union", "Union") { - for (size_t i = 0; i < 256; i++) +UnionType::UnionType(Cache *cache) : ClassType(cache, "Union", "Union") { + isTuple = true; + for (size_t i = 0; i < MAX_UNION; i++) pendingTypes.emplace_back( std::make_shared(cache, LinkType::Generic, i, 0, nullptr)); } UnionType::UnionType(Cache *cache, const std::vector &generics, const std::vector &pendingTypes) - : RecordType(cache, "Union", "Union", generics), pendingTypes(pendingTypes) {} + : ClassType(cache, "Union", "Union", generics), pendingTypes(pendingTypes) { + isTuple = true; +} int UnionType::unify(Type *typ, Unification *us) { if (typ->getUnion()) { @@ -27,13 +30,13 @@ int UnionType::unify(Type *typ, Unification *us) { for (size_t i = 0; i < pendingTypes.size(); i++) if (pendingTypes[i]->unify(tr->pendingTypes[i].get(), us) == -1) return -1; - return RecordType::unify(typ, us); + return ClassType::unify(typ, us); } else if (!isSealed()) { return tr->unify(this, us); } else if (!tr->isSealed()) { if (tr->pendingTypes[0]->getLink() && tr->pendingTypes[0]->getLink()->kind == LinkType::Unbound) - return RecordType::unify(tr.get(), us); + return ClassType::unify(tr, us); return -1; } // Do not hard-unify if we have unbounds @@ -46,7 +49,7 @@ int UnionType::unify(Type *typ, Unification *us) { return -1; int s1 = 2, s = 0; for (size_t i = 0; i < u1.size(); i++) { - if ((s = u1[i]->unify(u2[i].get(), us)) == -1) + if ((s = u1[i]->unify(u2[i], us)) == -1) return -1; s1 += s; } @@ -58,7 +61,7 @@ int UnionType::unify(Type *typ, Unification *us) { } TypePtr UnionType::generalize(int atLevel) { - auto r = RecordType::generalize(atLevel); + auto r = ClassType::generalize(atLevel); auto p = pendingTypes; for (auto &t : p) t = t->generalize(atLevel); @@ -69,7 +72,7 @@ TypePtr UnionType::generalize(int atLevel) { TypePtr UnionType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) { - auto r = RecordType::instantiate(atLevel, unboundCount, cache); + auto r = ClassType::instantiate(atLevel, unboundCount, cache); auto p = pendingTypes; for (auto &t : p) t = t->instantiate(atLevel, unboundCount, cache); @@ -80,65 +83,57 @@ TypePtr UnionType::instantiate(int atLevel, int *unboundCount, std::string UnionType::debugString(char mode) const { if (mode == 2) - return this->RecordType::debugString(mode); - if (!generics[0].type->getRecord()) - return this->RecordType::debugString(mode); + return this->ClassType::debugString(mode); + if (!generics[0].type->getClass()) + return this->ClassType::debugString(mode); std::set gss; - for (auto &a : generics[0].type->getRecord()->args) - gss.insert(a->debugString(mode)); - std::string s; - for (auto &i : gss) - s += "," + i; - return fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s.substr(1))); + for (auto &a : generics[0].type->getClass()->generics) + gss.insert(a.debugString(mode)); + std::string s = join(gss, " | "); + return fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s)); } -bool UnionType::canRealize() const { return isSealed() && RecordType::canRealize(); } +bool UnionType::canRealize() const { return isSealed() && ClassType::canRealize(); } std::string UnionType::realizedName() const { seqassert(canRealize(), "cannot realize {}", toString()); - std::set gss; - for (auto &a : generics[0].type->getRecord()->args) - gss.insert(a->realizedName()); - std::string s; - for (auto &i : gss) - s += "," + i; - return fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s.substr(1))); + return ClassType::realizedName(); } -std::string UnionType::realizedTypeName() const { return realizedName(); } - -void UnionType::addType(TypePtr typ) { +bool UnionType::addType(Type *typ) { seqassert(!isSealed(), "union already sealed"); - if (this == typ.get()) - return; + if (this == typ) + return true; if (auto tu = typ->getUnion()) { if (tu->isSealed()) { - for (auto &t : tu->generics[0].type->getRecord()->args) - addType(t); + for (auto &t : tu->generics[0].type->getClass()->generics) + if (!addType(t.type.get())) + return false; } else { for (auto &t : tu->pendingTypes) { if (t->getLink() && t->getLink()->kind == LinkType::Unbound) break; - else - addType(t); + else if (!addType(t.get())) + return false; } } + return true; } else { // Find first pending generic to which we can attach this! Unification us; for (auto &t : pendingTypes) if (auto l = t->getLink()) { if (l->kind == LinkType::Unbound) { - t->unify(typ.get(), &us); - return; + t->unify(typ, &us); + return true; } } - E(error::Error::UNION_TOO_BIG, this); + return false; } } -bool UnionType::isSealed() const { return generics[0].type->getRecord() != nullptr; } +bool UnionType::isSealed() const { return generics[0].type->getClass() != nullptr; } void UnionType::seal() { seqassert(!isSealed(), "union already sealed"); @@ -149,18 +144,20 @@ void UnionType::seal() { if (pendingTypes[i]->getLink() && pendingTypes[i]->getLink()->kind == LinkType::Unbound) break; - std::vector typeSet(pendingTypes.begin(), pendingTypes.begin() + i); - auto t = cache->typeCtx->instantiateTuple(typeSet); + std::vector typeSet; + for (size_t j = 0; j < i; j++) + typeSet.push_back(pendingTypes[j].get()); + auto t = tv.instantiateType(tv.generateTuple(typeSet.size()), typeSet); Unification us; generics[0].type->unify(t.get(), &us); } -std::vector UnionType::getRealizationTypes() { +std::vector UnionType::getRealizationTypes() { seqassert(canRealize(), "cannot realize {}", debugString(1)); - std::map unionTypes; - for (auto &u : generics[0].type->getRecord()->args) - unionTypes[u->realizedName()] = u; - std::vector r; + std::map unionTypes; + for (auto &u : generics[0].type->getClass()->generics) + unionTypes[u.type->realizedName()] = u.type.get(); + std::vector r; r.reserve(unionTypes.size()); for (auto &[_, t] : unionTypes) r.emplace_back(t); diff --git a/codon/parser/ast/types/union.h b/codon/parser/ast/types/union.h index ddbcb6b1..ec78fbcc 100644 --- a/codon/parser/ast/types/union.h +++ b/codon/parser/ast/types/union.h @@ -11,7 +11,9 @@ namespace codon::ast::types { -struct UnionType : public RecordType { +struct UnionType : public ClassType { + static const int MAX_UNION = 256; + std::vector pendingTypes; explicit UnionType(Cache *cache); @@ -28,16 +30,13 @@ struct UnionType : public RecordType { bool canRealize() const override; std::string debugString(char mode) const override; std::string realizedName() const override; - std::string realizedTypeName() const override; bool isSealed() const; - std::shared_ptr getUnion() override { - return std::static_pointer_cast(shared_from_this()); - } + UnionType *getUnion() override { return this; } - void addType(TypePtr typ); + bool addType(Type *); void seal(); - std::vector getRealizationTypes(); + std::vector getRealizationTypes(); }; } // namespace codon::ast::types diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index 644a46ad..1b37e97b 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -10,21 +10,21 @@ #include "codon/cir/util/irtools.h" #include "codon/parser/common.h" #include "codon/parser/peg/peg.h" -#include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/ctx.h" #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast { -Cache::Cache(std::string argv0) - : generatedSrcInfoCount(0), unboundCount(256), varCount(0), age(0), - argv0(std::move(argv0)), typeCtx(nullptr), codegenCtx(nullptr), isJit(false), - jitCell(0), pythonExt(false), pyModule(nullptr) {} +Cache::Cache(std::string argv0) : argv0(std::move(argv0)) { + this->_nodes = new std::vector>(); + typeCtx = std::make_shared(this, ".root"); +} std::string Cache::getTemporaryVar(const std::string &prefix, char sigil) { - return fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix, - ++varCount); + auto n = fmt::format("{}{}_{}", sigil ? fmt::format("{}_", sigil) : "", prefix, + ++varCount); + return n; } std::string Cache::rev(const std::string &s) { @@ -35,13 +35,6 @@ std::string Cache::rev(const std::string &s) { return ""; } -void Cache::addGlobal(const std::string &name, ir::Var *var) { - if (!in(globals, name)) { - // LOG("[global] {}", name); - globals[name] = {false, var}; - } -} - SrcInfo Cache::generateSrcInfo() { return {FILE_GENERATED, generatedSrcInfoCount, generatedSrcInfoCount++, 0}; } @@ -61,127 +54,127 @@ std::string Cache::getContent(const SrcInfo &info) { return s.substr(col, len); } -types::ClassTypePtr Cache::findClass(const std::string &name) const { +Cache::Class *Cache::getClass(types::ClassType *type) { + auto name = type->name; + return in(classes, name); +} + +std::string Cache::getMethod(types::ClassType *typ, const std::string &member) { + if (auto cls = getClass(typ)) { + if (auto t = in(cls->methods, member)) + return *t; + } + seqassertn(false, "cannot find '{}' in '{}'", member, typ->toString()); + return ""; +} + +types::ClassType *Cache::findClass(const std::string &name) const { auto f = typeCtx->find(name); - if (f && f->kind == TypecheckItem::Type) - return f->type->getClass(); + if (f && f->isType()) + return f->getType()->getClass()->generics[0].getType()->getClass(); return nullptr; } -types::FuncTypePtr Cache::findFunction(const std::string &name) const { +types::FuncType *Cache::findFunction(const std::string &name) const { auto f = typeCtx->find(name); - if (f && f->type && f->kind == TypecheckItem::Func) + if (f && f->type && f->isFunc()) return f->type->getFunc(); f = typeCtx->find(name + ":0"); - if (f && f->type && f->kind == TypecheckItem::Func) + if (f && f->type && f->isFunc()) return f->type->getFunc(); return nullptr; } -types::FuncTypePtr Cache::findMethod(types::ClassType *typ, const std::string &member, - const std::vector &args) { - auto e = std::make_shared(typ->name); - e->type = typ->getClass(); - seqassertn(e->type, "not a class"); - int oldAge = typeCtx->age; - typeCtx->age = 99999; - auto f = TypecheckVisitor(typeCtx).findBestMethod(e->type->getClass(), member, args); - typeCtx->age = oldAge; +types::FuncType *Cache::findMethod(types::ClassType *typ, const std::string &member, + const std::vector &args) { + auto f = TypecheckVisitor(typeCtx).findBestMethod(typ, member, args); return f; } -ir::types::Type *Cache::realizeType(types::ClassTypePtr type, +ir::types::Type *Cache::realizeType(types::ClassType *type, const std::vector &generics) { - auto e = std::make_shared(type->name); - e->type = type; - type = typeCtx->instantiateGeneric(type, generics)->getClass(); auto tv = TypecheckVisitor(typeCtx); - if (auto rtv = tv.realize(type)) { + if (auto rtv = tv.realize(tv.instantiateType(type, castVectorPtr(generics)))) { return classes[rtv->getClass()->name] - .realizations[rtv->getClass()->realizedTypeName()] + .realizations[rtv->getClass()->realizedName()] ->ir; } return nullptr; } -ir::Func *Cache::realizeFunction(types::FuncTypePtr type, +ir::Func *Cache::realizeFunction(types::FuncType *type, const std::vector &args, const std::vector &generics, - const types::ClassTypePtr &parentClass) { - auto e = std::make_shared(type->ast->name); - e->type = type; - type = typeCtx->instantiate(type, parentClass)->getFunc(); - if (args.size() != type->getArgTypes().size() + 1) + types::ClassType *parentClass) { + auto tv = TypecheckVisitor(typeCtx); + auto t = tv.instantiateType(type, parentClass); + if (args.size() != t->size() + 1) return nullptr; types::Type::Unification undo; - if (type->getRetType()->unify(args[0].get(), &undo) < 0) { + if (t->getRetType()->unify(args[0].get(), &undo) < 0) { undo.undo(); return nullptr; } for (int gi = 1; gi < args.size(); gi++) { undo = types::Type::Unification(); - if (type->getArgTypes()[gi - 1]->unify(args[gi].get(), &undo) < 0) { + if ((*t)[gi - 1]->unify(args[gi].get(), &undo) < 0) { undo.undo(); return nullptr; } } if (!generics.empty()) { - if (generics.size() != type->funcGenerics.size()) + if (generics.size() != t->funcGenerics.size()) return nullptr; for (int gi = 0; gi < generics.size(); gi++) { undo = types::Type::Unification(); - if (type->funcGenerics[gi].type->unify(generics[gi].get(), &undo) < 0) { + if (t->funcGenerics[gi].type->unify(generics[gi].get(), &undo) < 0) { undo.undo(); return nullptr; } } } - int oldAge = typeCtx->age; - typeCtx->age = 99999; - auto tv = TypecheckVisitor(typeCtx); ir::Func *f = nullptr; - if (auto rtv = tv.realize(type)) { + if (auto rtv = tv.realize(t.get())) { auto pr = pendingRealizations; // copy it as it might be modified for (auto &fn : pr) - TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); - f = functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; + TranslateVisitor(codegenCtx).translateStmts(clone(functions[fn.first].ast)); + f = functions[rtv->getFunc()->ast->getName()].realizations[rtv->realizedName()]->ir; } - typeCtx->age = oldAge; return f; } ir::types::Type *Cache::makeTuple(const std::vector &types) { auto tv = TypecheckVisitor(typeCtx); - auto t = typeCtx->instantiateTuple(types); - return realizeType(t, types); + auto t = tv.instantiateType(tv.generateTuple(types.size()), castVectorPtr(types)); + return realizeType(t->getClass(), types); } ir::types::Type *Cache::makeFunction(const std::vector &types) { auto tv = TypecheckVisitor(typeCtx); seqassertn(!types.empty(), "types must have at least one argument"); + std::vector tt; + for (size_t i = 1; i < types.size(); i++) + tt.emplace_back(types[i].get()); const auto &ret = types[0]; - auto argType = typeCtx->instantiateTuple( - std::vector(types.begin() + 1, types.end())); - auto t = typeCtx->find("Function"); - seqassertn(t && t->type, "cannot find 'Function'"); - return realizeType(t->type->getClass(), {argType, ret}); + auto argType = tv.instantiateType(tv.generateTuple(types.size() - 1), tt); + auto ft = realizeType(tv.getStdLibType("Function")->getClass(), {argType, ret}); + return ft; } ir::types::Type *Cache::makeUnion(const std::vector &types) { auto tv = TypecheckVisitor(typeCtx); - - auto argType = typeCtx->instantiateTuple(types); - auto t = typeCtx->find("Union"); - seqassertn(t && t->type, "cannot find 'Union'"); - return realizeType(t->type->getClass(), {argType}); + auto argType = + tv.instantiateType(tv.generateTuple(types.size()), castVectorPtr(types)); + return realizeType(tv.getStdLibType("Union")->getClass(), {argType}); } void Cache::parseCode(const std::string &code) { - auto node = ast::parseCode(this, "", code, /*startLine=*/0); + auto nodeOrErr = ast::parseCode(this, "", code, /*startLine=*/0); + if (nodeOrErr) + throw exc::ParserException(nodeOrErr.takeError()); auto sctx = imports[MAIN_IMPORT].ctx; - node = ast::SimplifyVisitor::apply(sctx, node, "", 99999); - node = ast::TypecheckVisitor::apply(this, node); + auto node = ast::TypecheckVisitor::apply(sctx, *nodeOrErr); for (auto &[name, p] : globals) if (p.first && !p.second) { p.second = name == VAR_ARGV ? codegenCtx->getModule()->getArgVar() @@ -189,15 +182,16 @@ void Cache::parseCode(const std::string &code) { SrcInfo(), nullptr, true, false, name); codegenCtx->add(ast::TranslateItem::Var, name, p.second); } - ast::TranslateVisitor(codegenCtx).transform(node); + ast::TranslateVisitor(codegenCtx).translateStmts(node); } -std::vector Cache::mergeC3(std::vector> &seqs) { +std::vector> +Cache::mergeC3(std::vector> &seqs) { // Reference: https://www.python.org/download/releases/2.3/mro/ - std::vector result; + std::vector> result; for (size_t i = 0;; i++) { bool found = false; - ExprPtr cand = nullptr; + std::shared_ptr cand = nullptr; for (auto &seq : seqs) { if (seq.empty()) continue; @@ -207,7 +201,7 @@ std::vector Cache::mergeC3(std::vector> &seqs) { if (!s.empty()) { bool in = false; for (size_t j = 1; j < s.size(); j++) { - if ((in |= (seq[0]->getTypeName() == s[j]->getTypeName()))) + if ((in |= (seq[0]->is(s[j]->getClass()->name)))) break; } if (in) { @@ -216,7 +210,7 @@ std::vector Cache::mergeC3(std::vector> &seqs) { } } if (!nothead) { - cand = seq[0]; + cand = std::dynamic_pointer_cast(seq[0]); break; } } @@ -224,9 +218,9 @@ std::vector Cache::mergeC3(std::vector> &seqs) { return result; if (!cand) return {}; - result.push_back(clone(cand)); + result.push_back(cand); for (auto &s : seqs) - if (!s.empty() && cand->getTypeName() == s[0]->getTypeName()) { + if (!s.empty() && cand->is(s[0]->getClass()->name)) { s.erase(s.begin()); } } @@ -235,325 +229,41 @@ std::vector Cache::mergeC3(std::vector> &seqs) { /** * Generate Python bindings for Cython-like access. - * - * TODO: this function is total mess. Needs refactoring. */ void Cache::populatePythonModule() { + using namespace ast; + if (!pythonExt) return; - - LOG_USER("[py] ====== module generation ======="); - -#define N std::make_shared - if (!pyModule) pyModule = std::make_shared(); - using namespace ast; - int oldAge = typeCtx->age; - typeCtx->age = 99999; - - auto realizeIR = [&](const types::FuncTypePtr &fn, - const std::vector &generics = {}) -> ir::Func * { - auto fnType = typeCtx->instantiate(fn); - types::Type::Unification u; - for (size_t i = 0; i < generics.size(); i++) - fnType->getFunc()->funcGenerics[i].type->unify(generics[i].get(), &u); - fnType = TypecheckVisitor(typeCtx).realize(fnType); - if (!fnType) - return nullptr; - - auto pr = pendingRealizations; // copy it as it might be modified - for (auto &fn : pr) - TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); - return functions[fn->ast->name].realizations[fnType->realizedName()]->ir; - }; - - const std::string pyWrap = "std.internal.python._PyWrap"; + LOG_USER("[py] ====== module generation ======="); + auto tv = TypecheckVisitor(typeCtx); auto clss = classes; // needs copy as below fns can mutate this - for (const auto &[cn, c] : clss) { - if (c.module.empty()) { - if (!in(c.methods, "__to_py__") || !in(c.methods, "__from_py__")) - continue; - - LOG_USER("[py] Cythonizing {}", cn); - ir::PyType py{rev(cn), c.ast->getDocstr()}; - - auto tc = typeCtx->forceFind(cn)->type; - if (!tc->canRealize()) - compilationError(fmt::format("cannot realize '{}' for Python export", rev(cn))); - tc = TypecheckVisitor(typeCtx).realize(tc); - seqassertn(tc, "cannot realize '{}'", cn); - - // 1. Replace to_py / from_py with _PyWrap.wrap_to_py/from_py - if (auto ofnn = in(c.methods, "__to_py__")) { - auto fnn = overloads[*ofnn].begin()->name; // default first overload! - auto &fna = functions[fnn].ast; - fna->getFunction()->suite = N(N( - N(pyWrap + ".wrap_to_py:0"), N(fna->args[0].name))); - } - if (auto ofnn = in(c.methods, "__from_py__")) { - auto fnn = overloads[*ofnn].begin()->name; // default first overload! - auto &fna = functions[fnn].ast; - fna->getFunction()->suite = - N(N(N(pyWrap + ".wrap_from_py:0"), - N(fna->args[0].name), N(cn))); - } - for (auto &n : std::vector{"__from_py__", "__to_py__"}) { - auto fnn = overloads[*in(c.methods, n)].begin()->name; - ir::Func *oldIR = nullptr; - if (!functions[fnn].realizations.empty()) - oldIR = functions[fnn].realizations.begin()->second->ir; - functions[fnn].realizations.clear(); - auto tf = TypecheckVisitor(typeCtx).realize(functions[fnn].type); - seqassertn(tf, "cannot re-realize '{}'", fnn); - if (oldIR) { - std::vector args; - for (auto it = oldIR->arg_begin(); it != oldIR->arg_end(); ++it) { - args.push_back(module->Nr(*it)); - } - ir::cast(oldIR)->setBody(ir::util::series( - ir::util::call(functions[fnn].realizations.begin()->second->ir, args))); - } - } - for (auto &[rn, r] : functions[pyWrap + ".py_type:0"].realizations) { - if (r->type->funcGenerics[0].type->unify(tc.get(), nullptr) >= 0) { - py.typePtrHook = r->ir; - break; - } - } - - // 2. Handle methods - auto methods = c.methods; - for (const auto &[n, ofnn] : methods) { - auto canonicalName = overloads[ofnn].back().name; - if (overloads[ofnn].size() == 1 && - functions[canonicalName].ast->hasAttr("autogenerated")) - continue; - auto fna = functions[canonicalName].ast; - bool isMethod = fna->hasAttr(Attr::Method); - bool isProperty = fna->hasAttr(Attr::Property); - - std::string call = pyWrap + ".wrap_multiple"; - bool isMagic = false; - if (startswith(n, "__") && endswith(n, "__")) { - auto m = n.substr(2, n.size() - 4); - if (m == "new" && c.ast->hasAttr(Attr::Tuple)) - m = "init"; - if (auto i = in(classes[pyWrap].methods, "wrap_magic_" + m)) { - call = *i; - isMagic = true; - } - } - if (isProperty) - call = pyWrap + ".wrap_get"; - - auto fnName = call + ":0"; - seqassertn(in(functions, fnName), "bad name"); - auto generics = std::vector{tc}; - if (isProperty) { - generics.push_back( - std::make_shared(this, rev(canonicalName))); - } else if (!isMagic) { - generics.push_back(std::make_shared(this, n)); - generics.push_back(std::make_shared(this, (int)isMethod)); - } - auto f = realizeIR(functions[fnName].type, generics); - if (!f) - continue; - - LOG_USER("[py] {} -> {} ({}; {})", n, call, isMethod, isProperty); - if (isProperty) { - py.getset.push_back({rev(canonicalName), "", f, nullptr}); - } else if (n == "__repr__") { - py.repr = f; - } else if (n == "__add__") { - py.add = f; - } else if (n == "__iadd__") { - py.iadd = f; - } else if (n == "__sub__") { - py.sub = f; - } else if (n == "__isub__") { - py.isub = f; - } else if (n == "__mul__") { - py.mul = f; - } else if (n == "__imul__") { - py.imul = f; - } else if (n == "__mod__") { - py.mod = f; - } else if (n == "__imod__") { - py.imod = f; - } else if (n == "__divmod__") { - py.divmod = f; - } else if (n == "__pow__") { - py.pow = f; - } else if (n == "__ipow__") { - py.ipow = f; - } else if (n == "__neg__") { - py.neg = f; - } else if (n == "__pos__") { - py.pos = f; - } else if (n == "__abs__") { - py.abs = f; - } else if (n == "__bool__") { - py.bool_ = f; - } else if (n == "__invert__") { - py.invert = f; - } else if (n == "__lshift__") { - py.lshift = f; - } else if (n == "__ilshift__") { - py.ilshift = f; - } else if (n == "__rshift__") { - py.rshift = f; - } else if (n == "__irshift__") { - py.irshift = f; - } else if (n == "__and__") { - py.and_ = f; - } else if (n == "__iand__") { - py.iand = f; - } else if (n == "__xor__") { - py.xor_ = f; - } else if (n == "__ixor__") { - py.ixor = f; - } else if (n == "__or__") { - py.or_ = f; - } else if (n == "__ior__") { - py.ior = f; - } else if (n == "__int__") { - py.int_ = f; - } else if (n == "__float__") { - py.float_ = f; - } else if (n == "__floordiv__") { - py.floordiv = f; - } else if (n == "__ifloordiv__") { - py.ifloordiv = f; - } else if (n == "__truediv__") { - py.truediv = f; - } else if (n == "__itruediv__") { - py.itruediv = f; - } else if (n == "__index__") { - py.index = f; - } else if (n == "__matmul__") { - py.matmul = f; - } else if (n == "__imatmul__") { - py.imatmul = f; - } else if (n == "__len__") { - py.len = f; - } else if (n == "__getitem__") { - py.getitem = f; - } else if (n == "__setitem__") { - py.setitem = f; - } else if (n == "__contains__") { - py.contains = f; - } else if (n == "__hash__") { - py.hash = f; - } else if (n == "__call__") { - py.call = f; - } else if (n == "__str__") { - py.str = f; - } else if (n == "__iter__") { - py.iter = f; - } else if (n == "__del__") { - py.del = f; - } else if (n == "__init__" || (c.ast->hasAttr(Attr::Tuple) && n == "__new__")) { - py.init = f; - } else { - py.methods.push_back(ir::PyFunction{ - n, fna->getDocstr(), f, - fna->hasAttr(Attr::Method) ? ir::PyFunction::Type::METHOD - : ir::PyFunction::Type::CLASS, - // always use FASTCALL for now; works even for 0- or 1- arg methods - 2}); - py.methods.back().keywords = true; - } - } - - for (auto &m : py.methods) { - if (in(std::set{"__lt__", "__le__", "__eq__", "__ne__", "__gt__", - "__ge__"}, - m.name)) { - py.cmp = realizeIR( - typeCtx->forceFind(pyWrap + ".wrap_cmp:0")->type->getFunc(), {tc}); - break; - } - } - - if (c.realizations.size() != 1) - compilationError(fmt::format("cannot pythonize generic class '{}'", cn)); - auto &r = c.realizations.begin()->second; - py.type = realizeType(r->type); - for (auto &[mn, mt] : r->fields) { - /// TODO: handle PyMember for tuples - // Generate getters & setters - auto generics = std::vector{ - tc, std::make_shared(this, mn)}; - auto gf = realizeIR(functions[pyWrap + ".wrap_get:0"].type, generics); - ir::Func *sf = nullptr; - if (!c.ast->hasAttr(Attr::Tuple)) - sf = realizeIR(functions[pyWrap + ".wrap_set:0"].type, generics); - py.getset.push_back({mn, "", gf, sf}); - LOG_USER("[py] {}: {} . {}", "member", cn, mn); - } + for (const auto &[cn, _] : clss) { + auto py = tv.cythonizeClass(cn); + if (!py.name.empty()) pyModule->types.push_back(py); - } } // Handle __iternext__ wrappers - auto cin = "_PyWrap.IterWrap"; - for (auto &[cn, cr] : classes[cin].realizations) { - LOG_USER("[py] iterfn: {}", cn); - ir::PyType py{cn, ""}; - auto tc = cr->type; - for (auto &[rn, r] : functions[pyWrap + ".py_type:0"].realizations) { - if (r->type->funcGenerics[0].type->unify(tc.get(), nullptr) >= 0) { - py.typePtrHook = r->ir; - break; - } - } - - auto &methods = classes[cin].methods; - for (auto &n : std::vector{"_iter", "_iternext"}) { - auto fnn = overloads[methods[n]].begin()->name; - auto &fna = functions[fnn]; - auto ft = typeCtx->instantiate(fna.type, tc->getClass()); - auto rtv = TypecheckVisitor(typeCtx).realize(ft); - auto f = - functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; - if (n == "_iter") - py.iter = f; - else - py.iternext = f; - } - py.type = cr->ir; + for (auto &[cn, _] : classes[CYTHON_ITER].realizations) { + auto py = tv.cythonizeIterator(cn); pyModule->types.push_back(py); } -#undef N auto fns = functions; // needs copy as below fns can mutate this - for (const auto &[fn, f] : fns) { - if (f.isToplevel) { - std::string call = pyWrap + ".wrap_multiple"; - auto fnName = call + ":0"; - seqassertn(in(functions, fnName), "bad name"); - auto generics = std::vector{ - typeCtx->forceFind(".toplevel")->type, - std::make_shared(this, rev(f.ast->name)), - std::make_shared(this, 0)}; - if (auto ir = realizeIR(functions[fnName].type, generics)) { - LOG_USER("[py] {}: {}", "toplevel", fn); - pyModule->functions.push_back(ir::PyFunction{rev(fn), f.ast->getDocstr(), ir, - ir::PyFunction::Type::TOPLEVEL, - int(f.ast->args.size())}); - pyModule->functions.back().keywords = true; - } - } + for (const auto &[fn, _] : fns) { + auto py = tv.cythonizeFunction(fn); + if (!py.name.empty()) + pyModule->functions.push_back(py); } // Handle pending realizations! auto pr = pendingRealizations; // copy it as it might be modified for (auto &fn : pr) - TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); - typeCtx->age = oldAge; + TranslateVisitor(codegenCtx).translateStmts(clone(functions[fn.first].ast)); } } // namespace codon::ast diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 3af25792..02aa8c38 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -21,24 +21,30 @@ #define STDLIB_INTERNAL_MODULE "internal" #define TYPE_TUPLE "Tuple" -#define TYPE_KWTUPLE "KwTuple.N" #define TYPE_TYPEVAR "TypeVar" #define TYPE_CALLABLE "Callable" -#define TYPE_PARTIAL "Partial.N" #define TYPE_OPTIONAL "Optional" #define TYPE_SLICE "std.internal.types.slice.Slice" -#define FN_UNWRAP "std.internal.types.optional.unwrap" +#define FN_UNWRAP "std.internal.types.optional.unwrap.0:0" +#define TYPE_TYPE "type" +#define FN_DISPATCH_SUFFIX ":dispatch" +#define VAR_USED_SUFFIX ":used" +#define FN_SETTER_SUFFIX ":set_" +#define VAR_CLASS_TOPLEVEL ":toplevel" #define VAR_ARGV "__argv__" +#define MAX_ERRORS 5 +#define MAX_TUPLE 2048 #define MAX_INT_WIDTH 10000 #define MAX_REALIZATION_DEPTH 200 #define MAX_STATIC_ITER 1024 +#define CYTHON_PYWRAP "std.internal.python._PyWrap" +#define CYTHON_ITER "_PyWrap.IterWrap" + namespace codon::ast { /// Forward declarations -struct SimplifyContext; -class SimplifyVisitor; struct TypeContext; struct TranslateContext; @@ -48,7 +54,7 @@ struct TranslateContext; * checking) assumes that previous stages populated this structure correctly. * Implemented to avoid bunch of global objects. */ -struct Cache : public std::enable_shared_from_this { +struct Cache { /// Stores a count for each identifier (name) seen in the code. /// Used to generate unique identifier for each name in the code (e.g. Foo -> Foo.2). std::unordered_map identifierCount; @@ -57,29 +63,28 @@ struct Cache : public std::enable_shared_from_this { std::unordered_map reverseIdentifierLookup; /// Number of code-generated source code positions. Used to generate the next unique /// source-code position information. - int generatedSrcInfoCount; + int generatedSrcInfoCount = 0; /// Number of unbound variables so far. Used to generate the next unique unbound /// identifier. - int unboundCount; + int unboundCount = 256; /// Number of auto-generated variables so far. Used to generate the next unique /// variable name in getTemporaryVar() below. - int varCount; - /// Stores the count of imported files. Used to track class method ages - /// and to prevent using extended methods before they were seen. - int age; + int varCount = 0; + /// Scope counter. Each conditional block gets a new scope ID. + int blockCount = 1; /// Holds module import data. - struct Import { + struct Module { + /// Relative module name (e.g., `foo.bar`) + std::string name; /// Absolute filename of an import. std::string filename; - /// Import simplify context. - std::shared_ptr ctx; + /// Import typechecking context. + std::shared_ptr ctx; /// Unique import variable for checking already loaded imports. std::string importVar; /// File content (line:col indexable) std::vector content; - /// Relative module name (e.g., `foo.bar`) - std::string moduleName; /// Set if loaded at toplevel bool loadedAtToplevel = true; }; @@ -92,8 +97,9 @@ struct Cache : public std::enable_shared_from_this { ir::Module *module = nullptr; /// Table of imported files that maps an absolute filename to a Import structure. - /// By convention, the key of the Codon's standard library is "". - std::unordered_map imports; + /// By convention, the key of the Codon's standard library is ":stdlib:", + /// and the main module is "". + std::unordered_map imports; /// Set of unique (canonical) global identifiers for marking such variables as global /// in code-generation step and in JIT. @@ -101,10 +107,13 @@ struct Cache : public std::enable_shared_from_this { /// Stores class data for each class (type) in the source code. struct Class { + /// Module information + std::string module; + /// Generic (unrealized) class template AST. - std::shared_ptr ast; + ClassStmt *ast = nullptr; /// Non-simplified AST. Used for base class instantiation. - std::shared_ptr originalAst; + ClassStmt *originalAst = nullptr; /// Class method lookup table. Each non-canonical name points /// to a root function name of a corresponding method. @@ -118,6 +127,12 @@ struct Cache : public std::enable_shared_from_this { types::TypePtr type; /// Base class name (if available) std::string baseClass; + Expr *typeExpr; + + ClassField(const std::string &name, const types::TypePtr &type, + const std::string &baseClass, Expr *typeExpr = nullptr) + : name(name), type(type), baseClass(baseClass), typeExpr(typeExpr) {} + types::Type *getType() const { return type.get(); } }; /// A list of class' ClassField instances. List is needed (instead of map) because /// the order of the fields matters. @@ -129,7 +144,7 @@ struct Cache : public std::enable_shared_from_this { /// A class realization. struct ClassRealization { /// Realized class type. - types::ClassTypePtr type; + std::shared_ptr type; /// A list of field names and realization's realized field types. std::vector> fields; /// IR type pointer. @@ -139,7 +154,7 @@ struct Cache : public std::enable_shared_from_this { struct VTable { // Maps {base, thunk signature} to {thunk realization, thunk ID} std::map, - std::pair> + std::pair, size_t>> table; codon::ir::Var *ir = nullptr; }; @@ -147,6 +162,8 @@ struct Cache : public std::enable_shared_from_this { std::unordered_map vtables; /// Realization ID size_t id = 0; + + types::ClassType *getType() const { return type.get(); } }; /// Realization lookup table that maps a realized class name to the corresponding /// ClassRealization instance. @@ -157,97 +174,89 @@ struct Cache : public std::enable_shared_from_this { /// List of virtual method names std::unordered_set virtuals; /// MRO - std::vector mro; + std::vector> mro; /// List of statically inherited classes. std::vector staticParentClasses; - /// Module information - std::string module; - - Class() : ast(nullptr), originalAst(nullptr), rtti(false) {} + bool hasRTTI() const { return rtti; } }; /// Class lookup table that maps a canonical class identifier to the corresponding /// Class instance. std::unordered_map classes; size_t classRealizationCnt = 0; + Class *getClass(types::ClassType *); + struct Function { + /// Module information + std::string module; + std::string rootName; /// Generic (unrealized) function template AST. - std::shared_ptr ast; + FunctionStmt *ast; + /// Unrealized function type. + std::shared_ptr type; + /// Non-simplified AST. - std::shared_ptr origAst; + FunctionStmt *origAst = nullptr; + bool isToplevel = false; /// A function realization. struct FunctionRealization { /// Realized function type. - types::FuncTypePtr type; + std::shared_ptr type; /// Realized function AST (stored here for later realization in code generations /// stage). - std::shared_ptr ast; + FunctionStmt *ast; /// IR function pointer. ir::Func *ir; + /// Resolved captures + std::vector captures; + + types::FuncType *getType() const { return type.get(); } }; /// Realization lookup table that maps a realized function name to the corresponding /// FunctionRealization instance. - std::unordered_map> realizations; - - /// Unrealized function type. - types::FuncTypePtr type; + std::unordered_map> realizations = + {}; + std::set captures = {}; + std::vector> captureMappings = {}; - /// Module information - std::string rootName = ""; - bool isToplevel = false; - - Function() - : ast(nullptr), origAst(nullptr), type(nullptr), rootName(""), - isToplevel(false) {} + types::FuncType *getType() const { return type.get(); } }; /// Function lookup table that maps a canonical function identifier to the /// corresponding Function instance. std::unordered_map functions; - struct Overload { - /// Canonical name of an overload (e.g. Foo.__init__.1). - std::string name; - /// Overload age (how many class extension were seen before a method definition). - /// Used to prevent the usage of an overload before it was defined in the code. - /// TODO: I have no recollection of how this was supposed to work. Most likely - /// it does not work at all... - int age; - }; /// Maps a "root" name of each function to the list of names of the function - /// overloads. - std::unordered_map> overloads; + /// overloads (canonical names). + std::unordered_map> overloads; /// Pointer to the later contexts needed for IR API access. - std::shared_ptr typeCtx; - std::shared_ptr codegenCtx; + std::shared_ptr typeCtx = nullptr; + std::shared_ptr codegenCtx = nullptr; /// Set of function realizations that are to be translated to IR. std::set> pendingRealizations; - /// Mapping of partial record names to function pointers and corresponding masks. - std::unordered_map>> - partials; /// Custom operators std::unordered_map>> + std::pair>> customBlockStmts; std::unordered_map> + std::function> customExprStmts; /// Plugin-added import paths std::vector pluginImportPaths; /// Set if the Codon is running in JIT mode. - bool isJit; - int jitCell; + bool isJit = false; + int jitCell = 0; - std::unordered_map> replacements; std::unordered_map generatedTuples; - std::vector errors; + std::vector> generatedTupleNames = {{}}; + ParserErrors errors; /// Set if Codon operates in Python compatibility mode (e.g., with Python numerics) bool pythonCompat = false; @@ -259,7 +268,7 @@ struct Cache : public std::enable_shared_from_this { /// Return a uniquely named temporary variable of a format /// "{sigil}_{prefix}{counter}". A sigil should be a non-lexable symbol. - std::string getTemporaryVar(const std::string &prefix = "", char sigil = '.'); + std::string getTemporaryVar(const std::string &prefix = "", char sigil = '%'); /// Get the non-canonical version of a canonical name. std::string rev(const std::string &s); @@ -267,46 +276,37 @@ struct Cache : public std::enable_shared_from_this { SrcInfo generateSrcInfo(); /// Get file contents at the given location. std::string getContent(const SrcInfo &info); - /// Register a global identifier. - void addGlobal(const std::string &name, ir::Var *var = nullptr); /// Realization API. /// Find a class with a given canonical name and return a matching types::Type pointer /// or a nullptr if a class is not found. /// Returns an _uninstantiated_ type. - types::ClassTypePtr findClass(const std::string &name) const; + types::ClassType *findClass(const std::string &name) const; /// Find a function with a given canonical name and return a matching types::Type /// pointer or a nullptr if a function is not found. /// Returns an _uninstantiated_ type. - types::FuncTypePtr findFunction(const std::string &name) const; + types::FuncType *findFunction(const std::string &name) const; /// Find the canonical name of a class method. - std::string getMethod(const types::ClassTypePtr &typ, const std::string &member) { - if (auto m = in(classes, typ->name)) { - if (auto t = in(m->methods, member)) - return *t; - } - seqassertn(false, "cannot find '{}' in '{}'", member, typ->toString()); - return ""; - } + std::string getMethod(types::ClassType *typ, const std::string &member); /// Find the class method in a given class type that best matches the given arguments. /// Returns an _uninstantiated_ type. - types::FuncTypePtr findMethod(types::ClassType *typ, const std::string &member, - const std::vector &args); + types::FuncType *findMethod(types::ClassType *typ, const std::string &member, + const std::vector &args); /// Given a class type and the matching generic vector, instantiate the type and /// realize it. - ir::types::Type *realizeType(types::ClassTypePtr type, + ir::types::Type *realizeType(types::ClassType *type, const std::vector &generics = {}); /// Given a function type and function arguments, instantiate the type and /// realize it. The first argument is the function return type. /// You can also pass function generics if a function has one (e.g. T in def /// foo[T](...)). If a generic is used as an argument, it will be auto-deduced. Pass /// only if a generic cannot be deduced from the provided args. - ir::Func *realizeFunction(types::FuncTypePtr type, + ir::Func *realizeFunction(types::FuncType *type, const std::vector &args, const std::vector &generics = {}, - const types::ClassTypePtr &parentClass = nullptr); + types::ClassType *parentClass = nullptr); ir::types::Type *makeTuple(const std::vector &types); ir::types::Type *makeFunction(const std::vector &types); @@ -314,10 +314,53 @@ struct Cache : public std::enable_shared_from_this { void parseCode(const std::string &code); - static std::vector mergeC3(std::vector> &); + static std::vector> + mergeC3(std::vector> &); std::shared_ptr pyModule = nullptr; void populatePythonModule(); + +private: + std::vector> *_nodes; + +public: + /// Convenience method that constructs a node with the visitor's source location. + template Tn *N(Ts &&...args) { + _nodes->emplace_back(std::make_unique(std::forward(args)...)); + Tn *t = (Tn *)(_nodes->back().get()); + t->cache = this; + return t; + } + template Tn *NS(const ASTNode *srcInfo, Ts &&...args) { + _nodes->emplace_back(std::make_unique(std::forward(args)...)); + Tn *t = (Tn *)(_nodes->back().get()); + t->cache = this; + t->setSrcInfo(srcInfo->getSrcInfo()); + return t; + } + +public: + std::unordered_map _timings; + + struct CTimer { + Cache *c; + Timer t; + std::string name; + CTimer(Cache *c, std::string name) : c(c), name(std::move(name)), t(Timer("")) {} + ~CTimer() { + c->_timings[name] += t.elapsed(); + t.logged = true; + } + }; + + template + std::vector castVectorPtr(std::vector> v) { + std::vector r; + r.reserve(v.size()); + for (const auto &i : v) + r.emplace_back(i.get()); + return r; + } }; } // namespace codon::ast diff --git a/codon/parser/common.cpp b/codon/parser/common.cpp index 962b61f7..9bc55abf 100644 --- a/codon/parser/common.cpp +++ b/codon/parser/common.cpp @@ -101,18 +101,32 @@ std::string escapeFStringBraces(const std::string &str, int start, int len) { return t; } int findStar(const std::string &s) { + bool start = false; int i = 0; - for (; i < s.size(); i++) - if (s[i] == ' ' || s[i] == ')') - break; + for (; i < s.size(); i++) { + if (s[i] == '(') + return i + 1; + if (!isspace(s[i])) + return i; + // if (start && (s[i] == '\n' || s[i] == ' ' || s[i] == ')')) + // break; + } return i; } +bool in(const std::string &m, const std::string &item) { + auto f = m.find(item); + return f != std::string::npos; +} size_t startswith(const std::string &str, const std::string &prefix) { + if (prefix.empty()) + return true; return (str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix) ? prefix.size() : 0; } size_t endswith(const std::string &str, const std::string &suffix) { + if (suffix.empty()) + return true; return (str.size() >= suffix.size() && str.substr(str.size() - suffix.size()) == suffix) ? suffix.size() @@ -130,13 +144,6 @@ void rtrim(std::string &str) { .base(), str.end()); } -int trimStars(std::string &str) { - int stars = 0; - for (; stars < str.size() && str[stars] == '*'; stars++) - ; - str = str.substr(stars); - return stars; -} bool isdigit(const std::string &str) { return std::all_of(str.begin(), str.end(), ::isdigit); } @@ -253,6 +260,11 @@ ImportFile getRoot(const std::string argv0, const std::vector &plug ext = ".py"; seqassertn((root.empty() || startswith(s, root)) && endswith(s, ext), "bad path substitution: {}, {}", s, root); + // LOG("{} -> {} {}", s, root, ext); + // Find toplevel enclosing import! + // for (auto &x: ctx->cache->imports) { + // if (substr(x.module, ) + // } auto module = s.substr(root.size() + 1, s.size() - root.size() - ext.size() - 1); std::replace(module.begin(), module.end(), '/', '.'); return ImportFile{(!isStdLib && root == module0Root) ? ImportFile::PACKAGE @@ -307,9 +319,11 @@ std::shared_ptr getImportFile(const std::string &argv0, } auto module0Root = llvm::sys::path::parent_path(getAbsolutePath(module0)).str(); - return paths.empty() ? nullptr - : std::make_shared( - getRoot(argv0, plugins, module0Root, paths[0])); + auto file = paths.empty() ? nullptr + : std::make_shared( + getRoot(argv0, plugins, module0Root, paths[0])); + + return file; } } // namespace codon::ast diff --git a/codon/parser/common.h b/codon/parser/common.h index 2b0a15d6..bca0d9cf 100644 --- a/codon/parser/common.h +++ b/codon/parser/common.h @@ -16,9 +16,10 @@ #include "codon/util/common.h" -#define CAST(s, T) dynamic_cast(s.get()) - namespace codon { +namespace ir { +class Attribute; +} namespace ast { @@ -41,13 +42,23 @@ size_t endswith(const std::string &str, const std::string &suffix); void ltrim(std::string &str); /// Trims whitespace at the end of the string. void rtrim(std::string &str); -/// Removes leading stars in front of the string and returns the number of such stars. -int trimStars(std::string &str); /// True if a string only contains digits. bool isdigit(const std::string &str); /// Combine items separated by a delimiter into a string. +/// Combine items separated by a delimiter into a string. +template std::string join(const T &items, const std::string &delim = " ") { + std::string s; + bool first = true; + for (const auto &i : items) { + if (!first) + s += delim; + s += i; + first = false; + } + return s; +} template -std::string join(const T &items, const std::string &delim = " ", size_t start = 0, +std::string join(const T &items, const std::string &delim, size_t start, size_t end = (1ull << 31)) { std::string s; if (end > items.size()) @@ -58,11 +69,12 @@ std::string join(const T &items, const std::string &delim = " ", size_t start = } /// Combine items separated by a delimiter into a string. template -std::string combine(const std::vector &items, const std::string &delim = " ") { +std::string combine(const std::vector &items, const std::string &delim = " ", + const int indent = -1) { std::string s; for (int i = 0; i < items.size(); i++) if (items[i]) - s += (i ? delim : "") + items[i]->toString(); + s += (i ? delim : "") + items[i]->toString(indent); return s; } template @@ -104,33 +116,42 @@ const V *in(const std::unordered_map &m, const U &item) { auto f = m.find(item); return f != m.end() ? &(f->second) : nullptr; } +/// @return True if an item is found in an unordered_map m. +template +V *in(std::unordered_map &m, const U &item) { + auto f = m.find(item); + return f != m.end() ? &(f->second) : nullptr; +} /// @return vector c transformed by the function f. template auto vmap(const std::vector &c, F &&f) { std::vector::type> ret; std::transform(std::begin(c), std::end(c), std::inserter(ret, std::end(ret)), f); return ret; } +/// @return True if an item is found in an string m. +bool in(const std::string &m, const std::string &item); /// AST utilities -/// Clones a pointer even if it is a nullptr. -template auto clone(const std::shared_ptr &t) { - return t ? t->clone() : nullptr; +template T clone(const T &t, bool clean = false) { return t.clone(clean); } + +template +typename std::remove_const::type *clone(T *t, bool clean = false) { + return t ? static_cast::type *>(t->clone(clean)) + : nullptr; } -/// Clones a vector of cloneable pointer objects. -template std::vector clone(const std::vector &t) { - std::vector v; - for (auto &i : t) - v.push_back(clone(i)); - return v; +template typename std::remove_const::type *clean_clone(T *t) { + return clone(t, true); } -/// Clones a vector of cloneable objects. -template std::vector clone_nop(const std::vector &t) { - std::vector v; +/// Clones a vector of cloneable pointer objects. +template +std::vector::type> clone(const std::vector &t, + bool clean = false) { + std::vector::type> v; for (auto &i : t) - v.push_back(i.clone()); + v.push_back(clone(i, clean)); return v; } @@ -164,5 +185,14 @@ std::shared_ptr getImportFile(const std::string &argv0, const std::string &module0 = "", const std::vector &plugins = {}); +template class SetInScope { + T *t; + T origVal; + +public: + SetInScope(T *t, const T &val) : t(t), origVal(*t) { *t = val; } + ~SetInScope() { *t = origVal; } +}; + } // namespace ast } // namespace codon diff --git a/codon/parser/ctx.h b/codon/parser/ctx.h index c33feea5..6129c1f4 100644 --- a/codon/parser/ctx.h +++ b/codon/parser/ctx.h @@ -43,7 +43,7 @@ template class Context : public std::enable_shared_from_this srcInfos; + std::vector nodeStack; public: explicit Context(std::string filename) : filename(std::move(filename)) { @@ -74,6 +74,11 @@ template class Context : public std::enable_shared_from_thissecond.front() : nullptr; } + /// Return all objects that share a common identifier or nullptr if it does not exist. + virtual std::list *find_all(const std::string &name) { + auto it = map.find(name); + return it != map.end() ? &(it->second) : nullptr; + } /// Add a new block (i.e. adds a stack level). virtual void addBlock() { stack.push_front(std::list()); } /// Remove the top-most block and all variables it holds. @@ -83,6 +88,12 @@ template class Context : public std::enable_shared_from_this class Context : public std::enable_shared_from_this class Context : public std::enable_shared_from_this 1); + return nodeStack[nodeStack.size() - 2]; + } + SrcInfo getSrcInfo() const { return nodeStack.back()->getSrcInfo(); } + size_t getStackSize() const { return stack.size(); } }; } // namespace codon::ast diff --git a/codon/parser/match.cpp b/codon/parser/match.cpp new file mode 100644 index 00000000..a4727ccc --- /dev/null +++ b/codon/parser/match.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#include "codon/parser/match.h" + +namespace codon::matcher { + +match_zero_or_more_t MAny() { return match_zero_or_more_t(); } + +match_startswith_t MStarts(std::string s) { return match_startswith_t{std::move(s)}; } + +match_endswith_t MEnds(std::string s) { return match_endswith_t{std::move(s)}; } + +match_contains_t MContains(std::string s) { return match_contains_t{std::move(s)}; } + +template <> bool match(const char *c, const char *d) { + return std::string(c) == std::string(d); +} + +template <> bool match(const char *c, std::string d) { return std::string(c) == d; } + +template <> bool match(std::string c, const char *d) { return std::string(d) == c; } + +template <> bool match(double &a, double b) { return abs(a - b) < __FLT_EPSILON__; } + +template <> bool match(std::string s, match_startswith_t m) { + return m.s.size() <= s.size() && s.substr(0, m.s.size()) == m.s; +} + +template <> bool match(std::string s, match_endswith_t m) { + return m.s.size() <= s.size() && s.substr(s.size() - m.s.size(), m.s.size()) == m.s; +} + +template <> bool match(std::string s, match_contains_t m) { + return s.find(m.s) != std::string::npos; +} + +template <> bool match(const char *s, match_startswith_t m) { + return match(std::string(s), m); +} + +template <> bool match(const char *s, match_endswith_t m) { + return match(std::string(s), m); +} + +template <> bool match(const char *s, match_contains_t m) { + return match(std::string(s), m); +} + +} // namespace codon::matcher diff --git a/codon/parser/match.h b/codon/parser/match.h new file mode 100644 index 00000000..c51ea346 --- /dev/null +++ b/codon/parser/match.h @@ -0,0 +1,171 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#pragma once + +#include + +#include "codon/cir/base.h" + +namespace codon::matcher { + +template struct match_t { + std::tuple args; + std::function fn; + match_t(MA... args, std::function fn) + : args(std::tuple(args...)), fn(fn) {} +}; + +template struct match_or_t { + std::tuple args; + match_or_t(MA... args) : args(std::tuple(args...)) {} +}; + +struct match_ignore_t {}; + +struct match_zero_or_more_t {}; + +struct match_startswith_t { + std::string s; +}; + +struct match_endswith_t { + std::string s; +}; + +struct match_contains_t { + std::string s; +}; + +template match_t M(TA... args) { + return match_t(args..., nullptr); +} + +template +match_t MCall(TA... args, std::function fn) { + return match_t(args..., fn); +} + +template match_t MVar(TA... args, T &tp) { + return match_t(args..., [&tp](T &t) { tp = t; }); +} + +template match_t MVar(TA... args, T *&tp) { + return match_t(args..., [&tp](T &t) { tp = &t; }); +} + +template match_or_t MOr(TA... args) { + return match_or_t(args...); +} + +match_zero_or_more_t MAny(); + +match_startswith_t MStarts(std::string s); + +match_endswith_t MEnds(std::string s); + +match_contains_t MContains(std::string s); + +////////////////////////////////////////////////////////////////////////////// + +template bool match(T t, M m) { + if constexpr (std::is_same_v) + return t == m; + return false; +} + +template bool match(T &t, match_ignore_t) { return true; } + +template bool match(T &t, match_zero_or_more_t) { return true; } + +template <> bool match(const char *c, const char *d); + +template <> bool match(const char *c, std::string d); + +template <> bool match(std::string c, const char *d); + +template <> bool match(double &a, double b); + +template <> bool match(std::string s, match_startswith_t m); + +template <> bool match(std::string s, match_endswith_t m); + +template <> bool match(std::string s, match_contains_t m); + +template <> bool match(const char *s, match_startswith_t m); + +template <> bool match(const char *s, match_endswith_t m); + +template <> bool match(const char *s, match_contains_t m); + +template bool match_help(T &t, TM m) { + if constexpr (i == std::tuple_size::value) { + return i == std::tuple_size::value; + } else if constexpr (i < std::tuple_size::value) { + if constexpr (std::is_same_v(m.args))>, + match_zero_or_more_t>) { + return true; + } + return match(std::get(t.match_members()), std::get(m.args)) && + match_help(t, m); + } else { + return false; + } +} + +template +bool match_or_help(T &t, match_or_t m) { + if constexpr (i >= 0 && i < std::tuple_size::value) { + return match(t, std::get(m.args)) || match_or_help(t, m); + } else { + return false; + } +} + +template bool match(TM &t, match_or_t m) { + return match_or_help<0, TM, TA...>(t, m); +} + +template bool match(TM *t, match_or_t m) { + return match_or_help<0, TM *, TA...>(t, m); +} + +template +bool match(T &t, match_t m) { + if constexpr (std::is_pointer_v) { + TM *tm = ir::cast(t); + if (!tm) + return false; + if constexpr (sizeof...(TA) == 0) { + if (m.fn) + m.fn(*tm); + return true; + } else { + auto r = match_help<0>(*tm, m); + if (r && m.fn) + m.fn(*tm); + return r; + } + } else { + if constexpr (!std::is_same_v) + return false; + if constexpr (sizeof...(TA) == 0) { + if (m.fn) + m.fn(t); + return true; + } else { + auto r = match_help<0>(t, m); + if (r && m.fn) + m.fn(t); + return r; + } + } +} + +template +bool match(T *t, match_t m) { + return match(t, m); +} + +} // namespace codon::matcher + +#define M_ matcher::match_ignore_t() diff --git a/codon/parser/peg/grammar.peg b/codon/parser/peg/grammar.peg index 176248bd..1805fa57 100644 --- a/codon/parser/peg/grammar.peg +++ b/codon/parser/peg/grammar.peg @@ -14,55 +14,85 @@ PREAMBLE { #define V1 VS[1] #define V2 VS[2] #define ac std::any_cast - #define ac_expr std::any_cast - #define ac_stmt std::any_cast + #define ac_expr std::any_cast + #define ac_stmt std::any_cast #define SemVals peg::SemanticValues + #define aste(T, s, ...) setSI(CTX.cache->N(__VA_ARGS__), s) + #define asts(T, s, ...) setSI(CTX.cache->N(__VA_ARGS__), s) + + template + T *setSI(ASTNode *n, const codon::SrcInfo &s) { + n->setSrcInfo(s); + return (T*)n; + } + template auto vmap(const peg::SemanticValues &c, F &&f) { return vmap(static_cast&>(c), f); } - template auto ast(Tsv &s, Ts &&...args) { - auto t = make_shared(std::forward(args)...); - t->setSrcInfo(s); - return std::static_pointer_cast(t); - } - auto chain(peg::SemanticValues &VS, const codon::SrcInfo &LOC) { - auto b = ac_expr(V0); + Expr *chain(const codon::ast::ParseContext &CTX, peg::SemanticValues &VS, const codon::SrcInfo &LOC) { + Expr *b = ac_expr(V0); for (int i = 1; i < VS.size(); i++) - b = ast(LOC, b, VS.token_to_string(i - 1), ac_expr(VS[i])); + b = aste(Binary, LOC, b, VS.token_to_string(i - 1), ac_expr(VS[i])); return b; } - auto wrap_tuple(peg::SemanticValues &VS, const codon::SrcInfo &LOC) { + Expr *wrap_tuple(const codon::ast::ParseContext &CTX, peg::SemanticValues &VS, const codon::SrcInfo &LOC) { if (VS.size() == 1 && VS.tokens.empty()) return ac_expr(V0); - return ast(LOC, VS.transform()); + return aste(Tuple, LOC, VS.transform()); } } + program <- (statements (_ EOL)* / (_ EOL)*) !. { if (VS.empty()) - return ast(LOC); + return asts(Suite, LOC); return ac_stmt(V0); } -fstring <- star_expressions _ (':' format_spec)? _ !. { - return make_pair(ac_expr(V0), VS.size() == 1 ? "" : ac(V1)); + +fstring <- fstring_prefix _ fstring_tail? _ !. { + StringExpr::FormatSpec fs {"", "", ""}; + auto [e, t] = ac>(V0); + fs.text = t; + if (VS.size() > 1) { + auto [c, s] = ac>(V1); + fs.conversion = c; + fs.spec = s; + } + return make_pair(e, fs); +} +fstring_prefix <- star_expressions _ '='? { + auto text = VS.sv(); + return make_pair(ac_expr(V0), + string(!VS.sv().empty() && VS.sv().back() == '=' ? VS.sv() : "")); } +fstring_tail <- + / fstring_conversion fstring_spec? { + return make_pair(ac(V0), VS.size() > 1 ? ac(V1) : ""); + } + / fstring_spec { return make_pair(std::string(), ac(V0)); } +fstring_conversion <- "!" ("s" / "r" / "a") { return string(VS.sv().substr(1)); } +fstring_spec <- ':' format_spec { return ac(V0); } # Macros list(c, e) <- e (_ c _ e)* tlist(c, e) <- e (_ c _ e)* (_ )? statements <- ((_ EOL)* statement)+ { - return ast(LOC, VS.transform()); + auto s = asts(Suite, LOC, VS.transform()); + cast(s)->flatten(); + return s; } statement <- SAMEDENT compound_stmt / SAMEDENT simple_stmt simple_stmt <- tlist(';', small_stmt) _ EOL { - return ast(LOC, VS.transform()); + auto s = asts(Suite, LOC, VS.transform()); + cast(s)->flatten(); + return s; } small_stmt <- / assignment - / 'pass' &(SPACE / ';' / EOL) { return any(ast(LOC)); } - / 'break' &(SPACE / ';' / EOL) { return any(ast(LOC)); } - / 'continue' &(SPACE / ';' / EOL) { return any(ast(LOC)); } + / 'pass' &(SPACE / ';' / EOL) { return any(asts(Suite, LOC)); } + / 'break' &(SPACE / ';' / EOL) { return any(asts(Break, LOC)); } + / 'continue' &(SPACE / ';' / EOL) { return any(asts(Continue, LOC)); } / global_stmt / nonlocal_stmt / yield_stmt &(SPACE / ';' / EOL) @@ -72,25 +102,26 @@ small_stmt <- / raise_stmt &(SPACE / ';' / EOL) / print_stmt / import_stmt - / expressions &(_ ';' / _ EOL) { return any(ast(LOC, ac_expr(V0))); } + / expressions &(_ ';' / _ EOL) { return any(asts(Expr, LOC, ac_expr(V0))); } / custom_small_stmt assignment <- / id _ ':' _ expression (_ '=' _ star_expressions)? { - return ast(LOC, + return asts(Assign, LOC, ac_expr(V0), VS.size() > 2 ? ac_expr(V2) : nullptr, ac_expr(V1) ); } / (star_targets _ (!'==' '=') _)+ star_expressions !(_ '=') { - vector stmts; - for (int i = int(VS.size()) - 2; i >= 0; i--) - stmts.push_back(ast(LOC, ac_expr(VS[i]), ac_expr(VS[i + 1]))); - return ast(LOC, move(stmts)); + vector stmts; + for (int i = int(VS.size()) - 2; i >= 0; i--) { + auto a = asts(Assign, LOC, ac_expr(VS[i]), ac_expr(VS[i + 1])); + stmts.push_back(a); + } + return asts(Suite, LOC, std::move(stmts)); } / star_expression _ augassign '=' ^ _ star_expressions { - return ast(LOC, - ac_expr(V0), ast(LOC, clone(ac_expr(V0)), ac(V1), ac_expr(V2), true) - ); + auto a = asts(Assign, LOC, ac_expr(V0), aste(Binary, LOC, clone(ac_expr(V0)), ac(V1), ac_expr(V2), true)); + return a; } augassign <- < '+' / '-' / '**' / '*' / '@' / '//' / '/' / '%' / '&' / '|' / '^' / '<<' / '>>' @@ -98,46 +129,47 @@ augassign <- < return VS.token_to_string(); } global_stmt <- 'global' SPACE tlist(',', NAME) { - return ast(LOC, - vmap(VS, [&](const any &i) { return ast(LOC, ac(i), false); }) + return asts(Suite, LOC, + vmap(VS, [&](const any &i) { return asts(Global, LOC, ac(i), false); }) ); } nonlocal_stmt <- 'nonlocal' SPACE tlist(',', NAME) { - return ast(LOC, - vmap(VS, [&](const any &i) { return ast(LOC, ac(i), true); }) + return asts(Suite, LOC, + vmap(VS, [&](const any &i) { return asts(Global, LOC, ac(i), true); }) ); } yield_stmt <- - / 'yield' SPACE 'from' SPACE expression { return ast(LOC, ac_expr(V0)); } + / 'yield' SPACE 'from' SPACE expression { return asts(YieldFrom, LOC, ac_expr(V0)); } / 'yield' (SPACE expressions)? { - return ast(LOC, !VS.empty() ? ac_expr(V0) : nullptr); + return asts(Yield, LOC, !VS.empty() ? ac_expr(V0) : nullptr); } assert_stmt <- 'assert' SPACE expression (_ ',' _ expression)? { - return ast(LOC, ac_expr(V0), VS.size() > 1 ? ac_expr(V1) : nullptr); + return asts(Assert, LOC, ac_expr(V0), VS.size() > 1 ? ac_expr(V1) : nullptr); } # TODO: do targets as in Python del_stmt <- 'del' SPACE tlist(',', expression) { - return ast(LOC, - vmap(VS, [&](const any &i) { return ast(LOC, ac_expr(i)); }) + return asts(Suite, LOC, + vmap(VS, [&](const any &i) { return asts(Del, LOC, ac_expr(i)); }) ); } return_stmt <- 'return' (SPACE expressions)? { - return ast(LOC, !VS.empty() ? ac_expr(V0) : nullptr); -} -# TODO: raise expression 'from' expression -raise_stmt <- 'raise' (SPACE expression)? { - return ast(LOC, !VS.empty() ? ac_expr(V0) : nullptr); + return asts(Return, LOC, !VS.empty() ? ac_expr(V0) : nullptr); } +raise_stmt <- + / 'raise' SPACE expression (SPACE 'from' SPACE expression)? { + return asts(Throw, LOC, ac_expr(V0), VS.size() > 1 ? ac_expr(V1) : nullptr); + } + / 'raise' { return asts(Throw, LOC, nullptr); } print_stmt <- / 'print' SPACE star_expression (_ ',' _ star_expression)* (_ <','>)? { - return ast(LOC, VS.transform(), !VS.tokens.empty()); + return asts(Print, LOC, VS.transform(), !VS.tokens.empty()); } - / 'print' _ &EOL { return ast(LOC, vector{}, false); } + / 'print' _ &EOL { return asts(Print, LOC, vector{}, false); } import_stmt <- import_name / import_from import_name <- 'import' SPACE list(',', as_name) { - return ast(LOC, - vmap(VS.transform>(), [&](const pair &i) { - return ast(LOC, i.first, nullptr, vector{}, nullptr, i.second); + return asts(Suite, LOC, + vmap(VS.transform>(), [&](const pair &i) { + return asts(Import, LOC, i.first, nullptr, vector{}, nullptr, i.second); }) ); } @@ -146,22 +178,22 @@ as_name <- dot_name (SPACE 'as' SPACE NAME)? { } import_from <- / 'from' SPACE (_ <'.'>)* (_ dot_name)? SPACE 'import' SPACE '*' { - return ast(LOC, - VS.size() == 1 ? ac_expr(V0) : nullptr, ast(LOC, "*"), vector{}, + return asts(Import, LOC, + VS.size() == 1 ? ac_expr(V0) : nullptr, aste(Id, LOC, "*"), vector{}, nullptr, "", int(VS.tokens.size()) ); } / 'from' SPACE (_ <'.'>)* (_ dot_name)? SPACE 'import' SPACE (from_as_parens / from_as_items) { auto f = VS.size() == 2 ? ac_expr(V0) : nullptr; - return ast(LOC, + return asts(Suite, LOC, vmap( ac(VS.size() == 2 ? V1 : V0), [&](const any &i) { auto p = ac>(i); - auto t = ac, ExprPtr, bool>>(p.first); - return ast(LOC, - f, get<0>(t), move(get<1>(t)), get<2>(t), p.second, int(VS.tokens.size()), get<3>(t) + auto t = ac, Expr*, bool>>(p.first); + return asts(Import, LOC, + f, get<0>(t), std::move(get<1>(t)), get<2>(t), p.second, int(VS.tokens.size()), get<3>(t) ); } ) @@ -180,17 +212,17 @@ from_id <- return tuple( ac_expr(V0), ac(V1).transform(), - VS.size() > 2 ? ac_expr(V2) : ast(LOC, "NoneType"), + VS.size() > 2 ? ac_expr(V2) : aste(Id, LOC, "NoneType"), true ); } - / dot_name { return tuple(ac_expr(V0), vector{}, (ExprPtr)nullptr, true); } + / dot_name { return tuple(ac_expr(V0), vector{}, (Expr*)nullptr, true); } dot_name <- id (_ '.' _ NAME)* { if (VS.size() == 1) return ac_expr(V0); - auto dot = ast(LOC, ac_expr(V0), ac(V1)); + auto dot = aste(Dot, LOC, ac_expr(V0), ac(V1)); for (int i = 2; i < VS.size(); i++) - dot = ast(LOC, dot, ac(VS[i])); + dot = aste(Dot, LOC, dot, ac(VS[i])); return dot; } from_params <- '(' _ tlist(',', from_param)? _ ')' { return VS; } @@ -213,46 +245,37 @@ compound_stmt <- if_stmt <- ('if' SPACE named_expression _ ':' _ suite) (SAMEDENT 'elif' SPACE named_expression _ ':' _ suite)* (SAMEDENT 'else' _ ':' _ suite)? { - shared_ptr stmt = ast(LOC, nullptr, nullptr); - IfStmt *p = (IfStmt*)stmt.get(); - for (int i = 0; i < VS.size(); i += 2) { - if (i == VS.size() - 1) { - p->elseSuite = ac_stmt(VS[i]); - } else { - if (i) { - p->elseSuite = ast(LOC, nullptr, nullptr); - p = (IfStmt*)(p->elseSuite.get()); - } - p->cond = ac_expr(VS[i]); - p->ifSuite = ac_stmt(VS[i + 1]); - } + Stmt *lastElse = VS.size() % 2 == 0 ? nullptr : SuiteStmt::wrap(ac_stmt(VS.back())); + for (size_t i = VS.size() - bool(lastElse); i-- > 0; ) { + lastElse = asts(If, LOC, ac_expr(VS[i - 1]), SuiteStmt::wrap(ac_stmt(VS[i])), SuiteStmt::wrap(lastElse)); + i--; } - return stmt; + return lastElse; } while_stmt <- ('while' SPACE named_expression _ ':' _ suite) (SAMEDENT 'else' (SPACE 'not' SPACE 'break')* _ ':' _ suite)? { - return ast(LOC, + return asts(While, LOC, ac_expr(V0), ac_stmt(V1), VS.size() > 2 ? ac_stmt(V2) : nullptr ); } for <- decorator? for_stmt { if (VS.size() > 1) { - auto s = dynamic_pointer_cast(ac_stmt(V1)); - s->decorator = ac_expr(V0); - return static_pointer_cast(s); + auto s = (ForStmt*)(ac_stmt(V1)); + s->setDecorator(ac_expr(V0)); + return (Stmt*)s; } return ac_stmt(V0); } for_stmt <- ('for' SPACE star_targets) (SPACE 'in' SPACE star_expressions _ ':' _ suite) (SAMEDENT 'else' (SPACE 'not' SPACE 'break')* _ ':' _ suite)? { - return ast(LOC, + return asts(For, LOC, ac_expr(V0), ac_expr(V1), ac_stmt(V2), VS.size() > 3 ? ac_stmt(VS[3]) : nullptr ); } with_stmt <- 'with' SPACE (with_parens_item / with_item) _ ':' _ suite { - return ast(LOC, - ac(V0).transform>(), ac_stmt(V1) + return asts(With, LOC, + ac(V0).transform>(), ac_stmt(V1) ); } with_parens_item <- '(' _ tlist(',', as_item) _ ')' { return VS; } @@ -261,47 +284,51 @@ as_item <- / expression SPACE 'as' SPACE id &(_ (',' / ')' / ':')) { return pair(ac_expr(V0), ac_expr(V1)); } - / expression { return pair(ac_expr(V0), (ExprPtr)nullptr); } + / expression { return pair(ac_expr(V0), (Expr*)nullptr); } # TODO: else block? try_stmt <- - / ('try' _ ':' _ suite) - excepts - (SAMEDENT 'finally' _ ':' _ suite)? { - return ast(LOC, + / ('try' _ ':' _ suite) excepts else_finally? { + std::pair ef {nullptr, nullptr}; + if (VS.size() > 2) ef = ac>(V2); + return asts(Try, LOC, ac_stmt(V0), - ac(V1).transform(), - VS.size() > 2 ? ac_stmt(V2): nullptr + ac(V1).transform(), + ef.first, ef.second ); } / ('try' _ ':' _ suite) (SAMEDENT 'finally' _ ':' _ suite)? { - return ast(LOC, - ac_stmt(V0), vector{}, VS.size() > 1 ? ac_stmt(V1): nullptr + return asts(Try, LOC, + ac_stmt(V0), vector{}, nullptr, + VS.size() > 1 ? ac_stmt(V1) : nullptr ); } +else_finally <- + / SAMEDENT 'else' _ ':' _ suite + SAMEDENT 'finally' _ ':' _ suite { return std::pair(ac_stmt(V0), ac_stmt(V1)); } + / SAMEDENT 'else' _ ':' _ suite { return std::pair(ac_stmt(V0), nullptr); } + / SAMEDENT 'finally' _ ':' _ suite { return std::pair(nullptr, ac_stmt(V0)); } excepts <- (SAMEDENT except_block)+ { return VS; } except_block <- / 'except' SPACE expression (SPACE 'as' SPACE NAME)? _ ':' _ suite { if (VS.size() == 3) - return TryStmt::Catch{ac(V1), ac_expr(V0), ac_stmt(V2)}; + return setSI(CTX.cache->N(ac(V1), ac_expr(V0), ac_stmt(V2)), LOC); else - return TryStmt::Catch{"", ac_expr(V0), ac_stmt(V1)}; + return setSI(CTX.cache->N("", ac_expr(V0), ac_stmt(V1)), LOC); } - / 'except' _ ':' _ suite { return TryStmt::Catch{"", nullptr, ac_stmt(V0)}; } + / 'except' _ ':' _ suite { return setSI(CTX.cache->N("", nullptr, ac_stmt(V0)), LOC); } function <- / extern_decorators function_def (_ EOL)+ &INDENT extern (_ EOL)* &DEDENT { - auto fn = dynamic_pointer_cast(ac_stmt(V1)); - fn->decorators = ac>(V0); - fn->suite = ast(LOC, ast(LOC, ac(V2))); - fn->parseDecorators(); - return static_pointer_cast(fn); + auto fn = (FunctionStmt*)(ac_stmt(V1)); + fn->setDecorators(ac>(V0)); + fn->setSuite(SuiteStmt::wrap(asts(Expr, LOC, aste(String, LOC, ac(V2))))); + return (Stmt*)fn; } / decorators? function_def _ suite { - auto fn = dynamic_pointer_cast(ac_stmt(VS.size() > 2 ? V1 : V0)); + auto fn = (FunctionStmt*)(ac_stmt(VS.size() > 2 ? V1 : V0)); if (VS.size() > 2) - fn->decorators = ac>(V0); - fn->suite = ac_stmt(VS.size() > 2 ? V2 : V1); - fn->parseDecorators(); - return static_pointer_cast(fn); + fn->setDecorators(ac>(V0)); + fn->setSuite(SuiteStmt::wrap(ac_stmt(VS.size() > 2 ? V2 : V1))); + return (Stmt*)fn; } extern <- (empty_line* EXTERNDENT (!EOL .)* EOL empty_line*)+ { return string(VS.sv()); @@ -312,7 +339,7 @@ function_def <- auto params = ac(V2).transform(); for (auto &p: ac>(V1)) params.push_back(p); - return ast(LOC, + return asts(Function, LOC, ac(V0), VS.size() == 4 ? ac_expr(VS[3]) : nullptr, params, @@ -320,7 +347,7 @@ function_def <- ); } / 'def' SPACE NAME _ params (_ '->' _ expression)? _ ':' { - return ast(LOC, + return asts(Function, LOC, ac(V0), VS.size() == 3 ? ac_expr(VS[2]) : nullptr, ac(V1).transform(), @@ -343,22 +370,22 @@ generics <- '[' _ tlist(',', param) _ ']' { for (auto &p: VS) { auto v = ac(p); v.status = Param::Generic; - if (!v.type) v.type = ast(LOC, "type"); + if (!v.type) v.type = aste(Id, LOC, "type"); params.push_back(v); } return params; } decorators <- decorator+ { - return VS.transform(); + return VS.transform(); } decorator <- ('@' _ !(('llvm' / 'python') _ EOL) named_expression _ EOL SAMEDENT) { return ac_expr(V0); } extern_decorators <- / decorators? ('@' _ <'llvm'/'python'> _ EOL SAMEDENT) decorators? { - vector vs{ast(LOC, VS.token_to_string())}; + vector vs{aste(Id, LOC, VS.token_to_string())}; for (auto &v: VS) { - auto nv = ac>(v); + auto nv = ac>(v); vs.insert(vs.end(), nv.begin(), nv.end()); } return vs; @@ -366,59 +393,59 @@ extern_decorators <- class <- decorators? class_def { if (VS.size() == 2) { auto fn = ac_stmt(V1); - dynamic_pointer_cast(fn)->decorators = ac>(V0); - dynamic_pointer_cast(fn)->parseDecorators(); + cast(fn)->setDecorators(ac>(V0)); return fn; } return ac_stmt(V0); } base_class_args <- '(' _ tlist(',', expression)? _ ')' { - return VS.transform(); + return VS.transform(); } class_args <- - / generics _ base_class_args { return make_pair(ac>(V0), ac>(V1)); } - / generics { return make_pair(ac>(V0), vector{}); } - / base_class_args { return make_pair(vector{}, ac>(V0)); } + / generics _ base_class_args { return make_pair(ac>(V0), ac>(V1)); } + / generics { return make_pair(ac>(V0), vector{}); } + / base_class_args { return make_pair(vector{}, ac>(V0)); } class_def <- 'class' SPACE NAME _ class_args? _ ':' _ suite { vector generics; - vector baseClasses; + vector baseClasses; if (VS.size() == 3) - std::tie(generics, baseClasses) = ac, vector>>(V1); + std::tie(generics, baseClasses) = ac, vector>>(V1); vector args; - auto suite = make_shared(); - auto s = const_cast(ac_stmt(VS.size() == 3 ? V2 : V1)->getSuite()); + auto suite = (SuiteStmt*)(asts(Suite, LOC)); + auto s = cast(ac_stmt(VS.size() == 3 ? V2 : V1)); seqassertn(s, "not a suite"); - for (auto &i: s->stmts) { - if (auto a = const_cast(i->getAssign())) - if (a->lhs->getId()) { - args.push_back(Param(a->getSrcInfo(), a->lhs->getId()->value, move(a->type), move(a->rhs))); + for (auto *i: *s) { + if (auto a = cast(i)) + if (auto ei = cast(a->getLhs())) { + args.push_back(Param(a->getSrcInfo(), ei->getValue(), a->getTypeExpr(), a->getRhs())); continue; } - suite->stmts.push_back(i); + suite->addStmt(i); } + suite->flatten(); for (auto &p: generics) args.push_back(p); - return ast(LOC, - ac(V0), move(args), suite, vector{}, baseClasses + return asts(Class, LOC, + ac(V0), std::move(args), suite, vector{}, baseClasses ); } match_stmt <- 'match' SPACE expression _ ':' (_ EOL)+ &INDENT (SAMEDENT case)+ (_ EOL)* &DEDENT { - return ast(LOC, ac_expr(V0), VS.transform(1)); + return asts(Match, LOC, ac_expr(V0), VS.transform(1)); } case <- / 'case' SPACE expression SPACE 'if' SPACE pipe _ ':' _ suite { - return MatchStmt::MatchCase{ac_expr(V0), ac_expr(V1), ac_stmt(V2)}; + return MatchCase{ac_expr(V0), ac_expr(V1), ac_stmt(V2)}; } / 'case' SPACE expression _ ':' _ suite { - return MatchStmt::MatchCase{ac_expr(V0), nullptr, ac_stmt(V1)}; + return MatchCase{ac_expr(V0), nullptr, ac_stmt(V1)}; } custom_stmt <- / NAME SPACE expression _ ':' _ suite { - return ast(LOC, ac(V0), ac_expr(V1), ac_stmt(V2)); + return asts(Custom, LOC, ac(V0), ac_expr(V1), ac_stmt(V2)); } / NAME _ ':' _ suite { - return ast(LOC, ac(V0), nullptr, ac_stmt(V2)); + return asts(Custom, LOC, ac(V0), nullptr, ac_stmt(V2)); } custom_stmt__PREDICATE { auto kwd = ac(V0); @@ -426,7 +453,7 @@ custom_stmt__PREDICATE { } custom_small_stmt <- NAME SPACE expressions { - return any(ast(LOC, ac(V0), ac_expr(V1), nullptr)); + return any(asts(Custom, LOC, ac(V0), ac_expr(V1), nullptr)); } custom_small_stmt__PREDICATE { auto kwd = ac(V0); @@ -438,61 +465,64 @@ custom_small_stmt__PREDICATE { # (2) Expressions ######################################################################################## -expressions <- tlist(',', expression) { return wrap_tuple(VS, LOC); } +expressions <- tlist(',', expression) { return wrap_tuple(CTX, VS, LOC); } expression <- / lambdef { return ac_expr(V0); } - / disjunction SPACE 'if' SPACE disjunction SPACE 'else' SPACE expression { - return ast(LOC, ac_expr(V1), ac_expr(V0), ac_expr(V2)); + / disjunction SPACE? 'if' SPACE? disjunction SPACE? 'else' SPACE? expression { + return aste(If, LOC, ac_expr(V1), ac_expr(V0), ac_expr(V2)); } / pipe { return ac_expr(V0); } -# TODO: make it more pythonic lambdef <- - / 'lambda' SPACE list(',', NAME) _ ':' _ expression { - return ast(LOC, - VS.transform(0, VS.size() - 1), ac_expr(VS.back()) + / 'lambda' SPACE lparams _ ':' _ expression { + return aste(Lambda, LOC, + ac(V0).transform(), ac_expr(V1) ); } / 'lambda' _ ':' _ expression { - return ast(LOC, vector{}, ac_expr(VS.back())); + return aste(Lambda, LOC, vector{}, ac_expr(V0)); + } +lparams <- tlist(',', lparam)? { return VS; } +lparam <- param_name (_ '=' _ expression)? { + return Param(LOC, ac(V0), nullptr, VS.size() > 1 ? ac_expr(V1) : nullptr); } pipe <- / disjunction (_ <'|>' / '||>'> _ disjunction)+ { - vector v; + vector v; for (int i = 0; i < VS.size(); i++) - v.push_back(PipeExpr::Pipe{i ? VS.token_to_string(i - 1) : "", ac_expr(VS[i])}); - return ast(LOC, move(v)); + v.push_back(Pipe{i ? VS.token_to_string(i - 1) : "", ac_expr(VS[i])}); + return aste(Pipe, LOC, std::move(v)); } / disjunction { return ac_expr(V0); } disjunction <- - / conjunction (SPACE 'or' SPACE conjunction)+ { - auto b = ast(LOC, ac_expr(V0), "||", ac_expr(V1)); + / conjunction (SPACE? 'or' SPACE? conjunction)+ { + auto b = aste(Binary, LOC, ac_expr(V0), "||", ac_expr(V1)); for (int i = 2; i < VS.size(); i++) - b = ast(LOC, b, "||", ac_expr(VS[i])); + b = aste(Binary, LOC, b, "||", ac_expr(VS[i])); return b; } / conjunction { return ac_expr(V0); } conjunction <- - / inversion (SPACE 'and' SPACE inversion)+ { - auto b = ast(LOC, ac_expr(V0), "&&", ac_expr(V1)); + / inversion (SPACE? 'and' SPACE? inversion)+ { + auto b = aste(Binary, LOC, ac_expr(V0), "&&", ac_expr(V1)); for (int i = 2; i < VS.size(); i++) - b = ast(LOC, b, "&&", ac_expr(VS[i])); + b = aste(Binary, LOC, b, "&&", ac_expr(VS[i])); return b; } / inversion { return ac_expr(V0); } inversion <- - / 'not' SPACE inversion { return ast(LOC, "!", ac_expr(V0)); } + / 'not' SPACE inversion { return aste(Unary, LOC, "!", ac_expr(V0)); } / comparison { return ac_expr(V0); } comparison <- bitwise_or compare_op_bitwise_or* { if (VS.size() == 1) { return ac_expr(V0); } else if (VS.size() == 2) { - auto p = ac>(V1); - return ast(LOC, ac_expr(V0), p.first, p.second); + auto p = ac>(V1); + return aste(Binary, LOC, ac_expr(V0), p.first, p.second); } else { - vector> v{pair(string(), ac_expr(V0))}; - auto vp = VS.transform>(1); + vector> v{pair(string(), ac_expr(V0))}; + auto vp = VS.transform>(1); v.insert(v.end(), vp.begin(), vp.end()); - return ast(LOC, move(v)); + return aste(ChainBinary, LOC, std::move(v)); } } compare_op_bitwise_or <- @@ -508,184 +538,190 @@ compare_op_bitwise_or <- / _ <'==' / '!=' / '<=' / '<' / '>=' / '>'> _ bitwise_or { return pair(VS.token_to_string(), ac_expr(V0)); } -bitwise_or <- bitwise_xor (_ <'|'> _ bitwise_xor)* { return chain(VS, LOC); } -bitwise_xor <- bitwise_and (_ <'^'> _ bitwise_and)* { return chain(VS, LOC); } -bitwise_and <- shift_expr (_ <'&'> _ shift_expr )* { return chain(VS, LOC); } -shift_expr <- sum (_ <'<<' / '>>'> _ sum )* { return chain(VS, LOC); } -sum <- term (_ <'+' / '-'> _ term)* { return chain(VS, LOC); } -term <- factor (_ <'*' / '//' / '/' / '%' / '@'> _ factor)* { return chain(VS, LOC); } +bitwise_or <- bitwise_xor (_ <'|'> _ bitwise_xor)* { return chain(CTX, VS, LOC); } +bitwise_xor <- bitwise_and (_ <'^'> _ bitwise_and)* { return chain(CTX, VS, LOC); } +bitwise_and <- shift_expr (_ <'&'> _ shift_expr )* { return chain(CTX, VS, LOC); } +shift_expr <- sum (_ <'<<' / '>>'> _ sum )* { return chain(CTX, VS, LOC); } +sum <- term (_ <'+' / '-'> _ term)* { return chain(CTX, VS, LOC); } +term <- factor (_ <'*' / '//' / '/' / '%' / '@'> _ factor)* { return chain(CTX, VS, LOC); } factor <- / <'+' / '-' / '~'> _ factor { - return ast(LOC, VS.token_to_string(), ac_expr(V0)); + return aste(Unary, LOC, VS.token_to_string(), ac_expr(V0)); } / power { return ac_expr(V0); } power <- / primary _ <'**'> _ factor { - return ast(LOC, ac_expr(V0), "**", ac_expr(V1)); + return aste(Binary, LOC, ac_expr(V0), "**", ac_expr(V1)); } / primary { return ac_expr(V0); } primary <- atom (_ primary_tail)* { - auto e = ac(V0); + auto e = ac(V0); for (int i = 1; i < VS.size(); i++) { auto p = ac>(VS[i]); if (p.first == 0) - e = ast(LOC, e, ac(p.second)); + e = aste(Dot, LOC, e, ac(p.second)); else if (p.first == 1) - e = ast(LOC, e, ac_expr(p.second)); + e = aste(Call, LOC, e, ac_expr(p.second)); else if (p.first == 2) - e = ast(LOC, e, ac>(p.second)); + e = aste(Call, LOC, e, ac>(p.second)); else - e = ast(LOC, e, ac_expr(p.second)); + e = aste(Index, LOC, e, ac_expr(p.second)); } return e; } primary_tail <- / '.' _ NAME { return pair(0, V0); } / genexp { return pair(1, V0); } - / arguments { return pair(2, VS.size() ? V0 : any(vector{})); } + / arguments { return pair(2, VS.size() ? V0 : any(vector{})); } / slices { return pair(3, V0); } -slices <- '[' _ tlist(',', slice) _ ']' { return wrap_tuple(VS, LOC); } +slices <- '[' _ tlist(',', slice) _ ']' { return wrap_tuple(CTX, VS, LOC); } slice <- / slice_part _ ':' _ slice_part (_ ':' _ slice_part)? { - return ast(LOC, + return aste(Slice, LOC, ac_expr(V0), ac_expr(V1), VS.size() > 2 ? ac_expr(V2) : nullptr ); } / expression { return ac_expr(V0); } -slice_part <- expression? { return VS.size() ? V0 : make_any(nullptr); } +slice_part <- expression? { return VS.size() ? V0 : make_any(nullptr); } atom <- / STRING (SPACE STRING)* { - return ast(LOC, VS.transform>()); + auto e = aste(String, LOC, VS.transform()); + return e; } / id { return ac_expr(V0); } - / 'True' { return ast(LOC, true); } - / 'False' { return ast(LOC, false);} - / 'None' { return ast(LOC); } + / 'True' { return aste(Bool, LOC, true); } + / 'False' { return aste(Bool, LOC, false);} + / 'None' { return aste(None, LOC); } / INT _ '...' _ INT { - return ast(LOC, - ast(LOC, ac(V0)), ast(LOC, ac(V1)) + return aste(Range, LOC, + aste(Int, LOC, ac(V0)), aste(Int, LOC, ac(V1)) ); } / FLOAT NAME? { - return ast(LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); + return aste(Float, LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); } / INT NAME? { - return ast(LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); + return aste(Int, LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); } / parentheses { return ac_expr(V0); } - / '...' { return ast(LOC); } + / '...' { return aste(Ellipsis, LOC); } parentheses <- ( tuple / yield / named / genexp / listexpr / listcomp / dict / set / dictcomp / setcomp ) tuple <- - / '(' _ ')' { return ast(LOC, VS.transform()); } - / '(' _ tlist(',', star_named_expression) _ ')' { return wrap_tuple(VS, LOC); } -yield <- '(' _ 'yield' _ ')' { return ast(LOC); } + / '(' _ ')' { return aste(Tuple, LOC, VS.transform()); } + / '(' _ tlist(',', star_named_expression) _ ')' { return wrap_tuple(CTX, VS, LOC); } +yield <- '(' _ 'yield' _ ')' { return aste(Yield, LOC); } named <- '(' _ named_expression _ ')' genexp <- '(' _ named_expression SPACE for_if_clauses _ ')' { - return ast(LOC, - GeneratorExpr::Generator, ac_expr(V0), ac(V1).transform() + return aste(Generator, + LOC, CTX.cache, GeneratorExpr::Generator, ac_expr(V0), ac>(V1) ); } listexpr <- '[' _ tlist(',', star_named_expression)? _ ']' { - return ast(LOC, VS.transform()); + return aste(List, LOC, VS.transform()); } listcomp <- '[' _ named_expression SPACE for_if_clauses _ ']' { - return ast(LOC, - GeneratorExpr::ListGenerator, - ac_expr(V0), - ac(V1).transform() + return aste(Generator, + LOC, CTX.cache, GeneratorExpr::ListGenerator, ac_expr(V0), ac>(V1) ); } set <- '{' _ tlist(',', star_named_expression) _ '}' { - return ast(LOC, VS.transform()); + return aste(Set, LOC, VS.transform()); } setcomp <- '{' _ named_expression SPACE for_if_clauses _ '}' { - return ast(LOC, - GeneratorExpr::SetGenerator, - ac_expr(V0), - ac(V1).transform() + return aste(Generator, + LOC, CTX.cache, GeneratorExpr::SetGenerator, ac_expr(V0), ac>(V1) ); } dict <- '{' _ tlist(',', double_starred_kvpair)? _ '}' { - return ast(LOC, VS.transform()); + return aste(Dict, LOC, VS.transform()); } dictcomp <- '{' _ kvpair SPACE for_if_clauses _ '}' { - auto p = ac(V0); - return ast(LOC, - p->getTuple()->items[0], p->getTuple()->items[1], - ac(V1).transform() + auto p = ac(V0); + return aste(Generator, + LOC, CTX.cache, (*cast(p))[0], (*cast(p))[1], ac>(V1) ); } double_starred_kvpair <- / '**' _ bitwise_or { - return ast(LOC, ac_expr(V0)); + return aste(KeywordStar, LOC, ac_expr(V0)); } - / kvpair { return ac(V0); } + / kvpair { return ac(V0); } kvpair <- expression _ ':' _ expression { - return ast(LOC, std::vector{ac_expr(V0), ac_expr(V1)}); + return aste(Tuple, LOC, std::vector{ac_expr(V0), ac_expr(V1)}); +} +for_if_clauses <- for_if_clause (SPACE for_if_clause)* { + std::vector v = ac>(V0); + auto tail = VS.transform>(1); + for (auto &t: tail) + v.insert(v.end(), t.begin(), t.end()); + return v; } -for_if_clauses <- for_if_clause (SPACE for_if_clause)* { return VS; } for_if_clause <- 'for' SPACE star_targets SPACE 'in' SPACE disjunction - (SPACE 'if' SPACE disjunction)* { - return GeneratorBody{ac_expr(V0), ac_expr(V1), VS.transform(2)}; + (SPACE? 'if' SPACE? disjunction)* { + std::vector v{asts(For, LOC, ac_expr(V0), ac_expr(V1), nullptr)}; + auto tail = VS.transform(2); + for (auto &t: tail) + v.push_back(asts(If, LOC, t, nullptr)); + return v; } -star_targets <- tlist(',', star_target) { return wrap_tuple(VS, LOC); } +star_targets <- tlist(',', star_target) { return wrap_tuple(CTX, VS, LOC); } star_target <- - / '*' _ !'*' star_target { return ast(LOC, ac_expr(V0)); } + / '*' _ !'*' star_target { return aste(Star, LOC, ac_expr(V0)); } / star_parens { return ac_expr(V0); } / primary { return ac_expr(V0); } star_parens <- - / '(' _ tlist(',', star_target) _ ')' { return wrap_tuple(VS, LOC); } - / '[' _ tlist(',', star_target) _ ']' { return wrap_tuple(VS, LOC); } + / '(' _ tlist(',', star_target) _ ')' { return wrap_tuple(CTX, VS, LOC); } + / '[' _ tlist(',', star_target) _ ']' { return wrap_tuple(CTX, VS, LOC); } -star_expressions <- tlist(',', star_expression) { return wrap_tuple(VS, LOC); } +star_expressions <- tlist(',', star_expression) { return wrap_tuple(CTX, VS, LOC); } star_expression <- - / '*' _ bitwise_or { return ast(LOC, ac_expr(V0)); } + / '*' _ bitwise_or { return aste(Star, LOC, ac_expr(V0)); } / expression { return ac_expr(V0); } star_named_expression <- - / '*' _ bitwise_or { return ast(LOC, ac_expr(V0)); } + / '*' _ bitwise_or { return aste(Star, LOC, ac_expr(V0)); } / named_expression { return ac_expr(V0); } named_expression <- / NAME _ ':=' _ ^ expression { - return ast(LOC, ast(LOC, ac(V0)), ac_expr(V1)); + return aste(Assign, LOC, aste(Id, LOC, ac(V0)), ac_expr(V1)); } / expression !(_ ':=') { return ac_expr(V0); } arguments <- '(' _ tlist(',', args)? _ ')' { - vector result; + vector result; for (auto &v: VS) - for (auto &i: ac>(v)) + for (auto &i: ac>(v)) result.push_back(i); return result; } args <- (simple_args (_ ',' _ kwargs)? / kwargs) { - auto args = ac>(V0); + auto args = ac>(V0); if (VS.size() > 1) { - auto v = ac>(V1); + auto v = ac>(V1); args.insert(args.end(), v.begin(), v.end()); } return args; } simple_args <- list(',', (starred_expression / named_expression !(_ '='))) { - return vmap(VS, [](auto &i) { return CallExpr::Arg(ac_expr(i)); }); + return vmap(VS, [](auto &i) { return CallArg(ac_expr(i)); }); } -starred_expression <- '*' _ expression { return ast(LOC, ac_expr(V0)); } +starred_expression <- '*' _ expression { return aste(Star, LOC, ac_expr(V0)); } kwargs <- / list(',', kwarg_or_starred) _ ',' _ list(',', kwarg_or_double_starred) { - return VS.transform(); + return VS.transform(); } - / list(',', kwarg_or_starred) { return VS.transform(); } - / list(',', kwarg_or_double_starred) { return VS.transform(); } + / list(',', kwarg_or_starred) { return VS.transform(); } + / list(',', kwarg_or_double_starred) { return VS.transform(); } kwarg_or_starred <- - / NAME _ '=' _ expression { return CallExpr::Arg(LOC, ac(V0), ac_expr(V1)); } - / starred_expression { return CallExpr::Arg(ac_expr(V0)); } + / NAME _ '=' _ expression { return CallArg(LOC, ac(V0), ac_expr(V1)); } + / starred_expression { return CallArg(ac_expr(V0)); } kwarg_or_double_starred <- - / NAME _ '=' _ expression { return CallExpr::Arg(LOC, ac(V0), ac_expr(V1)); } + / NAME _ '=' _ expression { return CallArg(LOC, ac(V0), ac_expr(V1)); } / '**' _ expression { - return CallExpr::Arg(ast(LOC, ac_expr(V0))); + return CallArg(aste(KeywordStar, LOC, ac_expr(V0))); } -id <- NAME { return ast(LOC, ac(V0)); } +id <- NAME { return aste(Id, LOC, ac(V0)); } INT <- (BININT / HEXINT / DECINT) { return string(VS.sv()); } BININT <- <'0' [bB] [0-1] ('_'* [0-1])*> HEXINT <- <'0' [xX] [0-9a-fA-F] ('_'? [0-9a-fA-F])*> @@ -697,25 +733,25 @@ NAME <- / keyword [a-zA-Z_0-9]+ { return string(VS.sv()); } / !keyword <[a-zA-Z_] [a-zA-Z_0-9]*> { return VS.token_to_string(); } STRING <- { - auto p = pair( + auto p = StringExpr::String( ac(VS.size() > 1 ? V1 : V0), VS.size() > 1 ? ac(V0) : "" ); - if (p.second != "r" && p.second != "R") { - p.first = unescape(p.first); + if (p.prefix != "r" && p.prefix != "R") { + p.value = unescape(p.value); } else { - p.second = ""; + p.prefix = ""; } return p; } STRING__PREDICATE { - auto p = pair( + auto p = StringExpr::String( ac(VS.size() > 1 ? V1 : V0), VS.size() > 1 ? ac(V0) : "" ); - if (p.second != "r" && p.second != "R") + if (p.prefix != "r" && p.prefix != "R") try { - p.first = unescape(p.first); + p.value = unescape(p.value); } catch (std::invalid_argument &e) { MSG = "invalid code in a string"; return false; @@ -777,6 +813,6 @@ EXTERNDENT__PREDICATE { > # https://docs.python.org/3/library/string.html#formatspec -format_spec <- ([<>=^] / [^{}] [<>=^])? [+-]? 'z'? '#'? '0'? [0-9]* [_,]* ('.' [0-9]+)? [bcdeEfFgGnosxX%]? { +format_spec <- ([^{}] [<>=^] / [<>=^])? [+- ]? 'z'? '#'? '0'? [0-9]* [_,]* ('.' [0-9]+)? [bcdeEfFgGnosxX%]? { return string(VS.sv()); } diff --git a/codon/parser/peg/openmp.peg b/codon/parser/peg/openmp.peg index fb0a31cf..61a426fe 100644 --- a/codon/parser/peg/openmp.peg +++ b/codon/parser/peg/openmp.peg @@ -10,34 +10,42 @@ PREAMBLE { #define V0 VS[0] #define V1 VS[1] #define ac std::any_cast + + template + T *setSI(ASTNode *n, const codon::SrcInfo &s) { + n->setSrcInfo(s); + return (T*)n; + } + #define ast(T, s, ...) setSI(CTX.cache->N(__VA_ARGS__), s) } pragma <- "omp"? _ "parallel"? _ (clause _)* { - vector v; + vector v; for (auto &i: VS) { - auto vi = ac>(i); + auto vi = ac>(i); v.insert(v.end(), vi.begin(), vi.end()); } return v; } clause <- / "schedule" _ "(" _ schedule_kind (_ "," _ int)? _ ")" { - vector v{{"schedule", make_shared(ac(V0))}}; + // CTX; + vector v{{"schedule", ast(String, LOC, ac(V0))}}; if (VS.size() > 1) - v.push_back({"chunk_size", make_shared(ac(V1))}); + v.push_back({"chunk_size", ast(Int, LOC, ac(V1))}); return v; } / "num_threads" _ "(" _ int _ ")" { - return vector{{"num_threads", make_shared(ac(V0))}}; + return vector{{"num_threads", ast(Int, LOC, ac(V0))}}; } / "ordered" { - return vector{{"ordered", make_shared(true)}}; + return vector{{"ordered", ast(Bool, LOC, true)}}; } / "collapse" { - return vector{{"collapse", make_shared(ac(V0))}}; + return vector{{"collapse", ast(Int, LOC, ac(V0))}}; } / "gpu" { - return vector{{"gpu", make_shared(true)}}; + return vector{{"gpu", ast(Bool, LOC, true)}}; } schedule_kind <- ("static" / "dynamic" / "guided" / "auto" / "runtime") { return VS.token_to_string(); @@ -46,7 +54,7 @@ int <- [1-9] [0-9]* { return stoi(VS.token_to_string()); } # ident <- [a-zA-Z_] [a-zA-Z_0-9]* { -# return make_shared(VS.token_to_string()); +# return ast(VS.token_to_string()); # } ~SPACE <- [ \t]+ ~_ <- SPACE* diff --git a/codon/parser/peg/peg.cpp b/codon/parser/peg/peg.cpp index ac9a04b5..0d268458 100644 --- a/codon/parser/peg/peg.cpp +++ b/codon/parser/peg/peg.cpp @@ -53,15 +53,16 @@ std::shared_ptr initParser() { } template -T parseCode(Cache *cache, const std::string &file, const std::string &code, - int line_offset, int col_offset, const std::string &rule) { +llvm::Expected parseCode(Cache *cache, const std::string &file, + const std::string &code, int line_offset, int col_offset, + const std::string &rule) { Timer t(""); t.logged = true; // Initialize if (!grammar) grammar = initParser(); - std::vector> errors; + std::vector errors; auto log = [&](size_t line, size_t col, const std::string &msg, const std::string &) { size_t ed = msg.size(); if (startswith(msg, "syntax error, unexpected")) { @@ -69,7 +70,7 @@ T parseCode(Cache *cache, const std::string &file, const std::string &code, if (i != std::string::npos) ed = i; } - errors.emplace_back(line, col, msg.substr(0, ed)); + errors.emplace_back(msg.substr(0, ed), file, line, col); }; T result; auto ctx = std::make_any(cache, 0, line_offset, col_offset); @@ -79,33 +80,27 @@ T parseCode(Cache *cache, const std::string &file, const std::string &code, if (!ret) r.error_info.output_log(log, code.c_str(), code.size()); totalPeg += t.elapsed(); - exc::ParserException ex; - if (!errors.empty()) { - for (auto &e : errors) - ex.track(fmt::format("{}", std::get<2>(e)), - SrcInfo(file, std::get<0>(e), std::get<1>(e), 0)); - throw ex; - return T(); - } + + if (!errors.empty()) + return llvm::make_error(errors); return result; } -StmtPtr parseCode(Cache *cache, const std::string &file, const std::string &code, - int line_offset) { - return parseCode(cache, file, code + "\n", line_offset, 0, "program"); +llvm::Expected parseCode(Cache *cache, const std::string &file, + const std::string &code, int line_offset) { + return parseCode(cache, file, code + "\n", line_offset, 0, "program"); } -std::pair parseExpr(Cache *cache, const std::string &code, - const codon::SrcInfo &offset) { +llvm::Expected> +parseExpr(Cache *cache, const std::string &code, const codon::SrcInfo &offset) { auto newCode = code; ltrim(newCode); rtrim(newCode); - auto e = parseCode>( + return parseCode>( cache, offset.file, newCode, offset.line, offset.col, "fstring"); - return e; } -StmtPtr parseFile(Cache *cache, const std::string &file) { +llvm::Expected parseFile(Cache *cache, const std::string &file) { std::vector lines; std::string code; if (file == "-") { @@ -116,7 +111,8 @@ StmtPtr parseFile(Cache *cache, const std::string &file) { } else { std::ifstream fin(file); if (!fin) - E(error::Error::COMPILER_NO_FILE, SrcInfo(), file); + return llvm::make_error(error::Error::COMPILER_NO_FILE, + SrcInfo(), file); for (std::string line; getline(fin, line);) { lines.push_back(line); code += line + "\n"; @@ -143,27 +139,24 @@ std::shared_ptr initOpenMPParser() { return g; } -std::vector parseOpenMP(Cache *cache, const std::string &code, - const codon::SrcInfo &loc) { +llvm::Expected> parseOpenMP(Cache *cache, const std::string &code, + const codon::SrcInfo &loc) { if (!ompGrammar) ompGrammar = initOpenMPParser(); - std::vector> errors; + std::vector errors; auto log = [&](size_t line, size_t col, const std::string &msg, const std::string &) { - errors.emplace_back(line, col, msg); + errors.emplace_back(fmt::format("openmp: {}", msg), loc.file, loc.line, loc.col); }; - std::vector result; + std::vector result; auto ctx = std::make_any(cache, 0, 0, 0); auto r = (*ompGrammar)["pragma"].parse_and_get_value(code.c_str(), code.size(), ctx, result, "", log); auto ret = r.ret && r.len == code.size(); if (!ret) r.error_info.output_log(log, code.c_str(), code.size()); - exc::ParserException ex; - if (!errors.empty()) { - ex.track(fmt::format("openmp {}", std::get<2>(errors[0])), loc); - throw ex; - } + if (!errors.empty()) + return llvm::make_error(errors); return result; } diff --git a/codon/parser/peg/peg.h b/codon/parser/peg/peg.h index edd8e7a3..75715485 100644 --- a/codon/parser/peg/peg.h +++ b/codon/parser/peg/peg.h @@ -13,18 +13,18 @@ namespace codon::ast { /// Parse a Seq code block with the appropriate file and position offsets. -StmtPtr parseCode(Cache *cache, const std::string &file, const std::string &code, - int line_offset = 0); +llvm::Expected parseCode(Cache *cache, const std::string &file, + const std::string &code, int line_offset = 0); /// Parse a Seq code expression. -/// @return pair of ExprPtr and a string indicating format specification +/// @return pair of Expr * and a format specification /// (empty if not available). -std::pair parseExpr(Cache *cache, const std::string &code, - const codon::SrcInfo &offset); +llvm::Expected> +parseExpr(Cache *cache, const std::string &code, const codon::SrcInfo &offset); /// Parse a Seq file. -StmtPtr parseFile(Cache *cache, const std::string &file); +llvm::Expected parseFile(Cache *cache, const std::string &file); /// Parse a OpenMP clause. -std::vector parseOpenMP(Cache *cache, const std::string &code, - const codon::SrcInfo &loc); +llvm::Expected> parseOpenMP(Cache *cache, const std::string &code, + const codon::SrcInfo &loc); } // namespace codon::ast diff --git a/codon/parser/visitors/doc/doc.cpp b/codon/parser/visitors/doc/doc.cpp index 39687d8d..507e90eb 100644 --- a/codon/parser/visitors/doc/doc.cpp +++ b/codon/parser/visitors/doc/doc.cpp @@ -9,6 +9,7 @@ #include "codon/parser/ast.h" #include "codon/parser/common.h" +#include "codon/parser/match.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/format/format.h" @@ -16,6 +17,9 @@ using fmt::format; namespace codon::ast { +using namespace error; +using namespace matcher; + // clang-format off std::string json_escape(const std::string &str) { std::string r; @@ -96,24 +100,30 @@ std::shared_ptr DocVisitor::apply(const std::string &argv0, shared->j = std::make_shared(); auto stdlib = getImportFile(argv0, STDLIB_INTERNAL_MODULE, "", true, ""); - auto ast = ast::parseFile(shared->cache, stdlib->path); - auto core = + auto astOrErr = ast::parseFile(shared->cache, stdlib->path); + if (!astOrErr) + throw exc::ParserException(astOrErr.takeError()); + auto coreOrErr = ast::parseCode(shared->cache, stdlib->path, "from internal.core import *"); + if (!coreOrErr) + throw exc::ParserException(coreOrErr.takeError()); shared->modules[""]->setFilename(stdlib->path); shared->modules[""]->add("__py_numerics__", std::make_shared(shared->itemID++)); shared->modules[""]->add("__py_extension__", std::make_shared(shared->itemID++)); shared->modules[""]->add("__debug__", std::make_shared(shared->itemID++)); shared->modules[""]->add("__apple__", std::make_shared(shared->itemID++)); - DocVisitor(shared->modules[""]).transformModule(std::move(core)); - DocVisitor(shared->modules[""]).transformModule(std::move(ast)); + DocVisitor(shared->modules[""]).transformModule(*coreOrErr); + DocVisitor(shared->modules[""]).transformModule(*astOrErr); auto ctx = std::make_shared(shared); for (auto &f : files) { auto path = getAbsolutePath(f); ctx->setFilename(path); LOG("-> parsing {}", path); - auto ast = ast::parseFile(shared->cache, path); - DocVisitor(ctx).transformModule(std::move(ast)); + auto astOrErr = ast::parseFile(shared->cache, path); + if (!astOrErr) + throw exc::ParserException(astOrErr.takeError()); + DocVisitor(ctx).transformModule(*astOrErr); } shared->cache = nullptr; @@ -127,18 +137,18 @@ std::shared_ptr DocContext::find(const std::string &s) const { return i; } -std::string getDocstr(const StmtPtr &s) { - if (auto se = s->getExpr()) - if (auto e = se->expr->getString()) +std::string getDocstr(Stmt *s) { + if (auto se = cast(s)) + if (auto e = cast(se->getExpr())) return e->getValue(); return ""; } -std::vector DocVisitor::flatten(StmtPtr stmt, std::string *docstr, bool deep) { - std::vector stmts; - if (auto s = stmt->getSuite()) { - for (int i = 0; i < (deep ? s->stmts.size() : 1); i++) { - for (auto &x : flatten(std::move(s->stmts[i]), i ? nullptr : docstr, deep)) +std::vector DocVisitor::flatten(Stmt *stmt, std::string *docstr, bool deep) { + std::vector stmts; + if (auto s = cast(stmt)) { + for (int i = 0; i < (deep ? s->size() : 1); i++) { + for (auto &x : flatten((*s)[i], i ? nullptr : docstr, deep)) stmts.push_back(std::move(x)); } } else { @@ -149,7 +159,7 @@ std::vector DocVisitor::flatten(StmtPtr stmt, std::string *docstr, bool return stmts; } -std::shared_ptr DocVisitor::transform(const ExprPtr &expr) { +std::shared_ptr DocVisitor::transform(Expr *expr) { if (!expr) return std::make_shared(); DocVisitor v(ctx); @@ -159,7 +169,7 @@ std::shared_ptr DocVisitor::transform(const ExprPtr &expr) { return v.resultExpr; } -std::string DocVisitor::transform(const StmtPtr &stmt) { +std::string DocVisitor::transform(Stmt *stmt) { if (!stmt) return ""; DocVisitor v(ctx); @@ -168,7 +178,7 @@ std::string DocVisitor::transform(const StmtPtr &stmt) { return v.resultStmt; } -void DocVisitor::transformModule(StmtPtr stmt) { +void DocVisitor::transformModule(Stmt *stmt) { std::vector children; std::string docstr; @@ -178,7 +188,7 @@ void DocVisitor::transformModule(StmtPtr stmt) { auto id = transform(s); if (id.empty()) continue; - if (i < (flat.size() - 1) && CAST(s, AssignStmt)) { + if (i < (flat.size() - 1) && cast(s)) { auto ds = getDocstr(flat[i + 1]); if (!ds.empty()) ctx->shared->j->get(id)->set("doc", ds); @@ -196,29 +206,30 @@ void DocVisitor::transformModule(StmtPtr stmt) { } void DocVisitor::visit(IntExpr *expr) { - resultExpr = std::make_shared(expr->value); + auto [value, _] = expr->getRawData(); + resultExpr = std::make_shared(value); } void DocVisitor::visit(IdExpr *expr) { - auto i = ctx->find(expr->value); + auto i = ctx->find(expr->getValue()); if (!i) - error("unknown identifier {}", expr->value); - resultExpr = std::make_shared(*i ? std::to_string(*i) : expr->value); + E(Error::CUSTOM, expr->getSrcInfo(), "unknown identifier {}", expr->getValue()); + resultExpr = std::make_shared(*i ? std::to_string(*i) : expr->getValue()); } void DocVisitor::visit(IndexExpr *expr) { std::vector> v; - v.push_back(transform(expr->expr)); - if (auto tp = CAST(expr->index, TupleExpr)) { - if (auto l = tp->items[0]->getList()) { - for (auto &e : l->items) + v.push_back(transform(expr->getExpr())); + if (auto tp = cast(expr->getIndex())) { + if (auto l = cast((*tp)[0])) { + for (auto *e : *l) v.push_back(transform(e)); - v.push_back(transform(tp->items[1])); + v.push_back(transform((*tp)[1])); } else - for (auto &e : tp->items) + for (auto *e : *tp) v.push_back(transform(e)); } else { - v.push_back(transform(expr->index)); + v.push_back(transform(expr->getIndex())); } resultExpr = std::make_shared(v); } @@ -233,21 +244,21 @@ bool isValidName(const std::string &s) { void DocVisitor::visit(FunctionStmt *stmt) { int id = ctx->shared->itemID++; - ctx->add(stmt->name, std::make_shared(id)); + ctx->add(stmt->getName(), std::make_shared(id)); auto j = std::make_shared(std::unordered_map{ - {"kind", "function"}, {"name", stmt->name}}); + {"kind", "function"}, {"name", stmt->getName()}}); j->set("pos", jsonify(stmt->getSrcInfo())); std::vector> args; std::vector generics; - for (auto &a : stmt->args) - if (a.status != Param::Normal) { + for (auto &a : *stmt) + if (!a.isValue()) { ctx->add(a.name, std::make_shared(0)); generics.push_back(a.name); a.status = Param::Generic; } - for (auto &a : stmt->args) - if (a.status == Param::Normal) { + for (auto &a : *stmt) + if (a.isValue()) { auto j = std::make_shared(); j->set("name", a.name); if (a.type) @@ -259,16 +270,16 @@ void DocVisitor::visit(FunctionStmt *stmt) { } j->set("generics", std::make_shared(generics)); bool isLLVM = false; - for (auto &d : stmt->decorators) - if (auto e = d->getId()) { - j->set("attrs", std::make_shared(e->value, "")); - isLLVM |= (e->value == "llvm"); + for (auto &d : stmt->getDecorators()) + if (auto e = cast(d)) { + j->set("attrs", std::make_shared(e->getValue(), "")); + isLLVM |= (e->getValue() == "llvm"); } - if (stmt->ret) - j->set("return", transform(stmt->ret)); + if (stmt->getReturn()) + j->set("return", transform(stmt->getReturn())); j->set("args", std::make_shared(args)); std::string docstr; - flatten(std::move(stmt->suite), &docstr); + flatten(stmt->getSuite(), &docstr); for (auto &g : generics) ctx->remove(g); if (!docstr.empty() && !isLLVM) @@ -280,36 +291,35 @@ void DocVisitor::visit(FunctionStmt *stmt) { void DocVisitor::visit(ClassStmt *stmt) { std::vector generics; auto j = std::make_shared(std::unordered_map{ - {"name", stmt->name}, + {"name", stmt->getName()}, {"kind", "class"}, {"type", stmt->isRecord() ? "type" : "class"}}); int id = ctx->shared->itemID++; bool isExtend = false; - for (auto &d : stmt->decorators) - if (auto e = d->getId()) - isExtend |= (e->value == "extend"); + for (auto &d : stmt->getDecorators()) + if (auto e = cast(d)) + isExtend |= (e->getValue() == "extend"); if (isExtend) { j->set("type", "extension"); - auto i = ctx->find(stmt->name); + auto i = ctx->find(stmt->getName()); j->set("parent", std::to_string(*i)); generics = ctx->shared->generics[*i]; } else { - ctx->add(stmt->name, std::make_shared(id)); + ctx->add(stmt->getName(), std::make_shared(id)); } std::vector> args; - for (auto &a : stmt->args) - if (a.status != Param::Normal) { - a.status = Param::Generic; + for (const auto &a : *stmt) + if (!a.isValue()) { generics.push_back(a.name); } ctx->shared->generics[id] = generics; for (auto &g : generics) ctx->add(g, std::make_shared(0)); - for (auto &a : stmt->args) - if (a.status == Param::Normal) { + for (const auto &a : *stmt) + if (a.isValue()) { auto ja = std::make_shared(); ja->set("name", a.name); if (a.type) @@ -322,13 +332,13 @@ void DocVisitor::visit(ClassStmt *stmt) { std::string docstr; std::vector members; - for (auto &f : flatten(std::move(stmt->suite), &docstr)) { - if (auto ff = CAST(f, FunctionStmt)) { + for (auto &f : flatten(stmt->getSuite(), &docstr)) { + if (auto ff = cast(f)) { auto i = transform(f); if (i != "") members.push_back(i); - if (isValidName(ff->name)) - ctx->remove(ff->name); + if (isValidName(ff->getName())) + ctx->remove(ff->getName()); } } for (auto &g : generics) @@ -346,25 +356,27 @@ std::shared_ptr DocVisitor::jsonify(const codon::SrcInfo &s) { } void DocVisitor::visit(ImportStmt *stmt) { - if (stmt->from && (stmt->from->isId("C") || stmt->from->isId("python"))) { + if (match(stmt->getFrom(), M(MOr("C", "python")))) { int id = ctx->shared->itemID++; std::string name, lib; - if (auto i = stmt->what->getId()) - name = i->value; - else if (auto d = stmt->what->getDot()) - name = d->member, lib = FormatVisitor::apply(d->expr); + if (auto i = cast(stmt->getWhat())) + name = i->getValue(); + else if (auto d = cast(stmt->getWhat())) + name = d->getMember(), lib = FormatVisitor::apply(d->getExpr()); else seqassert(false, "invalid C import statement"); ctx->add(name, std::make_shared(id)); - name = stmt->as.empty() ? name : stmt->as; + name = stmt->getAs().empty() ? name : stmt->getAs(); auto j = std::make_shared(std::unordered_map{ - {"name", name}, {"kind", "function"}, {"extern", stmt->from->getId()->value}}); + {"name", name}, + {"kind", "function"}, + {"extern", cast(stmt->getFrom())->getValue()}}); j->set("pos", jsonify(stmt->getSrcInfo())); std::vector> args; - if (stmt->ret) - j->set("return", transform(stmt->ret)); - for (auto &a : stmt->args) { + if (stmt->getReturnType()) + j->set("return", transform(stmt->getReturnType())); + for (const auto &a : stmt->getArgs()) { auto ja = std::make_shared(); ja->set("name", a.name); ja->set("type", transform(a.type)); @@ -378,22 +390,29 @@ void DocVisitor::visit(ImportStmt *stmt) { } std::vector dirs; // Path components - Expr *e = stmt->from.get(); - if (e) { - while (auto d = e->getDot()) { - dirs.push_back(d->member); - e = d->expr.get(); + Expr *e = stmt->getFrom(); + while (auto d = cast(e)) { + while (auto d = cast(e)) { + dirs.push_back(d->getMember()); + e = d->getExpr(); } - if (!e->getId() || !stmt->args.empty() || stmt->ret || - (stmt->what && !stmt->what->getId())) - error("invalid import statement"); + if (!cast(e) || !stmt->getArgs().empty() || stmt->getReturnType() || + (stmt->getWhat() && !cast(stmt->getWhat()))) + E(Error::CUSTOM, stmt->getSrcInfo(), "invalid import statement"); // We have an empty stmt->from in "from .. import". - if (!e->getId()->value.empty()) - dirs.push_back(e->getId()->value); + if (!cast(e)->getValue().empty()) + dirs.push_back(cast(e)->getValue()); } + auto ee = cast(e); + if (!ee || !stmt->getArgs().empty() || stmt->getReturnType() || + (stmt->getWhat() && !cast(stmt->getWhat()))) + E(Error::CUSTOM, stmt->getSrcInfo(), "invalid import statement"); + // We have an empty stmt->from in "from .. import". + if (!ee->getValue().empty()) + dirs.push_back(ee->getValue()); // Handle dots (e.g. .. in from ..m import x). - seqassert(stmt->dots >= 0, "negative dots in ImportStmt"); - for (size_t i = 1; i < stmt->dots; i++) + seqassert(stmt->getDots() >= 0, "negative dots in ImportStmt"); + for (size_t i = 1; i < stmt->getDots(); i++) dirs.emplace_back(".."); std::string path; for (int i = int(dirs.size()) - 1; i >= 0; i--) @@ -401,7 +420,7 @@ void DocVisitor::visit(ImportStmt *stmt) { // Fetch the import! auto file = getImportFile(ctx->shared->argv0, path, ctx->getFilename()); if (!file) - error(stmt, "cannot locate import '{}'", path); + E(Error::CUSTOM, stmt->getSrcInfo(), "cannot locate import '{}'", path); auto ictx = ctx; auto it = ctx->shared->modules.find(file->path); @@ -409,37 +428,40 @@ void DocVisitor::visit(ImportStmt *stmt) { ctx->shared->modules[file->path] = ictx = std::make_shared(ctx->shared); ictx->setFilename(file->path); LOG("=> parsing {}", file->path); - auto tmp = parseFile(ctx->shared->cache, file->path); - DocVisitor(ictx).transformModule(std::move(tmp)); + auto tmpOrErr = parseFile(ctx->shared->cache, file->path); + if (!tmpOrErr) + throw exc::ParserException(tmpOrErr.takeError()); + DocVisitor(ictx).transformModule(*tmpOrErr); } else { ictx = it->second; } - if (!stmt->what) { + if (!stmt->getWhat()) { // TODO: implement this corner case for (auto &i : dirs) if (!ctx->find(i)) ctx->add(i, std::make_shared(ctx->shared->itemID++)); - } else if (stmt->what->isId("*")) { + } else if (isId(stmt->getWhat(), "*")) { for (auto &i : *ictx) ctx->add(i.first, i.second.front()); } else { - auto i = stmt->what->getId(); - if (auto c = ictx->find(i->value)) - ctx->add(stmt->as.empty() ? i->value : stmt->as, c); + auto i = cast(stmt->getWhat()); + if (auto c = ictx->find(i->getValue())) + ctx->add(stmt->getAs().empty() ? i->getValue() : stmt->getAs(), c); else - error(stmt, "symbol '{}' not found in {}", i->value, file->path); + E(Error::CUSTOM, stmt->getSrcInfo(), "symbol '{}' not found in {}", i->getValue(), + file->path); } } void DocVisitor::visit(AssignStmt *stmt) { - auto e = CAST(stmt->lhs, IdExpr); + auto e = cast(stmt->getLhs()); if (!e) return; int id = ctx->shared->itemID++; - ctx->add(e->value, std::make_shared(id)); + ctx->add(e->getValue(), std::make_shared(id)); auto j = std::make_shared(std::unordered_map{ - {"name", e->value}, {"kind", "variable"}}); + {"name", e->getValue()}, {"kind", "variable"}}); j->set("pos", jsonify(stmt->getSrcInfo())); ctx->shared->j->set(std::to_string(id), j); resultStmt = std::to_string(id); diff --git a/codon/parser/visitors/doc/doc.h b/codon/parser/visitors/doc/doc.h index 098c6a9d..e6bb869f 100644 --- a/codon/parser/visitors/doc/doc.h +++ b/codon/parser/visitors/doc/doc.h @@ -64,13 +64,13 @@ struct DocVisitor : public CallbackASTVisitor, std::string static std::shared_ptr apply(const std::string &argv0, const std::vector &files); - std::shared_ptr transform(const ExprPtr &e) override; - std::string transform(const StmtPtr &e) override; + std::shared_ptr transform(Expr *e) override; + std::string transform(Stmt *e) override; - void transformModule(StmtPtr stmt); + void transformModule(Stmt *stmt); std::shared_ptr jsonify(const codon::SrcInfo &s); - std::vector flatten(StmtPtr stmt, std::string *docstr = nullptr, - bool deep = true); + std::vector flatten(Stmt *stmt, std::string *docstr = nullptr, + bool deep = true); public: void visit(IntExpr *) override; diff --git a/codon/parser/visitors/format/format.cpp b/codon/parser/visitors/format/format.cpp index 20585360..f00ea370 100644 --- a/codon/parser/visitors/format/format.cpp +++ b/codon/parser/visitors/format/format.cpp @@ -10,51 +10,60 @@ using fmt::format; namespace codon { namespace ast { +std::string FormatVisitor::anchor_root(const std::string &s) const { + return fmt::format("{}", s, s); +} + +std::string FormatVisitor::anchor(const std::string &s) const { + return fmt::format("{}", s, s); +} + FormatVisitor::FormatVisitor(bool html, Cache *cache) : renderType(false), renderHTML(html), indent(0), cache(cache) { if (renderHTML) { - header = "\n"; + header = "\n"; header += "
\n"; footer = "\n
"; - nl = "
"; - typeStart = ""; - typeEnd = ""; - nodeStart = ""; - nodeEnd = ""; - exprStart = ""; - exprEnd = ""; - commentStart = ""; - commentEnd = ""; - keywordStart = ""; - keywordEnd = ""; - space = " "; + nl = "
"; + typeStart = ""; + typeEnd = ""; + nodeStart = ""; + nodeEnd = ""; + stmtStart = ""; + stmtEnd = ""; + exprStart = ""; + exprEnd = ""; + commentStart = ""; + commentEnd = ""; + literalStart = ""; + literalEnd = ""; + keywordStart = ""; + keywordEnd = ""; + space = " "; renderType = true; } else { space = " "; } } -std::string FormatVisitor::transform(const ExprPtr &expr) { - return transform(expr.get()); -} - -std::string FormatVisitor::transform(const Expr *expr) { +std::string FormatVisitor::transform(Expr *expr) { FormatVisitor v(renderHTML, cache); if (expr) - const_cast(expr)->accept(v); + expr->accept(v); return v.result; } -std::string FormatVisitor::transform(const StmtPtr &stmt) { - return transform(stmt.get(), 0); -} +std::string FormatVisitor::transform(Stmt *stmt) { return transform(stmt, 0); } std::string FormatVisitor::transform(Stmt *stmt, int indent) { FormatVisitor v(renderHTML, cache); v.indent = this->indent + indent; if (stmt) stmt->accept(v); - return (stmt && stmt->getSuite() ? "" : pad(indent)) + v.result + newline(); + if (v.result.empty()) + return ""; + return fmt::format("{}{}{}{}{}", stmtStart, cast(stmt) ? "" : pad(indent), + v.result, stmtEnd, newline()); } std::string FormatVisitor::pad(int indent) const { @@ -70,109 +79,97 @@ std::string FormatVisitor::keyword(const std::string &s) const { return fmt::format("{}{}{}", keywordStart, s, keywordEnd); } +std::string FormatVisitor::literal(const std::string &s) const { + return fmt::format("{}{}{}", literalStart, s, literalEnd); +} + /*************************************************************************************/ void FormatVisitor::visit(NoneExpr *expr) { result = renderExpr(expr, "None"); } void FormatVisitor::visit(BoolExpr *expr) { - result = renderExpr(expr, "{}", expr->value ? "True" : "False"); + result = renderExpr(expr, "{}", literal(expr->getValue() ? "True" : "False")); } void FormatVisitor::visit(IntExpr *expr) { - result = renderExpr(expr, "{}{}", expr->value, expr->suffix); + auto [value, suffix] = expr->getRawData(); + result = renderExpr(expr, "{}{}", literal(value), suffix); } void FormatVisitor::visit(FloatExpr *expr) { - result = renderExpr(expr, "{}{}", expr->value, expr->suffix); + auto [value, suffix] = expr->getRawData(); + result = renderExpr(expr, "{}{}", literal(value), suffix); } void FormatVisitor::visit(StringExpr *expr) { - result = renderExpr(expr, "\"{}\"", escape(expr->getValue())); + result = + renderExpr(expr, "{}", literal(fmt::format("\"{}\"", escape(expr->getValue())))); } void FormatVisitor::visit(IdExpr *expr) { - result = renderExpr(expr, "{}", expr->value); + result = renderExpr(expr, "{}", + expr->getType() && expr->getType()->getFunc() + ? anchor(expr->getValue()) + : expr->getValue()); } void FormatVisitor::visit(StarExpr *expr) { - result = renderExpr(expr, "*{}", transform(expr->what)); + result = renderExpr(expr, "*{}", transform(expr->getExpr())); } void FormatVisitor::visit(KeywordStarExpr *expr) { - result = renderExpr(expr, "**{}", transform(expr->what)); + result = renderExpr(expr, "**{}", transform(expr->getExpr())); } void FormatVisitor::visit(TupleExpr *expr) { - result = renderExpr(expr, "({})", transform(expr->items)); + result = renderExpr(expr, "({})", transformItems(*expr)); } void FormatVisitor::visit(ListExpr *expr) { - result = renderExpr(expr, "[{}]", transform(expr->items)); + result = renderExpr(expr, "[{}]", transformItems(*expr)); } void FormatVisitor::visit(InstantiateExpr *expr) { - result = renderExpr(expr, "{}[{}]", transform(expr->typeExpr), - transform(expr->typeParams)); + result = + renderExpr(expr, "{}⟦{}⟧", transform(expr->getExpr()), transformItems(*expr)); } void FormatVisitor::visit(SetExpr *expr) { - result = renderExpr(expr, "{{{}}}", transform(expr->items)); + result = renderExpr(expr, "{{{}}}", transformItems(*expr)); } void FormatVisitor::visit(DictExpr *expr) { std::vector s; - for (auto &i : expr->items) - s.push_back(fmt::format("{}: {}", transform(i->getTuple()->items[0]), - transform(i->getTuple()->items[1]))); + for (auto *i : *expr) { + auto t = cast(i); + s.push_back(fmt::format("{}: {}", transform((*t)[0]), transform((*t)[1]))); + } result = renderExpr(expr, "{{{}}}", join(s, ", ")); } void FormatVisitor::visit(GeneratorExpr *expr) { - std::string s; - for (auto &i : expr->loops) { - std::string cond; - for (auto &k : i.conds) - cond += fmt::format(" if {}", transform(k)); - s += fmt::format("for {} in {}{}", i.vars->toString(), i.gen->toString(), cond); - } - if (expr->kind == GeneratorExpr::ListGenerator) - result = renderExpr(expr, "[{} {}]", transform(expr->expr), s); - else if (expr->kind == GeneratorExpr::SetGenerator) - result = renderExpr(expr, "{{{} {}}}", transform(expr->expr), s); - else - result = renderExpr(expr, "({} {})", transform(expr->expr), s); -} - -void FormatVisitor::visit(DictGeneratorExpr *expr) { - std::string s; - for (auto &i : expr->loops) { - std::string cond; - for (auto &k : i.conds) - cond += fmt::format(" if {}", transform(k)); - - s += fmt::format("for {} in {}{}", i.vars->toString(), i.gen->toString(), cond); - } - result = - renderExpr(expr, "{{{}: {} {}}}", transform(expr->key), transform(expr->expr), s); + // seqassert(false, "not implemented"); + result = "GENERATOR_IMPL"; } void FormatVisitor::visit(IfExpr *expr) { - result = renderExpr(expr, "{} if {} else {}", transform(expr->ifexpr), - transform(expr->cond), transform(expr->elsexpr)); + result = renderExpr(expr, "({} {} {} {} {})", transform(expr->getIf()), keyword("if"), + transform(expr->getCond()), keyword("else"), + transform(expr->getElse())); } void FormatVisitor::visit(UnaryExpr *expr) { - result = renderExpr(expr, "{}{}", expr->op, transform(expr->expr)); + result = renderExpr(expr, "{}{}", expr->getOp(), transform(expr->getExpr())); } void FormatVisitor::visit(BinaryExpr *expr) { - result = renderExpr(expr, "({} {} {})", transform(expr->lexpr), expr->op, - transform(expr->rexpr)); + result = renderExpr(expr, "({} {} {})", transform(expr->getLhs()), expr->getOp(), + transform(expr->getRhs())); } void FormatVisitor::visit(PipeExpr *expr) { std::vector items; - for (auto &l : expr->items) { + for (const auto &l : *expr) { if (!items.size()) items.push_back(transform(l.expr)); else @@ -182,239 +179,268 @@ void FormatVisitor::visit(PipeExpr *expr) { } void FormatVisitor::visit(IndexExpr *expr) { - result = renderExpr(expr, "{}[{}]", transform(expr->expr), transform(expr->index)); + result = renderExpr(expr, "{}[{}]", transform(expr->getExpr()), + transform(expr->getIndex())); } void FormatVisitor::visit(CallExpr *expr) { std::vector args; - for (auto &i : expr->args) { + for (auto &i : *expr) { if (i.name == "") args.push_back(transform(i.value)); else - args.push_back(fmt::format("{}: {}", i.name, transform(i.value))); + args.push_back(fmt::format("{}={}", i.name, transform(i.value))); } - result = renderExpr(expr, "{}({})", transform(expr->expr), join(args, ", ")); + result = renderExpr(expr, "{}({})", transform(expr->getExpr()), join(args, ", ")); } void FormatVisitor::visit(DotExpr *expr) { - result = renderExpr(expr, "{} . {}", transform(expr->expr), expr->member); + result = renderExpr(expr, "{}○{}", transform(expr->getExpr()), expr->getMember()); } void FormatVisitor::visit(SliceExpr *expr) { std::string s; - if (expr->start) - s += transform(expr->start); + if (expr->getStart()) + s += transform(expr->getStart()); s += ":"; - if (expr->stop) - s += transform(expr->stop); + if (expr->getStop()) + s += transform(expr->getStop()); s += ":"; - if (expr->step) - s += transform(expr->step); + if (expr->getStep()) + s += transform(expr->getStep()); result = renderExpr(expr, "{}", s); } void FormatVisitor::visit(EllipsisExpr *expr) { result = renderExpr(expr, "..."); } void FormatVisitor::visit(LambdaExpr *expr) { - result = renderExpr(expr, "{} {}: {}", keyword("lambda"), join(expr->vars, ", "), - transform(expr->expr)); + std::vector s; + for (const auto &v : *expr) + s.emplace_back(v.getName()); + result = renderExpr(expr, "{} {}: {}", keyword("lambda"), join(s, ", "), + transform(expr->getExpr())); } -void FormatVisitor::visit(YieldExpr *expr) { result = renderExpr(expr, "(yield)"); } +void FormatVisitor::visit(YieldExpr *expr) { + result = renderExpr(expr, "(" + keyword("yield") + ")"); +} void FormatVisitor::visit(StmtExpr *expr) { std::string s; - for (int i = 0; i < expr->stmts.size(); i++) - s += format("{}{}", pad(2), transform(expr->stmts[i].get(), 2)); - result = renderExpr(expr, "({}{}{}{}{})", newline(), s, newline(), pad(2), - transform(expr->expr)); + for (auto *i : *expr) + s += format("{}{}", pad(2), transform(i, 2)); + result = renderExpr(expr, "《{}{}{}{}{}》", newline(), s, newline(), pad(2), + transform(expr->getExpr())); } void FormatVisitor::visit(AssignExpr *expr) { - result = renderExpr(expr, "({} := {})", transform(expr->var), transform(expr->expr)); + result = renderExpr(expr, "({} := {})", transform(expr->getVar()), + transform(expr->getExpr())); } void FormatVisitor::visit(SuiteStmt *stmt) { - for (int i = 0; i < stmt->stmts.size(); i++) - result += transform(stmt->stmts[i]); + for (auto *s : *stmt) + result += transform(s); } void FormatVisitor::visit(BreakStmt *stmt) { result = keyword("break"); } void FormatVisitor::visit(ContinueStmt *stmt) { result = keyword("continue"); } -void FormatVisitor::visit(ExprStmt *stmt) { result = transform(stmt->expr); } +void FormatVisitor::visit(ExprStmt *stmt) { result = transform(stmt->getExpr()); } void FormatVisitor::visit(AssignStmt *stmt) { - if (stmt->type) { - result = fmt::format("{}: {} = {}", transform(stmt->lhs), transform(stmt->type), - transform(stmt->rhs)); + if (stmt->getTypeExpr()) { + result = fmt::format("{}: {} = {}", transform(stmt->getLhs()), + transform(stmt->getTypeExpr()), transform(stmt->getRhs())); } else { - result = fmt::format("{} = {}", transform(stmt->lhs), transform(stmt->rhs)); + result = + fmt::format("{} = {}", transform(stmt->getLhs()), transform(stmt->getRhs())); } } void FormatVisitor::visit(AssignMemberStmt *stmt) { - result = fmt::format("{}.{} = {}", transform(stmt->lhs), stmt->member, - transform(stmt->rhs)); + result = fmt::format("{}○{} = {}", transform(stmt->getLhs()), stmt->getMember(), + transform(stmt->getRhs())); } void FormatVisitor::visit(DelStmt *stmt) { - result = fmt::format("{} {}", keyword("del"), transform(stmt->expr)); + result = fmt::format("{} {}", keyword("del"), transform(stmt->getExpr())); } void FormatVisitor::visit(PrintStmt *stmt) { - result = fmt::format("{} {}", keyword("print"), transform(stmt->items)); + result = fmt::format("{} {}", keyword("print"), transformItems(*stmt)); } void FormatVisitor::visit(ReturnStmt *stmt) { result = fmt::format("{}{}", keyword("return"), - stmt->expr ? " " + transform(stmt->expr) : ""); + stmt->getExpr() ? " " + transform(stmt->getExpr()) : ""); } void FormatVisitor::visit(YieldStmt *stmt) { result = fmt::format("{}{}", keyword("yield"), - stmt->expr ? " " + transform(stmt->expr) : ""); + stmt->getExpr() ? " " + transform(stmt->getExpr()) : ""); } void FormatVisitor::visit(AssertStmt *stmt) { - result = fmt::format("{} {}", keyword("assert"), transform(stmt->expr)); + result = fmt::format("{} {}", keyword("assert"), transform(stmt->getExpr())); } void FormatVisitor::visit(WhileStmt *stmt) { - result = fmt::format("{} {}:{}{}", keyword("while"), transform(stmt->cond), newline(), - transform(stmt->suite.get(), 1)); + result = fmt::format("{} {}:{}{}", keyword("while"), transform(stmt->getCond()), + newline(), transform(stmt->getSuite(), 1)); } void FormatVisitor::visit(ForStmt *stmt) { - result = fmt::format("{} {} {} {}:{}{}", keyword("for"), transform(stmt->var), - keyword("in"), transform(stmt->iter), newline(), - transform(stmt->suite.get(), 1)); + result = fmt::format("{} {} {} {}:{}{}", keyword("for"), transform(stmt->getVar()), + keyword("in"), transform(stmt->getIter()), newline(), + transform(stmt->getSuite(), 1)); } void FormatVisitor::visit(IfStmt *stmt) { - result = fmt::format("{} {}:{}{}{}", keyword("if"), transform(stmt->cond), newline(), - transform(stmt->ifSuite.get(), 1), - stmt->elseSuite ? format("{}:{}{}", keyword("else"), newline(), - transform(stmt->elseSuite.get(), 1)) + result = fmt::format("{} {}:{}{}{}", keyword("if"), transform(stmt->getCond()), + newline(), transform(stmt->getIf(), 1), + stmt->getElse() ? format("{}:{}{}", keyword("else"), newline(), + transform(stmt->getElse(), 1)) : ""); } void FormatVisitor::visit(MatchStmt *stmt) { std::string s; - for (auto &c : stmt->cases) - s += fmt::format("{}{}{}{}:{}{}", pad(1), keyword("case"), transform(c.pattern), - c.guard ? " " + (keyword("case") + " " + transform(c.guard)) : "", - newline(), transform(c.suite.get(), 2)); - result = - fmt::format("{} {}:{}{}", keyword("match"), transform(stmt->what), newline(), s); + for (const auto &c : *stmt) + s += fmt::format( + "{}{}{}{}:{}{}", pad(1), keyword("case"), transform(c.getPattern()), + c.getGuard() ? " " + (keyword("case") + " " + transform(c.getGuard())) : "", + newline(), transform(c.getSuite(), 2)); + result = fmt::format("{} {}:{}{}", keyword("match"), transform(stmt->getExpr()), + newline(), s); } void FormatVisitor::visit(ImportStmt *stmt) { - auto as = stmt->as.empty() ? "" : fmt::format(" {} {} ", keyword("as"), stmt->as); - if (!stmt->what) - result += fmt::format("{} {}{}", keyword("import"), transform(stmt->from), as); + auto as = + stmt->getAs().empty() ? "" : fmt::format(" {} {} ", keyword("as"), stmt->getAs()); + if (!stmt->getWhat()) + result += fmt::format("{} {}{}", keyword("import"), transform(stmt->getFrom()), as); else - result += fmt::format("{} {} {} {}{}", keyword("from"), transform(stmt->from), - keyword("import"), transform(stmt->what), as); + result += fmt::format("{} {} {} {}{}", keyword("from"), transform(stmt->getFrom()), + keyword("import"), transform(stmt->getWhat()), as); } void FormatVisitor::visit(TryStmt *stmt) { std::vector catches; - for (auto &c : stmt->catches) { - catches.push_back( - fmt::format("{} {}{}:{}{}", keyword("catch"), transform(c.exc), - c.var == "" ? "" : fmt::format("{} {}", keyword("as"), c.var), - newline(), transform(c.suite.get(), 1))); + for (auto *c : *stmt) { + catches.push_back(fmt::format( + "{} {}{}:{}{}", keyword("except"), transform(c->getException()), + c->getVar() == "" ? "" : fmt::format("{} {}", keyword("as"), c->getVar()), + newline(), transform(c->getSuite(), 1))); } - result = - fmt::format("{}:{}{}{}{}", keyword("try"), newline(), - transform(stmt->suite.get(), 1), fmt::join(catches, ""), - stmt->finally ? fmt::format("{}:{}{}", keyword("finally"), newline(), - transform(stmt->finally.get(), 1)) - : ""); + result = fmt::format("{}:{}{}{}{}", keyword("try"), newline(), + transform(stmt->getSuite(), 1), fmt::join(catches, ""), + stmt->getFinally() + ? fmt::format("{}:{}{}", keyword("finally"), newline(), + transform(stmt->getFinally(), 1)) + : ""); } void FormatVisitor::visit(GlobalStmt *stmt) { - result = fmt::format("{} {}", keyword("global"), stmt->var); + result = fmt::format("{} {}", keyword("global"), stmt->getVar()); } void FormatVisitor::visit(ThrowStmt *stmt) { - result = fmt::format("{} {}", keyword("raise"), transform(stmt->expr)); + result = fmt::format("{} {}{}", keyword("raise"), transform(stmt->getExpr()), + stmt->getFrom() ? fmt::format(" {} {}", keyword("from"), + transform(stmt->getFrom())) + : ""); } void FormatVisitor::visit(FunctionStmt *fstmt) { if (cache) { - if (in(cache->functions, fstmt->name)) { - if (!cache->functions[fstmt->name].realizations.empty()) { - for (auto &real : cache->functions[fstmt->name].realizations) { - if (real.first != fstmt->name) { - result += transform(real.second->ast.get(), 0); + if (in(cache->functions, fstmt->getName())) { + if (!cache->functions[fstmt->getName()].realizations.empty()) { + result += fmt::format("
# {}", + fmt::format("{} {}", keyword("def"), fstmt->getName())); + for (auto &real : cache->functions[fstmt->getName()].realizations) { + auto fa = real.second->ast; + auto ft = real.second->type; + std::vector attrs; + for (const auto &a : fa->getDecorators()) + attrs.push_back(fmt::format("@{}", transform(a))); + if (auto a = fa->getAttribute(Attr::Module)) + if (!a->value.empty()) + attrs.push_back(fmt::format("@module:{}", a->value)); + if (auto a = fa->getAttribute(Attr::ParentClass)) + if (!a->value.empty()) + attrs.push_back(fmt::format("@parent:{}", a->value)); + std::vector args; + for (size_t i = 0, j = 0; i < fa->size(); i++) { + auto &a = (*fa)[i]; + if (a.isValue()) { + args.push_back(fmt::format( + "{}: {}{}", a.getName(), anchor((*ft)[j++]->realizedName()), + a.getDefault() ? fmt::format("={}", transform(a.getDefault())) : "")); + } } + auto body = transform(fa->getSuite(), 1); + auto name = fmt::format("{}", anchor_root(fa->getName())); + result += fmt::format( + "{}{}{}{} {}({}){}:{}{}", newline(), pad(), + attrs.size() ? join(attrs, newline() + pad()) + newline() + pad() : "", + keyword("def"), anchor_root(name), fmt::join(args, ", "), + fmt::format(" -> {}", anchor(ft->getRetType()->realizedName())), + newline(), body.empty() ? fmt::format("{}", keyword("pass")) : body); } - return; + result += "
"; } - fstmt = cache->functions[fstmt->name].ast.get(); + return; } } - // if (cache && cache->functions.find(fstmt->name) != cache->realizationAsts.end()) - // { - // fstmt = (const FunctionStmt *)(cache->realizationAsts[fstmt->name].get()); - // } else if (cache && cache->functions[fstmt->name].realizations.size()) { - // for (auto &real : cache->functions[fstmt->name].realizations) - // result += transform(real.second.ast); - // return; - // } else if (cache) { - // fstmt = cache->functions[fstmt->name].ast.get(); - // } - - std::vector attrs; - for (auto &a : fstmt->decorators) - attrs.push_back(fmt::format("@{}", transform(a))); - if (!fstmt->attributes.module.empty()) - attrs.push_back(fmt::format("@module:{}", fstmt->attributes.parentClass)); - if (!fstmt->attributes.parentClass.empty()) - attrs.push_back(fmt::format("@parent:{}", fstmt->attributes.parentClass)); - std::vector args; - for (auto &a : fstmt->args) - args.push_back(fmt::format( - "{}{}{}", a.name, a.type ? fmt::format(": {}", transform(a.type)) : "", - a.defaultValue ? fmt::format(" = {}", transform(a.defaultValue)) : "")); - auto body = transform(fstmt->suite.get(), 1); - auto name = fmt::format("{}{}{}", typeStart, fstmt->name, typeEnd); - name = fmt::format("{}{}{}", exprStart, name, exprEnd); - result += fmt::format( - "{}{} {}({}){}:{}{}", - attrs.size() ? join(attrs, newline() + pad()) + newline() + pad() : "", - keyword("def"), name, fmt::join(args, ", "), - fstmt->ret ? fmt::format(" -> {}", transform(fstmt->ret)) : "", newline(), - body.empty() ? fmt::format("{}", keyword("pass")) : body); } void FormatVisitor::visit(ClassStmt *stmt) { - std::vector attrs; - - if (!stmt->attributes.has(Attr::Extend)) - attrs.push_back("@extend"); - if (!stmt->attributes.has(Attr::Tuple)) - attrs.push_back("@tuple"); - std::vector args; - std::string key = stmt->isRecord() ? "type" : "class"; - for (auto &a : stmt->args) - args.push_back(fmt::format("{}: {}", a.name, transform(a.type))); - result = fmt::format("{}{} {}({})", - attrs.size() ? join(attrs, newline() + pad()) + newline() + pad() - : "", - keyword(key), stmt->name, fmt::join(args, ", ")); - if (stmt->suite) - result += fmt::format(":{}{}", newline(), transform(stmt->suite.get(), 1)); + if (cache) { + if (auto cls = in(cache->classes, stmt->getName())) { + if (!cls->realizations.empty()) { + result = fmt::format( + "
# {}", + fmt::format("{} {} {}", keyword("class"), stmt->getName(), + stmt->hasAttribute(Attr::Extend) ? " +@extend" : "")); + for (auto &real : cls->realizations) { + std::vector args; + auto l = real.second->type->is(TYPE_TUPLE) + ? real.second->type->generics.size() + : real.second->fields.size(); + for (size_t i = 0; i < l; i++) { + const auto &[n, t] = real.second->fields[i]; + auto name = fmt::format("{}{}: {}{}", exprStart, n, + anchor(t->realizedName()), exprEnd); + args.push_back(name); + } + result += fmt::format("{}{}{}{} {}", newline(), pad(), + (stmt->hasAttribute(Attr::Tuple) + ? format("@tuple{}{}", newline(), pad()) + : ""), + keyword("class"), anchor_root(real.first)); + if (!args.empty()) + result += fmt::format(":{}{}{}", newline(), pad(indent + 1), + fmt::join(args, newline() + pad(indent + 1))); + } + result += "
"; + } + } + } + // if (stmt->suite) + // result += transform(stmt->suite); } void FormatVisitor::visit(YieldFromStmt *stmt) { - result = fmt::format("{} {}", keyword("yield from"), transform(stmt->expr)); + result = fmt::format("{} {}", keyword("yield from"), transform(stmt->getExpr())); } void FormatVisitor::visit(WithStmt *stmt) {} +void FormatVisitor::visit(CommentStmt *stmt) { + result = fmt::format("{}# {}{}", commentStart, stmt->getComment(), commentEnd); +} + } // namespace ast } // namespace codon diff --git a/codon/parser/visitors/format/format.h b/codon/parser/visitors/format/format.h index 66e343cc..42329f07 100644 --- a/codon/parser/visitors/format/format.h +++ b/codon/parser/visitors/format/format.h @@ -23,17 +23,22 @@ class FormatVisitor : public CallbackASTVisitor { std::string header, footer, nl; std::string typeStart, typeEnd; std::string nodeStart, nodeEnd; + std::string stmtStart, stmtEnd; std::string exprStart, exprEnd; std::string commentStart, commentEnd; std::string keywordStart, keywordEnd; + std::string literalStart, literalEnd; Cache *cache; private: template std::string renderExpr(T &&t, Ts &&...args) { - std::string s; - return fmt::format("{}{}{}{}{}{}", exprStart, s, nodeStart, fmt::format(args...), - nodeEnd, exprEnd); + std::string s = t->getType() + ? fmt::format("{}{}{}", typeStart, + anchor(t->getType()->realizedName()), typeEnd) + : ""; + return fmt::format("{}{}{}{}{}{}", exprStart, nodeStart, fmt::format(args...), + nodeEnd, s, exprEnd); } template std::string renderComment(Ts &&...args) { return fmt::format("{}{}{}", commentStart, fmt::format(args...), commentEnd); @@ -41,12 +46,14 @@ class FormatVisitor : public CallbackASTVisitor { std::string pad(int indent = 0) const; std::string newline() const; std::string keyword(const std::string &s) const; + std::string literal(const std::string &s) const; + std::string anchor_root(const std::string &s) const; + std::string anchor(const std::string &s) const; public: FormatVisitor(bool html, Cache *cache = nullptr); - std::string transform(const ExprPtr &e) override; - std::string transform(const Expr *expr); - std::string transform(const StmtPtr &stmt) override; + std::string transform(Expr *e) override; + std::string transform(Stmt *stmt) override; std::string transform(Stmt *stmt, int indent); template @@ -56,8 +63,8 @@ class FormatVisitor : public CallbackASTVisitor { return fmt::format("{}{}{}", t.header, t.transform(stmt), t.footer); } - void defaultVisit(Expr *e) override { error("cannot format {}", *e); } - void defaultVisit(Stmt *e) override { error("cannot format {}", *e); } + void defaultVisit(Expr *e) override { seqassertn(false, "cannot format {}", *e); } + void defaultVisit(Stmt *e) override { seqassertn(false, "cannot format {}", *e); } public: void visit(NoneExpr *) override; @@ -73,7 +80,6 @@ class FormatVisitor : public CallbackASTVisitor { void visit(SetExpr *) override; void visit(DictExpr *) override; void visit(GeneratorExpr *) override; - void visit(DictGeneratorExpr *) override; void visit(InstantiateExpr *expr) override; void visit(IfExpr *) override; void visit(UnaryExpr *) override; @@ -112,6 +118,7 @@ class FormatVisitor : public CallbackASTVisitor { void visit(ClassStmt *) override; void visit(YieldFromStmt *) override; void visit(WithStmt *) override; + void visit(CommentStmt *) override; public: friend std::ostream &operator<<(std::ostream &out, const FormatVisitor &c) { @@ -119,7 +126,7 @@ class FormatVisitor : public CallbackASTVisitor { } using CallbackASTVisitor::transform; - template std::string transform(const std::vector &ts) { + template std::string transformItems(const T &ts) { std::vector r; for (auto &e : ts) r.push_back(transform(e)); diff --git a/codon/parser/visitors/scoping/scoping.cpp b/codon/parser/visitors/scoping/scoping.cpp new file mode 100644 index 00000000..579d16a2 --- /dev/null +++ b/codon/parser/visitors/scoping/scoping.cpp @@ -0,0 +1,924 @@ +// Copyright (C) 2022-2023 Exaloop Inc. + +#include +#include +#include + +#include "codon/parser/ast.h" +#include "codon/parser/common.h" +#include "codon/parser/match.h" +#include "codon/parser/peg/peg.h" +#include "codon/parser/visitors/scoping/scoping.h" +#include + +#define CHECK(x) \ + { \ + if (!(x)) \ + return; \ + } +#define STOP_ERROR(...) \ + do { \ + addError(__VA_ARGS__); \ + return; \ + } while (0) + +using fmt::format; +using namespace codon::error; +using namespace codon::matcher; + +namespace codon::ast { + +llvm::Error ScopingVisitor::apply(Cache *cache, Stmt *s) { + auto c = std::make_shared(); + c->cache = cache; + c->functionScope = nullptr; + ScopingVisitor v; + v.ctx = c; + + ConditionalBlock cb(c.get(), s, 0); + if (!v.transform(s)) + return llvm::make_error(v.errors); + if (v.hasErrors()) + return llvm::make_error(v.errors); + v.processChildCaptures(); + // LOG("-> {}", s->toString(2)); + return llvm::Error::success(); +} + +bool ScopingVisitor::transform(Expr *expr) { + ScopingVisitor v(*this); + if (expr) { + v.setSrcInfo(expr->getSrcInfo()); + expr->accept(v); + if (v.hasErrors()) + errors.append(v.errors); + if (!canContinue()) + return false; + } + return true; +} + +bool ScopingVisitor::transform(Stmt *stmt) { + ScopingVisitor v(*this); + if (stmt) { + v.setSrcInfo(stmt->getSrcInfo()); + stmt->setAttribute(Attr::ExprTime, ++ctx->time); + stmt->accept(v); + if (v.hasErrors()) + errors.append(v.errors); + if (!canContinue()) + return false; + } + return true; +} + +bool ScopingVisitor::transformScope(Expr *e) { + if (e) { + ConditionalBlock c(ctx.get(), nullptr); + return transform(e); + } + return true; +} + +bool ScopingVisitor::transformScope(Stmt *s) { + if (s) { + ConditionalBlock c(ctx.get(), s); + return transform(s); + } + return true; +} + +bool ScopingVisitor::transformAdding(Expr *e, ASTNode *root) { + if (cast(e)) { + return transform(e); + } else if (auto de = cast(e)) { + if (!transform(e)) + return false; + if (!ctx->classDeduce.first.empty() && + match(de->getExpr(), M(ctx->classDeduce.first))) + ctx->classDeduce.second.insert(de->getMember()); + return true; + } else if (cast(e) || cast(e) || cast(e)) { + SetInScope s1(&(ctx->adding), true); + SetInScope s2(&(ctx->root), root); + return transform(e); + } else { + seqassert(e, "bad call to transformAdding"); + addError(Error::ASSIGN_INVALID, e); + return false; + } +} + +void ScopingVisitor::visit(IdExpr *expr) { + if (ctx->adding) + ctx->root = expr; + if (ctx->adding && ctx->tempScope) + ctx->renames.back()[expr->getValue()] = + ctx->cache->getTemporaryVar(expr->getValue()); + for (size_t i = ctx->renames.size(); i-- > 0;) + if (auto v = in(ctx->renames[i], expr->getValue())) { + expr->setValue(*v); + break; + } + if (visitName(expr->getValue(), ctx->adding, ctx->root, expr->getSrcInfo())) + expr->setAttribute(Attr::ExprDominatedUndefCheck); +} + +void ScopingVisitor::visit(DotExpr *expr) { + SetInScope s(&(ctx->adding), false); // to handle a.x, y = b + CallbackASTVisitor::visit(expr); +} + +void ScopingVisitor::visit(IndexExpr *expr) { + SetInScope s(&(ctx->adding), false); // to handle a[x], y = b + CallbackASTVisitor::visit(expr); +} + +void ScopingVisitor::visit(StringExpr *expr) { + std::vector exprs; + for (auto &p : *expr) { + if (p.prefix == "f" || p.prefix == "F") { + /// Transform an F-string + auto fstr = unpackFString(p.value); + if (!canContinue()) + return; + for (auto pf : fstr) { + if (pf.prefix.empty() && !exprs.empty() && exprs.back().prefix.empty()) { + exprs.back().value += pf.value; + } else { + exprs.emplace_back(pf); + } + } + } else if (!p.prefix.empty()) { + exprs.emplace_back(p); + } else if (!exprs.empty() && exprs.back().prefix.empty()) { + exprs.back().value += p.value; + } else { + exprs.emplace_back(p); + } + } + expr->strings = exprs; +} + +/// Split a Python-like f-string into a list: +/// `f"foo {x+1} bar"` -> `["foo ", str(x+1), " bar"] +/// Supports "{x=}" specifier (that prints the raw expression as well): +/// `f"{x+1=}"` -> `["x+1=", str(x+1)]` +std::vector +ScopingVisitor::unpackFString(const std::string &value) { + // Strings to be concatenated + std::vector items; + int braceCount = 0, braceStart = 0; + for (int i = 0; i < value.size(); i++) { + if (value[i] == '{') { + if (braceStart < i) + items.emplace_back(value.substr(braceStart, i - braceStart)); + if (!braceCount) + braceStart = i + 1; + braceCount++; + } else if (value[i] == '}') { + braceCount--; + if (!braceCount) { + std::string code = value.substr(braceStart, i - braceStart); + + auto offset = getSrcInfo(); + offset.col += i; + items.emplace_back(code, "#f"); + items.back().setSrcInfo(offset); + + auto val = parseExpr(ctx->cache, code, offset); + if (!val) { + addError(val.takeError()); + } else { + items.back().expr = val->first; + if (!transform(items.back().expr)) + return items; + items.back().format = val->second; + } + } + braceStart = i + 1; + } + } + if (braceCount > 0) + addError(Error::STR_FSTRING_BALANCE_EXTRA, getSrcInfo()); + else if (braceCount < 0) + addError(Error::STR_FSTRING_BALANCE_MISSING, getSrcInfo()); + if (braceStart != value.size()) + items.emplace_back(value.substr(braceStart, value.size() - braceStart)); + return items; +} + +void ScopingVisitor::visit(GeneratorExpr *expr) { + SetInScope s(&(ctx->tempScope), true); + ctx->renames.emplace_back(); + CHECK(transform(expr->getFinalSuite())); + ctx->renames.pop_back(); +} + +void ScopingVisitor::visit(IfExpr *expr) { + CHECK(transform(expr->getCond())); + CHECK(transformScope(expr->getIf())); + CHECK(transformScope(expr->getElse())); +} + +void ScopingVisitor::visit(BinaryExpr *expr) { + CHECK(transform(expr->getLhs())); + if (expr->getOp() == "&&" || expr->getOp() == "||") { + CHECK(transformScope(expr->getRhs())); + } else { + CHECK(transform(expr->getRhs())); + } +} + +void ScopingVisitor::visit(AssignExpr *expr) { + seqassert(cast(expr->getVar()), + "only simple assignment expression are supported"); + + SetInScope s(&(ctx->tempScope), false); + CHECK(transform(expr->getExpr())); + CHECK(transformAdding(expr->getVar(), expr)); +} + +void ScopingVisitor::visit(LambdaExpr *expr) { + auto c = std::make_shared(); + c->cache = ctx->cache; + FunctionStmt f("lambda", nullptr, {}, nullptr); + c->functionScope = &f; + c->renames = ctx->renames; + ScopingVisitor v; + c->scope.emplace_back(0, nullptr); + v.ctx = c; + for (const auto &a : *expr) { + auto [_, n] = a.getNameWithStars(); + v.visitName(n, true, expr, a.getSrcInfo()); + if (a.defaultValue) + CHECK(transform(a.defaultValue)); + } + c->scope.pop_back(); + + SuiteStmt s; + c->scope.emplace_back(0, &s); + v.transform(expr->getExpr()); + v.processChildCaptures(); + c->scope.pop_back(); + if (v.hasErrors()) + errors.append(v.errors); + if (!canContinue()) + return; + + auto b = std::make_unique(); + b->captures = c->captures; + for (const auto &n : c->captures) + ctx->childCaptures.insert(n); + for (auto &[u, v] : c->map) + b->bindings[u] = v.size(); + expr->setAttribute(Attr::Bindings, std::move(b)); +} + +// todo)) Globals/nonlocals cannot be shadowed in children scopes (as in Python) + +void ScopingVisitor::visit(AssignStmt *stmt) { + CHECK(transform(stmt->getRhs())); + CHECK(transform(stmt->getTypeExpr())); + CHECK(transformAdding(stmt->getLhs(), stmt)); +} + +void ScopingVisitor::visit(IfStmt *stmt) { + CHECK(transform(stmt->getCond())); + CHECK(transformScope(stmt->getIf())); + CHECK(transformScope(stmt->getElse())); +} + +void ScopingVisitor::visit(MatchStmt *stmt) { + CHECK(transform(stmt->getExpr())); + for (auto &m : *stmt) { + CHECK(transform(m.getPattern())); + CHECK(transform(m.getGuard())); + CHECK(transformScope(m.getSuite())); + } +} + +void ScopingVisitor::visit(WhileStmt *stmt) { + CHECK(transform(stmt->getCond())); + + std::unordered_set seen; + { + ConditionalBlock c(ctx.get(), stmt->getSuite()); + ctx->scope.back().seenVars = std::make_unique>(); + CHECK(transform(stmt->getSuite())); + seen = *(ctx->scope.back().seenVars); + } + for (auto &var : seen) + findDominatingBinding(var); + + CHECK(transformScope(stmt->getElse())); +} + +void ScopingVisitor::visit(ForStmt *stmt) { + CHECK(transform(stmt->getIter())); + CHECK(transform(stmt->getDecorator())); + for (auto &a : stmt->ompArgs) + CHECK(transform(a.value)); + + std::unordered_set seen, seenDef; + { + ConditionalBlock c(ctx.get(), stmt->getSuite()); + + ctx->scope.back().seenVars = std::make_unique>(); + CHECK(transformAdding(stmt->getVar(), stmt)); + seenDef = *(ctx->scope.back().seenVars); + + ctx->scope.back().seenVars = std::make_unique>(); + CHECK(transform(stmt->getSuite())); + seen = *(ctx->scope.back().seenVars); + } + for (auto &var : seen) + if (!in(seenDef, var)) + findDominatingBinding(var); + + CHECK(transformScope(stmt->getElse())); +} + +void ScopingVisitor::visit(ImportStmt *stmt) { + // Validate + if (stmt->getFrom()) { + Expr *e = stmt->getFrom(); + while (auto d = cast(e)) + e = d->getExpr(); + if (!isId(stmt->getFrom(), "C") && !isId(stmt->getFrom(), "python")) { + if (!cast(e)) + STOP_ERROR(Error::IMPORT_IDENTIFIER, e); + if (!stmt->getArgs().empty()) + STOP_ERROR(Error::IMPORT_FN, stmt->getArgs().front().getSrcInfo()); + if (stmt->getReturnType()) + STOP_ERROR(Error::IMPORT_FN, stmt->getReturnType()); + if (stmt->getWhat() && !cast(stmt->getWhat())) + STOP_ERROR(Error::IMPORT_IDENTIFIER, stmt->getWhat()); + } + if (stmt->isCVar() && !stmt->getArgs().empty()) + STOP_ERROR(Error::IMPORT_FN, stmt->getArgs().front().getSrcInfo()); + } + if (ctx->functionScope && stmt->getWhat() && isId(stmt->getWhat(), "*")) + STOP_ERROR(error::Error::IMPORT_STAR, stmt); + + // dylib C imports + if (stmt->getFrom() && isId(stmt->getFrom(), "C") && cast(stmt->getWhat())) + CHECK(transform(cast(stmt->getWhat())->getExpr())); + + if (stmt->getAs().empty()) { + if (stmt->getWhat()) { + CHECK(transformAdding(stmt->getWhat(), stmt)); + } else { + CHECK(transformAdding(stmt->getFrom(), stmt)); + } + } else { + visitName(stmt->getAs(), true, stmt, stmt->getSrcInfo()); + } + for (const auto &a : stmt->getArgs()) { + CHECK(transform(a.type)); + CHECK(transform(a.defaultValue)); + } + CHECK(transform(stmt->getReturnType())); +} + +void ScopingVisitor::visit(TryStmt *stmt) { + CHECK(transformScope(stmt->getSuite())); + for (auto *a : *stmt) { + CHECK(transform(a->getException())); + ConditionalBlock c(ctx.get(), a->getSuite()); + if (!a->getVar().empty()) { + auto newName = ctx->cache->getTemporaryVar(a->getVar()); + ctx->renames.push_back({{a->getVar(), newName}}); + a->var = newName; + visitName(a->getVar(), true, a, a->getException()->getSrcInfo()); + } + CHECK(transform(a->getSuite())); + if (!a->getVar().empty()) + ctx->renames.pop_back(); + } + CHECK(transform(stmt->getElse())); + CHECK(transform(stmt->getFinally())); +} + +void ScopingVisitor::visit(DelStmt *stmt) { + /// TODO + CHECK(transform(stmt->getExpr())); +} + +/// Process `global` statements. Remove them upon completion. +void ScopingVisitor::visit(GlobalStmt *stmt) { + if (!ctx->functionScope) + STOP_ERROR(Error::FN_OUTSIDE_ERROR, stmt, + stmt->isNonLocal() ? "nonlocal" : "global"); + if (in(ctx->map, stmt->getVar()) || in(ctx->captures, stmt->getVar())) + STOP_ERROR(Error::FN_GLOBAL_ASSIGNED, stmt, stmt->getVar()); + + visitName(stmt->getVar(), true, stmt, stmt->getSrcInfo()); + ctx->captures[stmt->getVar()] = stmt->isNonLocal() + ? BindingsAttribute::CaptureType::Nonlocal + : BindingsAttribute::CaptureType::Global; +} + +void ScopingVisitor::visit(YieldStmt *stmt) { + if (ctx->functionScope) + ctx->functionScope->setAttribute(Attr::IsGenerator); + CHECK(transform(stmt->getExpr())); +} + +void ScopingVisitor::visit(YieldExpr *expr) { + if (ctx->functionScope) + ctx->functionScope->setAttribute(Attr::IsGenerator); +} + +void ScopingVisitor::visit(FunctionStmt *stmt) { + // Validate + std::vector newDecorators; + for (auto &d : stmt->getDecorators()) { + if (isId(d, Attr::Attribute)) { + if (stmt->getDecorators().size() != 1) + STOP_ERROR(Error::FN_SINGLE_DECORATOR, stmt->getDecorators()[1], + Attr::Attribute); + stmt->setAttribute(Attr::Attribute); + } else if (isId(d, Attr::LLVM)) { + stmt->setAttribute(Attr::LLVM); + } else if (isId(d, Attr::Python)) { + if (stmt->getDecorators().size() != 1) + STOP_ERROR(Error::FN_SINGLE_DECORATOR, stmt->getDecorators()[1], Attr::Python); + stmt->setAttribute(Attr::Python); + } else if (isId(d, Attr::Internal)) { + stmt->setAttribute(Attr::Internal); + } else if (isId(d, Attr::HiddenFromUser)) { + stmt->setAttribute(Attr::HiddenFromUser); + } else if (isId(d, Attr::Atomic)) { + stmt->setAttribute(Attr::Atomic); + } else if (isId(d, Attr::Property)) { + stmt->setAttribute(Attr::Property); + } else if (isId(d, Attr::StaticMethod)) { + stmt->setAttribute(Attr::StaticMethod); + } else if (isId(d, Attr::ForceRealize)) { + stmt->setAttribute(Attr::ForceRealize); + } else if (isId(d, Attr::C)) { + stmt->setAttribute(Attr::C); + } else { + newDecorators.emplace_back(d); + } + } + if (stmt->hasAttribute(Attr::C)) { + for (auto &a : *stmt) { + if (a.getName().size() > 1 && a.getName()[0] == '*' && a.getName()[1] != '*') + stmt->setAttribute(Attr::CVarArg); + } + } + if (!stmt->empty() && !stmt->front().getType() && stmt->front().getName() == "self") { + stmt->setAttribute(Attr::HasSelf); + } + stmt->setDecorators(newDecorators); + if (!stmt->getReturn() && + (stmt->hasAttribute(Attr::LLVM) || stmt->hasAttribute(Attr::C))) + STOP_ERROR(Error::FN_LLVM, getSrcInfo()); + // Set attributes + std::unordered_set seenArgs; + bool defaultsStarted = false, hasStarArg = false, hasKwArg = false; + for (size_t ia = 0; ia < stmt->size(); ia++) { + auto &a = (*stmt)[ia]; + auto [stars, n] = a.getNameWithStars(); + if (stars == 2) { + if (hasKwArg) + STOP_ERROR(Error::FN_MULTIPLE_ARGS, a.getSrcInfo()); + if (a.getDefault()) + STOP_ERROR(Error::FN_DEFAULT_STARARG, a.getDefault()); + if (ia != stmt->size() - 1) + STOP_ERROR(Error::FN_LAST_KWARG, a.getSrcInfo()); + + hasKwArg = true; + } else if (stars == 1) { + if (hasStarArg) + STOP_ERROR(Error::FN_MULTIPLE_ARGS, a.getSrcInfo()); + if (a.getDefault()) + STOP_ERROR(Error::FN_DEFAULT_STARARG, a.getDefault()); + hasStarArg = true; + } + if (in(seenArgs, n)) + STOP_ERROR(Error::FN_ARG_TWICE, a.getSrcInfo(), n); + seenArgs.insert(n); + if (!a.getDefault() && defaultsStarted && !stars && a.isValue()) + STOP_ERROR(Error::FN_DEFAULT, a.getSrcInfo(), n); + defaultsStarted |= bool(a.getDefault()); + if (stmt->hasAttribute(Attr::C)) { + if (a.getDefault()) + STOP_ERROR(Error::FN_C_DEFAULT, a.getDefault(), n); + if (stars != 1 && !a.getType()) + STOP_ERROR(Error::FN_C_TYPE, a.getSrcInfo(), n); + } + } + + bool isOverload = false; + for (auto &d : stmt->getDecorators()) + if (isId(d, "overload")) { + isOverload = true; + } + if (!isOverload) + visitName(stmt->getName(), true, stmt, stmt->getSrcInfo()); + + auto c = std::make_shared(); + c->cache = ctx->cache; + c->functionScope = stmt; + if (ctx->inClass && !stmt->empty()) + c->classDeduce = {stmt->front().getName(), {}}; + c->renames = ctx->renames; + ScopingVisitor v; + c->scope.emplace_back(0); + v.ctx = c; + v.visitName(stmt->getName(), true, stmt, stmt->getSrcInfo()); + for (const auto &a : *stmt) { + auto [_, n] = a.getNameWithStars(); + v.visitName(n, true, stmt, a.getSrcInfo()); + if (a.defaultValue) + CHECK(transform(a.defaultValue)); + } + c->scope.pop_back(); + + c->scope.emplace_back(0, stmt->getSuite()); + v.transform(stmt->getSuite()); + v.processChildCaptures(); + c->scope.pop_back(); + if (v.hasErrors()) + errors.append(v.errors); + if (!canContinue()) + return; + + auto b = std::make_unique(); + b->captures = c->captures; + for (const auto &n : c->captures) { + ctx->childCaptures.insert(n); + } + for (auto &[u, v] : c->map) { + b->bindings[u] = v.size(); + } + stmt->setAttribute(Attr::Bindings, std::move(b)); + + + if (!c->classDeduce.second.empty()) { + auto s = std::make_unique(); + for (const auto &n : c->classDeduce.second) + s->values.push_back(n); + stmt->setAttribute(Attr::ClassDeduce, std::move(s)); + } +} + +void ScopingVisitor::visit(WithStmt *stmt) { + ConditionalBlock c(ctx.get(), stmt->getSuite()); + for (size_t i = 0; i < stmt->size(); i++) { + CHECK(transform((*stmt)[i])); + if (!stmt->getVars()[i].empty()) + visitName(stmt->getVars()[i], true, stmt, stmt->getSrcInfo()); + } + CHECK(transform(stmt->getSuite())); +} + +void ScopingVisitor::visit(ClassStmt *stmt) { + // @tuple(init=, repr=, eq=, order=, hash=, pickle=, container=, python=, add=, + // internal=...) + // @dataclass(...) + // @extend + + std::map tupleMagics = { + {"new", true}, {"repr", false}, {"hash", false}, + {"eq", false}, {"ne", false}, {"lt", false}, + {"le", false}, {"gt", false}, {"ge", false}, + {"pickle", true}, {"unpickle", true}, {"to_py", false}, + {"from_py", false}, {"iter", false}, {"getitem", false}, + {"len", false}, {"to_gpu", false}, {"from_gpu", false}, + {"from_gpu_new", false}, {"tuplesize", true}}; + + for (auto &d : stmt->getDecorators()) { + if (isId(d, "deduce")) { + stmt->setAttribute(Attr::ClassDeduce); + } else if (isId(d, "__notuple__")) { + stmt->setAttribute(Attr::ClassNoTuple); + } else if (isId(d, "dataclass")) { + } else if (auto c = cast(d)) { + if (isId(c->getExpr(), Attr::Tuple)) { + stmt->setAttribute(Attr::Tuple); + for (auto &m : tupleMagics) + m.second = true; + } else if (!isId(c->getExpr(), "dataclass")) { + STOP_ERROR(Error::CLASS_BAD_DECORATOR, c->getExpr()); + } else if (stmt->hasAttribute(Attr::Tuple)) { + STOP_ERROR(Error::CLASS_CONFLICT_DECORATOR, c, "dataclass", Attr::Tuple); + } + for (const auto &a : *c) { + auto b = cast(a); + if (!b) + STOP_ERROR(Error::CLASS_NONSTATIC_DECORATOR, a.getSrcInfo()); + char val = char(b->getValue()); + if (a.getName() == "init") { + tupleMagics["new"] = val; + } else if (a.getName() == "repr") { + tupleMagics["repr"] = val; + } else if (a.getName() == "eq") { + tupleMagics["eq"] = tupleMagics["ne"] = val; + } else if (a.getName() == "order") { + tupleMagics["lt"] = tupleMagics["le"] = tupleMagics["gt"] = + tupleMagics["ge"] = val; + } else if (a.getName() == "hash") { + tupleMagics["hash"] = val; + } else if (a.getName() == "pickle") { + tupleMagics["pickle"] = tupleMagics["unpickle"] = val; + } else if (a.getName() == "python") { + tupleMagics["to_py"] = tupleMagics["from_py"] = val; + } else if (a.getName() == "gpu") { + tupleMagics["to_gpu"] = tupleMagics["from_gpu"] = + tupleMagics["from_gpu_new"] = val; + } else if (a.getName() == "container") { + tupleMagics["iter"] = tupleMagics["getitem"] = val; + } else { + STOP_ERROR(Error::CLASS_BAD_DECORATOR_ARG, a.getSrcInfo()); + } + } + } else if (isId(d, Attr::Tuple)) { + if (stmt->hasAttribute(Attr::Tuple)) + STOP_ERROR(Error::CLASS_MULTIPLE_DECORATORS, d, Attr::Tuple); + stmt->setAttribute(Attr::Tuple); + for (auto &m : tupleMagics) { + m.second = true; + } + } else if (isId(d, Attr::Extend)) { + stmt->setAttribute(Attr::Extend); + if (stmt->getDecorators().size() != 1) { + STOP_ERROR( + Error::CLASS_SINGLE_DECORATOR, + stmt->getDecorators()[stmt->getDecorators().front() == d]->getSrcInfo(), + Attr::Extend); + } + } else if (isId(d, Attr::Internal)) { + stmt->setAttribute(Attr::Internal); + } else { + STOP_ERROR(Error::CLASS_BAD_DECORATOR, d); + } + } + if (stmt->hasAttribute(Attr::ClassDeduce)) + tupleMagics["new"] = false; + if (!stmt->hasAttribute(Attr::Tuple)) { + tupleMagics["init"] = tupleMagics["new"]; + tupleMagics["new"] = tupleMagics["raw"] = true; + tupleMagics["len"] = false; + } + tupleMagics["dict"] = true; + // Internal classes do not get any auto-generated members. + std::vector magics; + if (!stmt->hasAttribute(Attr::Internal)) { + for (auto &m : tupleMagics) + if (m.second) { + if (m.first == "new") + magics.insert(magics.begin(), m.first); + else + magics.push_back(m.first); + } + } + stmt->setAttribute(Attr::ClassMagic, + std::make_unique(magics)); + std::unordered_set seen; + if (stmt->hasAttribute(Attr::Extend) && !stmt->empty()) + STOP_ERROR(Error::CLASS_EXTENSION, stmt->front().getSrcInfo()); + if (stmt->hasAttribute(Attr::Extend) && + !(stmt->getBaseClasses().empty() && stmt->getStaticBaseClasses().empty())) { + STOP_ERROR(Error::CLASS_EXTENSION, stmt->getBaseClasses().empty() + ? stmt->getStaticBaseClasses().front() + : stmt->getBaseClasses().front()); + } + for (auto &a : *stmt) { + if (!a.getType() && !a.getDefault()) + STOP_ERROR(Error::CLASS_MISSING_TYPE, a.getSrcInfo(), a.getName()); + if (in(seen, a.getName())) + STOP_ERROR(Error::CLASS_ARG_TWICE, a.getSrcInfo(), a.getName()); + seen.insert(a.getName()); + } + + if (stmt->hasAttribute(Attr::Extend)) + visitName(stmt->getName()); + else + visitName(stmt->getName(), true, stmt, stmt->getSrcInfo()); + + auto c = std::make_shared(); + c->cache = ctx->cache; + c->renames = ctx->renames; + ScopingVisitor v; + c->scope.emplace_back(0); + c->inClass = true; + v.ctx = c; + for (const auto &a : *stmt) { + v.transform(a.type); + v.transform(a.defaultValue); + } + v.transform(stmt->getSuite()); + c->scope.pop_back(); + if (v.hasErrors()) + errors.append(v.errors); + if (!canContinue()) + return; + + for (auto &d : stmt->getBaseClasses()) + CHECK(transform(d)); + for (auto &d : stmt->getStaticBaseClasses()) + CHECK(transform(d)); +} + +void ScopingVisitor::processChildCaptures() { + for (auto &n : ctx->childCaptures) { + if (auto i = in(ctx->map, n.first)) { + if (i->back().binding && cast(i->back().binding)) + continue; + } + if (!findDominatingBinding(n.first)) { + ctx->captures.insert(n); // propagate! + } + } +} + +void ScopingVisitor::switchToUpdate(ASTNode *binding, const std::string &name, + bool gotUsedVar) { + if (binding && binding->hasAttribute(Attr::Bindings)) { + binding->getAttribute(Attr::Bindings)->bindings.erase(name); + } + if (binding) { + if (!gotUsedVar && binding->hasAttribute(Attr::ExprDominatedUsed)) + binding->eraseAttribute(Attr::ExprDominatedUsed); + binding->setAttribute(gotUsedVar ? Attr::ExprDominatedUsed : Attr::ExprDominated); + } + if (cast(binding)) + STOP_ERROR(error::Error::ID_INVALID_BIND, binding, name); + else if (cast(binding)) + STOP_ERROR(error::Error::ID_INVALID_BIND, binding, name); +} + +bool ScopingVisitor::visitName(const std::string &name, bool adding, ASTNode *root, + const SrcInfo &src) { + if (adding && ctx->inClass) + return false; + if (adding) { + if (auto p = in(ctx->captures, name)) { + if (*p == BindingsAttribute::CaptureType::Read) { + addError(error::Error::ASSIGN_LOCAL_REFERENCE, ctx->firstSeen[name], name, src); + return false; + } else if (root) { // global, nonlocal + switchToUpdate(root, name, false); + } + } else { + if (auto i = in(ctx->childCaptures, name)) { + if (*i != BindingsAttribute::CaptureType::Global && ctx->functionScope) { + auto newScope = std::vector{ctx->scope[0].id}; + seqassert(ctx->scope.front().suite, "invalid suite"); + if (!ctx->scope.front().suite->hasAttribute(Attr::Bindings)) + ctx->scope.front().suite->setAttribute( + Attr::Bindings, std::make_unique()); + ctx->scope.front() + .suite->getAttribute(Attr::Bindings) + ->bindings[name] = false; + auto newItem = ScopingVisitor::Context::Item(src, newScope, nullptr); + ctx->map[name].push_back(newItem); + } + } + ctx->map[name].emplace_front(src, ctx->getScope(), root); + } + } else { + if (!in(ctx->firstSeen, name)) + ctx->firstSeen[name] = src; + if (!in(ctx->map, name)) { + ctx->captures[name] = BindingsAttribute::CaptureType::Read; + } + } + if (auto val = findDominatingBinding(name)) { + // Track loop variables to dominate them later. Example: + // x = 1 + // while True: + // if x > 10: break + // x = x + 1 # x must be dominated after the loop to ensure that it gets updated + auto scope = ctx->getScope(); + for (size_t li = ctx->scope.size(); li-- > 0;) { + if (ctx->scope[li].seenVars) { + bool inside = val->scope.size() >= scope.size() && + val->scope[scope.size() - 1] == scope.back(); + if (!inside) + ctx->scope[li].seenVars->insert(name); + else + break; + } + scope.pop_back(); + } + + // Variable binding check for variables that are defined within conditional blocks + if (!val->accessChecked.empty()) { + bool checked = false; + for (size_t ai = val->accessChecked.size(); ai-- > 0;) { + auto &a = val->accessChecked[ai]; + if (a.size() <= ctx->scope.size() && + a[a.size() - 1] == ctx->scope[a.size() - 1].id) { + checked = true; + break; + } + } + if (!checked) { + seqassert(!adding, "bad branch"); + if (!(val->binding && val->binding->hasAttribute(Attr::Bindings))) { + // If the expression is not conditional, we can just do the check once + val->accessChecked.push_back(ctx->getScope()); + } + return true; + } + } + } + return false; +} + +/// Get an item from the context. Perform domination analysis for accessing items +/// defined in the conditional blocks (i.e., Python scoping). +ScopingVisitor::Context::Item * +ScopingVisitor::findDominatingBinding(const std::string &name, bool allowShadow) { + auto it = in(ctx->map, name); + if (!it || it->empty()) + return nullptr; + auto lastGood = it->begin(); + while (lastGood != it->end() && lastGood->ignore) + lastGood++; + int commonScope = int(ctx->scope.size()); + // Iterate through all bindings with the given name and find the closest binding that + // dominates the current scope. + for (auto i = it->begin(); i != it->end(); i++) { + if (i->ignore) + continue; + + bool completeDomination = i->scope.size() <= ctx->scope.size() && + i->scope.back() == ctx->scope[i->scope.size() - 1].id; + if (completeDomination) { + commonScope = i->scope.size(); + lastGood = i; + break; + } else { + seqassert(i->scope[0] == 0, "bad scoping"); + seqassert(ctx->scope[0].id == 0, "bad scoping"); + // Find the longest block prefix between the binding and the current common scope. + commonScope = std::min(commonScope, int(i->scope.size())); + while (commonScope > 0 && + i->scope[commonScope - 1] != ctx->scope[commonScope - 1].id) + commonScope--; + // if (commonScope < int(ctx->scope.size()) && commonScope != p) + // break; + lastGood = i; + } + } + seqassert(lastGood != it->end(), "corrupted scoping ({})", name); + if (!allowShadow) { // go to the end + lastGood = it->end(); + --lastGood; + int p = std::min(commonScope, int(lastGood->scope.size())); + while (p >= 0 && lastGood->scope[p - 1] != ctx->scope[p - 1].id) + p--; + commonScope = p; + } + + bool gotUsedVar = false; + if (lastGood->scope.size() != commonScope) { + // The current scope is potentially reachable by multiple bindings that are + // not dominated by a common binding. Create such binding in the scope that + // dominates (covers) all of them. + auto scope = ctx->getScope(); + auto newScope = std::vector(scope.begin(), scope.begin() + commonScope); + + // Make sure to prepend a binding declaration: `var` and `var__used__ = False` + // to the dominating scope. + for (size_t si = commonScope; si-- > 0; si--) { + if (!ctx->scope[si].suite) + continue; + if (!ctx->scope[si].suite->hasAttribute(Attr::Bindings)) + ctx->scope[si].suite->setAttribute(Attr::Bindings, + std::make_unique()); + ctx->scope[si] + .suite->getAttribute(Attr::Bindings) + ->bindings[name] = true; + auto newItem = ScopingVisitor::Context::Item( + getSrcInfo(), newScope, ctx->scope[si].suite, {lastGood->scope}); + lastGood = it->insert(++lastGood, newItem); + gotUsedVar = true; + break; + } + } else if (lastGood->binding && lastGood->binding->hasAttribute(Attr::Bindings)) { + gotUsedVar = lastGood->binding->getAttribute(Attr::Bindings) + ->bindings[name]; + } + // Remove all bindings after the dominant binding. + for (auto i = it->begin(); i != it->end(); i++) { + if (i == lastGood) + break; + switchToUpdate(i->binding, name, gotUsedVar); + i->scope = lastGood->scope; + i->ignore = true; + } + if (!gotUsedVar && lastGood->binding && + lastGood->binding->hasAttribute(Attr::Bindings)) + lastGood->binding->getAttribute(Attr::Bindings)->bindings[name] = + false; + return &(*lastGood); +} + +} // namespace codon::ast diff --git a/codon/parser/visitors/scoping/scoping.h b/codon/parser/visitors/scoping/scoping.h new file mode 100644 index 00000000..5a3fb5c2 --- /dev/null +++ b/codon/parser/visitors/scoping/scoping.h @@ -0,0 +1,186 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "codon/parser/ast.h" +#include "codon/parser/common.h" +#include "codon/parser/visitors/typecheck/ctx.h" +#include "codon/parser/visitors/visitor.h" + +namespace codon::ast { + +struct BindingsAttribute : public ir::Attribute { + static const std::string AttributeName; + + enum CaptureType { Read, Global, Nonlocal }; + std::unordered_map captures; + std::unordered_map bindings; + + std::unique_ptr clone() const override { + auto p = std::make_unique(); + p->captures = captures; + p->bindings = bindings; + return p; + } + +private: + std::ostream &doFormat(std::ostream &os) const override { return os << "Bindings"; } +}; + +class ScopingVisitor : public CallbackASTVisitor { + struct Context { + /// A pointer to the shared cache. + Cache *cache; + + /// Holds the information about current scope. + /// A scope is defined as a stack of conditional blocks + /// (i.e., blocks that might not get executed during the runtime). + /// Used mainly to support Python's variable scoping rules. + struct ScopeBlock { + int id; + // Associated SuiteStmt + Stmt *suite; + /// List of variables "seen" before their assignment within a loop. + /// Used to dominate variables that are updated within a loop. + std::unique_ptr> seenVars = nullptr; + ScopeBlock(int id, Stmt *s = nullptr) : id(id), suite(s), seenVars(nullptr) {} + }; + /// Current hierarchy of conditional blocks. + std::vector scope; + std::vector getScope() const { + std::vector result; + result.reserve(scope.size()); + for (const auto &b : scope) + result.emplace_back(b.id); + return result; + } + + struct Item : public codon::SrcObject { + std::vector scope; + ASTNode *binding = nullptr; + bool ignore = false; + + /// List of scopes where the identifier is accessible + /// without __used__ check + std::vector> accessChecked; + Item(const codon::SrcInfo &src, std::vector scope, + ASTNode *binding = nullptr, std::vector> accessChecked = {}) + : scope(std::move(scope)), binding(std::move(binding)), ignore(false), + accessChecked(std::move(accessChecked)) { + setSrcInfo(src); + } + }; + std::unordered_map> map; + + std::unordered_map captures; + std::unordered_map + childCaptures; // for functions! + std::map firstSeen; + std::pair> classDeduce; + + bool adding = false; + ASTNode *root = nullptr; + FunctionStmt *functionScope = nullptr; + bool inClass = false; + // bool isConditional = false; + + std::vector> renames = {{}}; + bool tempScope = false; + + // Time to track positions of assignments and references to them. + int64_t time = 0; + }; + std::shared_ptr ctx = nullptr; + + struct ConditionalBlock { + Context *ctx; + ConditionalBlock(Context *ctx, Stmt *s, int id = -1) : ctx(ctx) { + if (s) + seqassertn(cast(s), "not a suite"); + ctx->scope.emplace_back(id == -1 ? ctx->cache->blockCount++ : id, s); + } + ~ConditionalBlock() { + seqassertn(!ctx->scope.empty() && + (ctx->scope.back().id == 0 || ctx->scope.size() > 1), + "empty scope"); + ctx->scope.pop_back(); + } + }; + +public: + ParserErrors errors; + bool hasErrors() const { return !errors.empty(); } + bool canContinue() const { return errors.size() <= MAX_ERRORS; } + + template + void addError(error::Error e, const SrcInfo &o, const TA &...args) { + auto msg = + ErrorMessage(error::Emsg(e, args...), o.file, o.line, o.col, o.len, int(e)); + errors.addError({msg}); + } + template void addError(error::Error e, ASTNode *o, const TA &...args) { + this->addError(e, o->getSrcInfo(), args...); + } + void addError(llvm::Error &&e) { + llvm::handleAllErrors(std::move(e), [this](const error::ParserErrorInfo &e) { + this->errors.append(e.getErrors()); + }); + } + + static llvm::Error apply(Cache *, Stmt *s); + bool transform(Expr *expr) override; + bool transform(Stmt *stmt) override; + + // Can error! + bool visitName(const std::string &name, bool = false, ASTNode * = nullptr, + const SrcInfo & = SrcInfo()); + bool transformAdding(Expr *e, ASTNode *); + bool transformScope(Expr *); + bool transformScope(Stmt *); + + void visit(StringExpr *) override; + void visit(IdExpr *) override; + void visit(DotExpr *) override; + void visit(IndexExpr *) override; + void visit(GeneratorExpr *) override; + void visit(IfExpr *) override; + void visit(BinaryExpr *) override; + void visit(LambdaExpr *) override; + void visit(YieldExpr *) override; + void visit(AssignExpr *) override; + + void visit(AssignStmt *) override; + void visit(DelStmt *) override; + void visit(YieldStmt *) override; + void visit(WhileStmt *) override; + void visit(ForStmt *) override; + void visit(IfStmt *) override; + void visit(MatchStmt *) override; + void visit(ImportStmt *) override; + void visit(TryStmt *) override; + void visit(GlobalStmt *) override; + void visit(FunctionStmt *) override; + void visit(ClassStmt *) override; + void visit(WithStmt *) override; + + Context::Item *findDominatingBinding(const std::string &, bool = true); + void processChildCaptures(); + void switchToUpdate(ASTNode *binding, const std::string &, bool); + + std::vector unpackFString(const std::string &value); + + template Tn *N(Ts &&...args) { + Tn *t = ctx->cache->N(std::forward(args)...); + t->setSrcInfo(getSrcInfo()); + return t; + } +}; + +} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/access.cpp b/codon/parser/visitors/simplify/access.cpp deleted file mode 100644 index 8779f2db..00000000 --- a/codon/parser/visitors/simplify/access.cpp +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/cache.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -void SimplifyVisitor::visit(IdExpr *expr) { - auto val = ctx->findDominatingBinding(expr->value); - - if (!val && ctx->getBase()->pyCaptures) { - ctx->getBase()->pyCaptures->insert(expr->value); - resultExpr = N(N("__pyenv__"), N(expr->value)); - return; - } else if (!val) { - E(Error::ID_NOT_FOUND, expr, expr->value); - } - - // If we are accessing an outside variable, capture it or raise an error - auto captured = checkCapture(val); - if (captured) - val = ctx->forceFind(expr->value); - - // Track loop variables to dominate them later. Example: - // x = 1 - // while True: - // if x > 10: break - // x = x + 1 # x must be dominated after the loop to ensure that it gets updated - if (ctx->getBase()->getLoop()) { - for (size_t li = ctx->getBase()->loops.size(); li-- > 0;) { - auto &loop = ctx->getBase()->loops[li]; - bool inside = val->scope.size() >= loop.scope.size() && - val->scope[loop.scope.size() - 1] == loop.scope.back(); - if (!inside) - loop.seenVars.insert(expr->value); - else - break; - } - } - - // Replace the variable with its canonical name - expr->value = val->canonicalName; - - // Mark global as "seen" to prevent later creation of local variables - // with the same name. Example: - // x = 1 - // def foo(): - // print(x) # mark x as seen - // x = 2 # so that this is an error - if (!val->isGeneric() && ctx->isOuter(val) && - in(ctx->cache->reverseIdentifierLookup, val->canonicalName) && - !in(ctx->seenGlobalIdentifiers[ctx->getBaseName()], - ctx->cache->rev(val->canonicalName))) { - ctx->seenGlobalIdentifiers[ctx->getBaseName()] - [ctx->cache->rev(val->canonicalName)] = expr->clone(); - } - - // Flag the expression as a type expression if it points to a class or a generic - if (val->isType()) - expr->markType(); - - // Variable binding check for variables that are defined within conditional blocks - if (!val->accessChecked.empty()) { - bool checked = false; - for (auto &a : val->accessChecked) { - if (a.size() <= ctx->scope.blocks.size() && - a[a.size() - 1] == ctx->scope.blocks[a.size() - 1]) { - checked = true; - break; - } - } - if (!checked) { - // Prepend access with __internal__.undef([var]__used__, "[var name]") - auto checkStmt = N(N( - N("__internal__.undef"), - N(fmt::format("{}.__used__", val->canonicalName)), - N(ctx->cache->reverseIdentifierLookup[val->canonicalName]))); - if (!ctx->isConditionalExpr) { - // If the expression is not conditional, we can just do the check once - prependStmts->push_back(checkStmt); - val->accessChecked.push_back(ctx->scope.blocks); - } else { - // Otherwise, this check must be always called - resultExpr = N(checkStmt, N(*expr)); - } - } - } -} - -/// Flatten imports. -/// @example -/// `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import -/// `a.B.c` -> canonical name of `c` in class `a.B` -/// `python.foo` -> internal.python._get_identifier("foo") -/// Other cases are handled during the type checking. -void SimplifyVisitor::visit(DotExpr *expr) { - // First flatten the imports: - // transform Dot(Dot(a, b), c...) to {a, b, c, ...} - std::vector chain; - Expr *root = expr; - for (; root->getDot(); root = root->getDot()->expr.get()) - chain.push_back(root->getDot()->member); - - if (auto id = root->getId()) { - // Case: a.bar.baz - chain.push_back(id->value); - std::reverse(chain.begin(), chain.end()); - auto p = getImport(chain); - - if (!p.second) { - seqassert(ctx->getBase()->pyCaptures, "unexpected py capture"); - ctx->getBase()->pyCaptures->insert(chain[0]); - resultExpr = N(N("__pyenv__"), N(chain[0])); - } else if (p.second->getModule() == "std.python") { - resultExpr = transform(N( - N(N(N("internal"), "python"), "_get_identifier"), - N(chain[p.first++]))); - } else if (p.second->getModule() == ctx->getModule() && p.first == 1) { - resultExpr = transform(N(chain[0]), true); - } else { - resultExpr = N(p.second->canonicalName); - if (p.second->isType() && p.first == chain.size()) - resultExpr->markType(); - } - for (auto i = p.first; i < chain.size(); i++) - resultExpr = N(resultExpr, chain[i]); - } else { - // Case: a[x].foo.bar - transform(expr->expr, true); - } -} - -/// Access identifiers from outside of the current function/class scope. -/// Either use them as-is (globals), capture them if allowed (nonlocals), -/// or raise an error. -bool SimplifyVisitor::checkCapture(const SimplifyContext::Item &val) { - if (!ctx->isOuter(val)) - return false; - if ((val->isType() && !val->isGeneric()) || val->isFunc()) - return false; - - // Ensure that outer variables can be captured (i.e., do not cross no-capture - // boundary). Example: - // def foo(): - // x = 1 - // class T: # <- boundary (classes cannot capture locals) - // t: int = x # x cannot be accessed - // def bar(): # <- another boundary - // # (class methods cannot capture locals except class generics) - // print(x) # x cannot be accessed - bool crossCaptureBoundary = false; - bool localGeneric = val->isGeneric() && val->getBaseName() == ctx->getBaseName(); - bool parentClassGeneric = - val->isGeneric() && !ctx->getBase()->isType() && - (ctx->bases.size() > 1 && ctx->bases[ctx->bases.size() - 2].isType() && - ctx->bases[ctx->bases.size() - 2].name == val->getBaseName()); - auto i = ctx->bases.size(); - for (; i-- > 0;) { - if (ctx->bases[i].name == val->getBaseName()) - break; - if (!localGeneric && !parentClassGeneric && !ctx->bases[i].captures) - crossCaptureBoundary = true; - } - - // Mark methods (class functions that access class generics) - if (parentClassGeneric) - ctx->getBase()->attributes->set(Attr::Method); - - // Ignore generics - if (parentClassGeneric || localGeneric) - return false; - - // Case: a global variable that has not been marked with `global` statement - if (val->isVar() && val->getBaseName().empty() && val->scope.size() == 1) { - val->noShadow = true; - if (!val->isStatic()) - ctx->cache->addGlobal(val->canonicalName); - return false; - } - - // Check if a real variable (not a static) is defined outside the current scope - if (crossCaptureBoundary) - E(Error::ID_CANNOT_CAPTURE, getSrcInfo(), ctx->cache->rev(val->canonicalName)); - - // Case: a nonlocal variable that has not been marked with `nonlocal` statement - // and capturing is enabled - auto captures = ctx->getBase()->captures; - if (captures && !in(*captures, val->canonicalName)) { - // Captures are transformed to function arguments; generate new name for that - // argument - ExprPtr typ = nullptr; - if (val->isType()) - typ = N("type"); - if (auto st = val->isStatic()) - typ = N(N("Static"), - N(st == StaticValue::INT ? "int" : "str")); - auto [newName, _] = (*captures)[val->canonicalName] = { - ctx->generateCanonicalName(val->canonicalName), typ}; - ctx->cache->reverseIdentifierLookup[newName] = newName; - // Add newly generated argument to the context - std::shared_ptr newVal = nullptr; - if (val->isType()) - newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); - else - newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); - newVal->baseName = ctx->getBaseName(); - newVal->noShadow = true; - newVal->scope = ctx->getBase()->scope; - return true; - } - - // Case: a nonlocal variable that has not been marked with `nonlocal` statement - // and capturing is *not* enabled - E(Error::ID_NONLOCAL, getSrcInfo(), ctx->cache->rev(val->canonicalName)); - return false; -} - -/// Check if a access chain (a.b.c.d...) contains an import or class prefix. -std::pair -SimplifyVisitor::getImport(const std::vector &chain) { - size_t importEnd = 0; - std::string importName; - - // Find the longest prefix that corresponds to the existing import - // (e.g., `a.b.c.d` -> `a.b.c` if there is `import a.b.c`) - SimplifyContext::Item val = nullptr, importVal = nullptr; - for (auto i = chain.size(); i-- > 0;) { - val = ctx->find(join(chain, "/", 0, i + 1)); - if (val && val->isImport()) { - importVal = val; - importName = val->importPath, importEnd = i + 1; - break; - } - } - - if (importEnd != chain.size()) { // false when a.b.c points to import itself - // Find the longest prefix that corresponds to the existing class - // (e.g., `a.b.c` -> `a.b` if there is `class a: class b:`) - std::string itemName; - size_t itemEnd = 0; - auto fctx = importName.empty() ? ctx : ctx->cache->imports[importName].ctx; - for (auto i = chain.size(); i-- > importEnd;) { - if (fctx->getModule() == "std.python" && importEnd < chain.size()) { - // Special case: importing from Python. - // Fake SimplifyItem that indicates std.python access - val = std::make_shared(SimplifyItem::Var, "", "", - fctx->getModule(), std::vector{}); - return {importEnd, val}; - } else { - val = fctx->find(join(chain, ".", importEnd, i + 1)); - if (val && i + 1 != chain.size() && val->isImport()) { - importVal = val; - importName = val->importPath; - importEnd = i + 1; - fctx = ctx->cache->imports[importName].ctx; - i = chain.size(); - continue; - } - if (val && (importName.empty() || val->isType() || !val->isConditional())) { - itemName = val->canonicalName, itemEnd = i + 1; - break; - } - } - } - if (itemName.empty() && importName.empty()) { - if (ctx->getBase()->pyCaptures) - return {1, nullptr}; - E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]); - } else if (itemName.empty()) { - if (!ctx->isStdlibLoading && endswith(importName, "__init__.codon")) { - auto import = ctx->cache->imports[importName]; - auto file = - getImportFile(ctx->cache->argv0, chain[importEnd], importName, false, - ctx->cache->module0, ctx->cache->pluginImportPaths); - if (file) { - auto s = SimplifyVisitor(import.ctx, preamble) - .transform(N(N(chain[importEnd]), nullptr)); - prependStmts->push_back(s); - return getImport(chain); - } - } - - E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd], - ctx->cache->imports[importName].moduleName); - } - importEnd = itemEnd; - } - return {importEnd, val}; -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/assign.cpp b/codon/parser/visitors/simplify/assign.cpp deleted file mode 100644 index 91cdf234..00000000 --- a/codon/parser/visitors/simplify/assign.cpp +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// Transform walrus (assignment) expression. -/// @example -/// `(expr := var)` -> `var = expr; var` -void SimplifyVisitor::visit(AssignExpr *expr) { - seqassert(expr->var->getId(), "only simple assignment expression are supported"); - StmtPtr s = N(clone(expr->var), expr->expr); - auto avoidDomination = false; // walruses always leak - std::swap(avoidDomination, ctx->avoidDomination); - if (ctx->isConditionalExpr) { - // Make sure to transform both suite _AND_ the expression in the same scope - ctx->enterConditionalBlock(); - transform(s); - transform(expr->var); - SuiteStmt *suite = s->getSuite(); - if (!suite) { - s = N(s); - suite = s->getSuite(); - } - ctx->leaveConditionalBlock(&suite->stmts); - } else { - s = transform(s); - transform(expr->var); - } - std::swap(avoidDomination, ctx->avoidDomination); - resultExpr = N(std::vector{s}, expr->var); -} - -/// Unpack assignments. -/// See @c transformAssignment and @c unpackAssignments for more details. -void SimplifyVisitor::visit(AssignStmt *stmt) { - std::vector stmts; - if (stmt->rhs && stmt->rhs->getBinary() && stmt->rhs->getBinary()->inPlace) { - // Update case: a += b - seqassert(!stmt->type, "invalid AssignStmt {}", stmt->toString()); - stmts.push_back(transformAssignment(stmt->lhs, stmt->rhs, nullptr, true)); - } else if (stmt->type) { - // Type case: `a: T = b, c` (no unpacking) - stmts.push_back(transformAssignment(stmt->lhs, stmt->rhs, stmt->type)); - } else { - // Normal case - unpackAssignments(stmt->lhs, stmt->rhs, stmts); - } - resultStmt = stmts.size() == 1 ? stmts[0] : N(stmts); -} - -/// Transform deletions. -/// @example -/// `del a` -> `a = type(a)()` and remove `a` from the context -/// `del a[x]` -> `a.__delitem__(x)` -void SimplifyVisitor::visit(DelStmt *stmt) { - if (auto idx = stmt->expr->getIndex()) { - resultStmt = N( - transform(N(N(idx->expr, "__delitem__"), idx->index))); - } else if (auto ei = stmt->expr->getId()) { - // Assign `a` to `type(a)()` to mark it for deletion - resultStmt = N( - transform(clone(stmt->expr)), - transform(N(N(N("type"), clone(stmt->expr))))); - resultStmt->getAssign()->setUpdate(); - - // Allow deletion *only* if the binding is dominated - auto val = ctx->find(ei->value); - if (!val) - E(Error::ID_NOT_FOUND, ei, ei->value); - if (ctx->scope.blocks != val->scope) - E(Error::DEL_NOT_ALLOWED, ei, ei->value); - ctx->remove(ei->value); - } else { - E(Error::DEL_INVALID, stmt); - } -} - -/// Transform simple assignments. -/// @example -/// `a[x] = b` -> `a.__setitem__(x, b)` -/// `a.x = b` -> @c AssignMemberStmt -/// `a: type` = b -> @c AssignStmt -/// `a = b` -> @c AssignStmt or @c UpdateStmt (see below) -StmtPtr SimplifyVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr type, - bool mustExist) { - if (auto idx = lhs->getIndex()) { - // Case: a[x] = b - seqassert(!type, "unexpected type annotation"); - if (auto b = rhs->getBinary()) { - if (mustExist && b->inPlace && !b->rexpr->getId()) { - auto var = ctx->cache->getTemporaryVar("assign"); - seqassert(rhs->getBinary(), "not a bin"); - return transform(N( - N(N(var), idx->index), - N(N( - N(idx->expr, "__setitem__"), N(var), - N(N(idx->expr->clone(), N(var)), b->op, - b->rexpr, true))))); - } - } - return transform(N( - N(N(idx->expr, "__setitem__"), idx->index, rhs))); - } - - if (auto dot = lhs->getDot()) { - // Case: a.x = b - seqassert(!type, "unexpected type annotation"); - transform(dot->expr, true); - // If we are deducing class members, check if we can deduce a member from this - // assignment - auto deduced = ctx->getClassBase() ? ctx->getClassBase()->deducedMembers : nullptr; - if (deduced && dot->expr->isId(ctx->getBase()->selfName) && - !in(*deduced, dot->member)) - deduced->push_back(dot->member); - return N(dot->expr, dot->member, transform(rhs)); - } - - // Case: a (: t) = b - auto e = lhs->getId(); - if (!e) - E(Error::ASSIGN_INVALID, lhs); - - // Disable creation of local variables that share the name with some global if such - // global was already accessed within the current scope. Example: - // x = 1 - // def foo(): - // print(x) # x is seen here - // x = 2 # this should error - if (in(ctx->seenGlobalIdentifiers[ctx->getBaseName()], e->value)) - E(Error::ASSIGN_LOCAL_REFERENCE, - ctx->seenGlobalIdentifiers[ctx->getBaseName()][e->value], e->value); - - auto val = ctx->find(e->value); - // Make sure that existing values that cannot be shadowed (e.g. imported globals) are - // only updated - mustExist |= val && val->noShadow && !ctx->isOuter(val); - if (mustExist) { - val = ctx->findDominatingBinding(e->value); - if (val && val->isVar() && !ctx->isOuter(val)) { - auto s = N(transform(lhs, false), transform(rhs)); - if (ctx->getBase()->attributes && ctx->getBase()->attributes->has(Attr::Atomic)) - s->setAtomicUpdate(); - else - s->setUpdate(); - return s; - } else { - E(Error::ASSIGN_LOCAL_REFERENCE, e, e->value); - } - } - - transform(rhs, true); - transformType(type, false); - - // Generate new canonical variable name for this assignment and add it to the context - auto canonical = ctx->generateCanonicalName(e->value); - auto assign = N(N(canonical), rhs, type); - val = nullptr; - if (rhs && rhs->isType()) { - val = ctx->addType(e->value, canonical, lhs->getSrcInfo()); - } else { - val = ctx->addVar(e->value, canonical, lhs->getSrcInfo()); - if (auto st = getStaticGeneric(type.get())) - val->staticType = st; - if (ctx->avoidDomination) - val->avoidDomination = true; - } - // Clean up seen tags if shadowing a name - ctx->seenGlobalIdentifiers[ctx->getBaseName()].erase(e->value); - - // Register all toplevel variables as global in JIT mode - bool isGlobal = (ctx->cache->isJit && val->isGlobal() && !val->isGeneric()) || - (canonical == VAR_ARGV); - if (isGlobal && !val->isGeneric()) - ctx->cache->addGlobal(canonical); - - return assign; -} - -/// Unpack an assignment expression `lhs = rhs` into a list of simple assignment -/// expressions (e.g., `a = b`, `a.x = b`, or `a[x] = b`). -/// Handle Python unpacking rules. -/// @example -/// `(a, b) = c` -> `a = c[0]; b = c[1]` -/// `a, b = c` -> `a = c[0]; b = c[1]` -/// `[a, *x, b] = c` -> `a = c[0]; x = c[1:-1]; b = c[-1]`. -/// Non-trivial right-hand expressions are first stored in a temporary variable. -/// @example -/// `a, b = c, d + foo()` -> `assign = (c, d + foo); a = assign[0]; b = assign[1]`. -/// Each assignment is unpacked recursively to allow cases like `a, (b, c) = d`. -void SimplifyVisitor::unpackAssignments(const ExprPtr &lhs, ExprPtr rhs, - std::vector &stmts) { - std::vector leftSide; - if (auto et = lhs->getTuple()) { - // Case: (a, b) = ... - for (auto &i : et->items) - leftSide.push_back(i); - } else if (auto el = lhs->getList()) { - // Case: [a, b] = ... - for (auto &i : el->items) - leftSide.push_back(i); - } else { - // Case: simple assignment (a = b, a.x = b, or a[x] = b) - stmts.push_back(transformAssignment(clone(lhs), clone(rhs))); - return; - } - - // Prepare the right-side expression - auto srcPos = rhs->getSrcInfo(); - if (!rhs->getId()) { - // Store any non-trivial right-side expression into a variable - auto var = ctx->cache->getTemporaryVar("assign"); - ExprPtr newRhs = N(srcPos, var); - stmts.push_back(transformAssignment(newRhs, clone(rhs))); - rhs = newRhs; - } - - // Process assignments until the fist StarExpr (if any) - size_t st = 0; - for (; st < leftSide.size(); st++) { - if (leftSide[st]->getStar()) - break; - // Transformation: `leftSide_st = rhs[st]` where `st` is static integer - auto rightSide = N(srcPos, clone(rhs), N(srcPos, st)); - // Recursively process the assignment because of cases like `(a, (b, c)) = d)` - unpackAssignments(leftSide[st], rightSide, stmts); - } - // Process StarExpr (if any) and the assignments that follow it - if (st < leftSide.size() && leftSide[st]->getStar()) { - // StarExpr becomes SliceExpr (e.g., `b` in `(a, *b, c) = d` becomes `d[1:-2]`) - auto rightSide = N( - srcPos, clone(rhs), - N(srcPos, N(srcPos, st), - // this slice is either [st:] or [st:-lhs_len + st + 1] - leftSide.size() == st + 1 - ? nullptr - : N(srcPos, -leftSide.size() + st + 1), - nullptr)); - unpackAssignments(leftSide[st]->getStar()->what, rightSide, stmts); - st += 1; - // Process remaining assignments. They will use negative indices (-1, -2 etc.) - // because we do not know how big is StarExpr - for (; st < leftSide.size(); st++) { - if (leftSide[st]->getStar()) - E(Error::ASSIGN_MULTI_STAR, leftSide[st]); - rightSide = N(srcPos, clone(rhs), - N(srcPos, -int(leftSide.size() - st))); - unpackAssignments(leftSide[st], rightSide, stmts); - } - } -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/basic.cpp b/codon/parser/visitors/simplify/basic.cpp deleted file mode 100644 index 32e45020..00000000 --- a/codon/parser/visitors/simplify/basic.cpp +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/cache.h" -#include "codon/parser/common.h" -#include "codon/parser/peg/peg.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// See @c transformInt -void SimplifyVisitor::visit(IntExpr *expr) { resultExpr = transformInt(expr); } - -/// See @c transformFloat -void SimplifyVisitor::visit(FloatExpr *expr) { resultExpr = transformFloat(expr); } - -/// Concatenate string sequence (e.g., `"a" "b" "c"`) into a single string. -/// Parse f-strings and custom prefix strings. -/// Also see @c transformFString -void SimplifyVisitor::visit(StringExpr *expr) { - std::vector exprs; - std::vector concat; - for (auto &p : expr->strings) { - if (p.second == "f" || p.second == "F") { - /// Transform an F-string - exprs.push_back(transformFString(p.first)); - } else if (!p.second.empty()) { - /// Custom prefix strings: - /// call `str.__prefix_[prefix]__(str, [static length of str])` - exprs.push_back( - transform(N(N("str", format("__prefix_{}__", p.second)), - N(p.first), N(p.first.size())))); - } else { - exprs.push_back(N(p.first)); - concat.push_back(p.first); - } - } - if (concat.size() == expr->strings.size()) { - /// Simple case: statically concatenate a sequence of strings without any prefix - expr->strings = {{combine2(concat, ""), ""}}; - } else if (exprs.size() == 1) { - /// Simple case: only one string in a sequence - resultExpr = std::move(exprs[0]); - } else { - /// Complex case: call `str.cat(str1, ...)` - resultExpr = transform(N(N("str", "cat"), exprs)); - } -} - -/**************************************************************************************/ - -/// Parse various integer representations depending on the integer suffix. -/// @example -/// `123u` -> `UInt[64](123)` -/// `123i56` -> `Int[56](123)` -/// `123pf` -> `int.__suffix_pf__(123)` -ExprPtr SimplifyVisitor::transformInt(IntExpr *expr) { - if (!expr->intValue) { - /// TODO: currently assumes that ints are always 64-bit. - /// Should use str constructors if available for ints with a suffix instead. - E(Error::INT_RANGE, expr, expr->value); - } - - /// Handle fixed-width integers: suffixValue is a pointer to NN if the suffix - /// is `uNNN` or `iNNN`. - std::unique_ptr suffixValue = nullptr; - if (expr->suffix.size() > 1 && (expr->suffix[0] == 'u' || expr->suffix[0] == 'i') && - isdigit(expr->suffix.substr(1))) { - try { - suffixValue = std::make_unique(std::stoi(expr->suffix.substr(1))); - } catch (...) { - } - if (suffixValue && *suffixValue > MAX_INT_WIDTH) - suffixValue = nullptr; - } - - if (expr->suffix.empty()) { - // A normal integer (int64_t) - return N(*(expr->intValue)); - } else if (expr->suffix == "u") { - // Unsigned integer: call `UInt[64](value)` - return transform(N(N(N("UInt"), N(64)), - N(*(expr->intValue)))); - } else if (suffixValue) { - // Fixed-width numbers (with `uNNN` and `iNNN` suffixes): - // call `UInt[NNN](value)` or `Int[NNN](value)` - return transform( - N(N(N(expr->suffix[0] == 'u' ? "UInt" : "Int"), - N(*suffixValue)), - N(*(expr->intValue)))); - } else { - // Custom suffix: call `int.__suffix_[suffix]__(value)` - return transform( - N(N("int", format("__suffix_{}__", expr->suffix)), - N(*(expr->intValue)))); - } -} - -/// Parse various float representations depending on the suffix. -/// @example -/// `123.4pf` -> `float.__suffix_pf__(123.4)` -ExprPtr SimplifyVisitor::transformFloat(FloatExpr *expr) { - if (!expr->floatValue) { - /// TODO: currently assumes that floats are always 64-bit. - /// Should use str constructors if available for floats with suffix instead. - E(Error::FLOAT_RANGE, expr, expr->value); - } - - if (expr->suffix.empty()) { - /// A normal float (double) - return N(*(expr->floatValue)); - } else { - // Custom suffix: call `float.__suffix_[suffix]__(value)` - return transform( - N(N("float", format("__suffix_{}__", expr->suffix)), - N(*(expr->floatValue)))); - } -} - -/// Parse a Python-like f-string into a concatenation: -/// `f"foo {x+1} bar"` -> `str.cat("foo ", str(x+1), " bar")` -/// Supports "{x=}" specifier (that prints the raw expression as well): -/// `f"{x+1=}"` -> `str.cat("x+1=", str(x+1))` -ExprPtr SimplifyVisitor::transformFString(const std::string &value) { - // Strings to be concatenated - std::vector items; - int braceCount = 0, braceStart = 0; - for (int i = 0; i < value.size(); i++) { - if (value[i] == '{') { - if (braceStart < i) - items.push_back(N(value.substr(braceStart, i - braceStart))); - if (!braceCount) - braceStart = i + 1; - braceCount++; - } else if (value[i] == '}') { - braceCount--; - if (!braceCount) { - std::string code = value.substr(braceStart, i - braceStart); - auto offset = getSrcInfo(); - offset.col += i; - if (!code.empty() && code.back() == '=') { - // Special case: f"{x=}" - code = code.substr(0, code.size() - 1); - items.push_back(N(fmt::format("{}=", code))); - } - auto [expr, format] = parseExpr(ctx->cache, code, offset); - if (!format.empty()) { - items.push_back( - N(N(expr, "__format__"), N(format))); - } else { - // Every expression is wrapped within `str` - items.push_back(N(N("str"), expr)); - } - } - braceStart = i + 1; - } - } - if (braceCount > 0) - E(Error::STR_FSTRING_BALANCE_EXTRA, getSrcInfo()); - if (braceCount < 0) - E(Error::STR_FSTRING_BALANCE_MISSING, getSrcInfo()); - if (braceStart != value.size()) - items.push_back(N(value.substr(braceStart, value.size() - braceStart))); - return transform(N(N("str", "cat"), items)); -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/call.cpp b/codon/parser/visitors/simplify/call.cpp deleted file mode 100644 index 36c25d71..00000000 --- a/codon/parser/visitors/simplify/call.cpp +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// Transform print statement. -/// @example -/// `print a, b` -> `print(a, b)` -/// `print a, b,` -> `print(a, b, end=' ')` -void SimplifyVisitor::visit(PrintStmt *stmt) { - std::vector args; - for (auto &i : stmt->items) - args.emplace_back("", transform(i)); - if (stmt->isInline) - args.emplace_back("end", N(" ")); - resultStmt = N(N(transform(N("print")), args)); -} - -/// Transform calls. The real stuff happens during the type checking. -/// Here just perform some sanity checks and transform some special calls -/// (see @c transformSpecialCall for details). -void SimplifyVisitor::visit(CallExpr *expr) { - transform(expr->expr, true); - if ((resultExpr = transformSpecialCall(expr->expr, expr->args))) - return; - - for (auto &i : expr->args) { - if (auto el = i.value->getEllipsis()) { - if (&(i) == &(expr->args.back()) && i.name.empty()) - el->mode = EllipsisExpr::PARTIAL; - } - transform(i.value, true); - } -} - -/// Simplify the following special call expressions: -/// `tuple(i for i in tup)` (tuple generators) -/// `std.collections.namedtuple` (sugar for @tuple class) -/// `std.functools.partial` (sugar for partial calls) -/// Check validity of `type()` call. See below for more details. -ExprPtr SimplifyVisitor::transformSpecialCall(const ExprPtr &callee, - const std::vector &args) { - if (callee->isId("tuple") && args.size() == 1 && - CAST(args.front().value, GeneratorExpr)) { - // tuple(i for i in j) - return transformTupleGenerator(args); - } else if (callee->isId("type") && !ctx->allowTypeOf) { - // type(i) - E(Error::CALL_NO_TYPE, getSrcInfo()); - } else if (callee->isId("std.collections.namedtuple")) { - // namedtuple('Foo', ['x', 'y']) - return transformNamedTuple(args); - } else if (callee->isId("std.functools.partial")) { - // partial(foo, a=5) - return transformFunctoolsPartial(args); - } - return nullptr; -} - -/// Transform `tuple(i for i in tup)` into a GeneratorExpr that will be handled during -/// the type checking. -ExprPtr -SimplifyVisitor::transformTupleGenerator(const std::vector &args) { - GeneratorExpr *g = nullptr; - // We currently allow only a simple iterations over tuples - if (args.size() != 1 || !(g = CAST(args[0].value, GeneratorExpr)) || - g->kind != GeneratorExpr::Generator || g->loops.size() != 1 || - !g->loops[0].conds.empty()) - E(Error::CALL_TUPLE_COMPREHENSION, args[0].value); - auto var = clone(g->loops[0].vars); - auto ex = clone(g->expr); - - ctx->enterConditionalBlock(); - ctx->getBase()->loops.push_back({"", ctx->scope.blocks, {}}); - if (auto i = var->getId()) { - ctx->addVar(i->value, ctx->generateCanonicalName(i->value), var->getSrcInfo()); - var = transform(var); - ex = transform(ex); - } else { - std::string varName = ctx->cache->getTemporaryVar("for"); - ctx->addVar(varName, varName, var->getSrcInfo()); - var = N(varName); - auto head = transform(N(clone(g->loops[0].vars), clone(var))); - ex = N(head, transform(ex)); - } - ctx->leaveConditionalBlock(); - // Dominate loop variables - for (auto &var : ctx->getBase()->getLoop()->seenVars) - ctx->findDominatingBinding(var); - ctx->getBase()->loops.pop_back(); - return N( - GeneratorExpr::Generator, ex, - std::vector{{var, transform(g->loops[0].gen), {}}}); -} - -/// Transform named tuples. -/// @example -/// `namedtuple("NT", ["a", ("b", int)])` -> ```@tuple -/// class NT[T1]: -/// a: T1 -/// b: int``` -ExprPtr SimplifyVisitor::transformNamedTuple(const std::vector &args) { - // Ensure that namedtuple call is valid - if (args.size() != 2 || !args[0].value->getString() || !args[1].value->getList()) - E(Error::CALL_NAMEDTUPLE, getSrcInfo()); - - // Construct the class statement - std::vector generics, params; - int ti = 1; - for (auto &i : args[1].value->getList()->items) { - if (auto s = i->getString()) { - generics.emplace_back(Param{format("T{}", ti), N("type"), nullptr, true}); - params.emplace_back( - Param{s->getValue(), N(format("T{}", ti++)), nullptr}); - } else if (i->getTuple() && i->getTuple()->items.size() == 2 && - i->getTuple()->items[0]->getString()) { - params.emplace_back(Param{i->getTuple()->items[0]->getString()->getValue(), - transformType(i->getTuple()->items[1]), nullptr}); - } else { - E(Error::CALL_NAMEDTUPLE, i); - } - } - for (auto &g : generics) - params.push_back(g); - auto name = args[0].value->getString()->getValue(); - prependStmts->push_back(transform( - N(name, params, nullptr, std::vector{N("tuple")}))); - return transformType(N(name)); -} - -/// Transform partial calls (Python syntax). -/// @example -/// `partial(foo, 1, a=2)` -> `foo(1, a=2, ...)` -ExprPtr SimplifyVisitor::transformFunctoolsPartial(std::vector args) { - if (args.empty()) - E(Error::CALL_PARTIAL, getSrcInfo()); - auto name = clone(args[0].value); - args.erase(args.begin()); - args.emplace_back("", N(EllipsisExpr::PARTIAL)); - return transform(N(name, args)); -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/class.cpp b/codon/parser/visitors/simplify/class.cpp deleted file mode 100644 index 8865e150..00000000 --- a/codon/parser/visitors/simplify/class.cpp +++ /dev/null @@ -1,676 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/cache.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// Transform class and type definitions, as well as extensions. -/// See below for details. -void SimplifyVisitor::visit(ClassStmt *stmt) { - // Get root name - std::string name = stmt->name; - - // Generate/find class' canonical name (unique ID) and AST - std::string canonicalName; - std::vector &argsToParse = stmt->args; - - // classItem will be added later when the scope is different - auto classItem = std::make_shared(SimplifyItem::Type, "", "", - ctx->getModule(), ctx->scope.blocks); - classItem->setSrcInfo(stmt->getSrcInfo()); - if (!stmt->attributes.has(Attr::Extend)) { - classItem->canonicalName = canonicalName = - ctx->generateCanonicalName(name, !stmt->attributes.has(Attr::Internal)); - // Reference types are added to the context here. - // Tuple types are added after class contents are parsed to prevent - // recursive record types (note: these are allowed for reference types) - if (!stmt->attributes.has(Attr::Tuple)) { - ctx->add(name, classItem); - ctx->addAlwaysVisible(classItem); - } - } else { - // Find the canonical name and AST of the class that is to be extended - if (!ctx->isGlobal() || ctx->isConditional()) - E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "class extension"); - auto val = ctx->find(name); - if (!val || !val->isType()) - E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); - canonicalName = val->canonicalName; - const auto &astIter = ctx->cache->classes.find(canonicalName); - if (astIter == ctx->cache->classes.end()) { - E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); - } else { - argsToParse = astIter->second.ast->args; - } - } - - std::vector clsStmts; // Will be filled later! - std::vector varStmts; // Will be filled later! - std::vector fnStmts; // Will be filled later! - std::vector addLater; - try { - // Add the class base - SimplifyContext::BaseGuard br(ctx.get(), canonicalName); - - // Parse and add class generics - std::vector args; - std::pair autoDeducedInit{nullptr, nullptr}; - if (stmt->attributes.has("deduce") && args.empty()) { - // Auto-detect generics and fields - autoDeducedInit = autoDeduceMembers(stmt, args); - } else { - // Add all generics before parent classes, fields and methods - for (auto &a : argsToParse) { - if (a.status != Param::Generic) - continue; - std::string genName, varName; - if (stmt->attributes.has(Attr::Extend)) - varName = a.name, genName = ctx->cache->rev(a.name); - else - varName = ctx->generateCanonicalName(a.name), genName = a.name; - if (auto st = getStaticGeneric(a.type.get())) { - auto val = ctx->addVar(genName, varName, a.type->getSrcInfo()); - val->generic = true; - val->staticType = st; - } else { - ctx->addType(genName, varName, a.type->getSrcInfo())->generic = true; - } - args.emplace_back(varName, transformType(clone(a.type), false), - transformType(clone(a.defaultValue), false), a.status); - if (!stmt->attributes.has(Attr::Extend) && a.status == Param::Normal) - ctx->cache->classes[canonicalName].fields.push_back( - Cache::Class::ClassField{varName, nullptr, canonicalName}); - } - } - - // Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes) - ExprPtr typeAst = N(name), transformedTypeAst = NT(canonicalName); - for (auto &a : args) { - if (a.status == Param::Generic) { - if (!typeAst->getIndex()) { - typeAst = N(N(name), N()); - transformedTypeAst = - NT(NT(canonicalName), std::vector{}); - } - typeAst->getIndex()->index->getTuple()->items.push_back(N(a.name)); - CAST(transformedTypeAst, InstantiateExpr) - ->typeParams.push_back(transform(N(a.name), true)); - } - } - - // Collect classes (and their fields) that are to be statically inherited - std::vector staticBaseASTs, baseASTs; - if (!stmt->attributes.has(Attr::Extend)) { - staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, stmt->attributes, - canonicalName); - if (ctx->cache->isJit && !stmt->baseClasses.empty()) - E(Error::CUSTOM, stmt->baseClasses[0], - "inheritance is not yet supported in JIT mode"); - parseBaseClasses(stmt->baseClasses, args, stmt->attributes, canonicalName, - transformedTypeAst); - } - - // A ClassStmt will be separated into class variable assignments, method-free - // ClassStmts (that include nested classes) and method FunctionStmts - transformNestedClasses(stmt, clsStmts, varStmts, fnStmts); - - // Collect class fields - for (auto &a : argsToParse) { - if (a.status == Param::Normal) { - if (!ClassStmt::isClassVar(a)) { - args.emplace_back(a.name, transformType(clone(a.type), false), - transform(clone(a.defaultValue), true)); - if (!stmt->attributes.has(Attr::Extend)) { - ctx->cache->classes[canonicalName].fields.push_back( - Cache::Class::ClassField{a.name, nullptr, canonicalName}); - } - } else if (!stmt->attributes.has(Attr::Extend)) { - // Handle class variables. Transform them later to allow self-references - auto name = format("{}.{}", canonicalName, a.name); - preamble->push_back(N(N(name), nullptr, nullptr)); - ctx->cache->addGlobal(name); - auto assign = N(N(name), a.defaultValue, - a.type ? a.type->getIndex()->index : nullptr); - assign->setUpdate(); - varStmts.push_back(assign); - ctx->cache->classes[canonicalName].classVars[a.name] = name; - } - } - } - - // ASTs for member arguments to be used for populating magic methods - std::vector memberArgs; - for (auto &a : args) { - if (a.status == Param::Normal) - memberArgs.push_back(a.clone()); - } - - // Parse class members (arguments) and methods - if (!stmt->attributes.has(Attr::Extend)) { - // Now that we are done with arguments, add record type to the context - if (stmt->attributes.has(Attr::Tuple)) { - // Ensure that class binding does not shadow anything. - // Class bindings cannot be dominated either - auto v = ctx->find(name); - if (v && v->noShadow) - E(Error::CLASS_INVALID_BIND, stmt, name); - ctx->add(name, classItem); - ctx->addAlwaysVisible(classItem); - } - // Create a cached AST. - stmt->attributes.module = - format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", - ctx->moduleName.module); - ctx->cache->classes[canonicalName].ast = - N(canonicalName, args, N(), stmt->attributes); - ctx->cache->classes[canonicalName].ast->baseClasses = stmt->baseClasses; - for (auto &b : staticBaseASTs) - ctx->cache->classes[canonicalName].staticParentClasses.emplace_back(b->name); - ctx->cache->classes[canonicalName].ast->validate(); - ctx->cache->classes[canonicalName].module = ctx->getModule(); - - // Codegen default magic methods - for (auto &m : stmt->attributes.magics) { - fnStmts.push_back(transform( - codegenMagic(m, typeAst, memberArgs, stmt->attributes.has(Attr::Tuple)))); - } - // Add inherited methods - for (auto &base : staticBaseASTs) { - for (auto &mm : ctx->cache->classes[base->name].methods) - for (auto &mf : ctx->cache->overloads[mm.second]) { - auto f = ctx->cache->functions[mf.name].ast; - if (!f->attributes.has("autogenerated")) { - std::string rootName; - auto &mts = ctx->cache->classes[ctx->getBase()->name].methods; - auto it = mts.find(ctx->cache->rev(f->name)); - if (it != mts.end()) - rootName = it->second; - else - rootName = ctx->generateCanonicalName(ctx->cache->rev(f->name), true); - auto newCanonicalName = - format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); - ctx->cache->overloads[rootName].push_back( - {newCanonicalName, ctx->cache->age}); - ctx->cache->reverseIdentifierLookup[newCanonicalName] = - ctx->cache->rev(f->name); - auto nf = std::dynamic_pointer_cast(f->clone()); - nf->name = newCanonicalName; - nf->attributes.parentClass = ctx->getBase()->name; - ctx->cache->functions[newCanonicalName].ast = nf; - ctx->cache->classes[ctx->getBase()->name] - .methods[ctx->cache->rev(f->name)] = rootName; - fnStmts.push_back(nf); - } - } - } - // Add auto-deduced __init__ (if available) - if (autoDeducedInit.first) - fnStmts.push_back(autoDeducedInit.first); - } - // Add class methods - for (const auto &sp : getClassMethods(stmt->suite)) - if (sp && sp->getFunction()) { - if (sp.get() != autoDeducedInit.second) { - auto &ds = sp->getFunction()->decorators; - for (auto &dc : ds) { - if (auto d = dc->getDot()) { - if (d->member == "setter" and d->expr->isId(sp->getFunction()->name) && - sp->getFunction()->args.size() == 2) { - sp->getFunction()->name = format(".set_{}", sp->getFunction()->name); - dc = nullptr; - break; - } - } - } - fnStmts.push_back(transform(sp)); - } - } - - // After popping context block, record types and nested classes will disappear. - // Store their references and re-add them to the context after popping - addLater.reserve(clsStmts.size() + 1); - for (auto &c : clsStmts) - addLater.push_back(ctx->find(c->getClass()->name)); - if (stmt->attributes.has(Attr::Tuple)) - addLater.push_back(ctx->forceFind(name)); - - // Mark functions as virtual: - auto banned = - std::set{"__init__", "__new__", "__raw__", "__tuplesize__"}; - for (auto &m : ctx->cache->classes[canonicalName].methods) { - auto method = m.first; - for (size_t mi = 1; mi < ctx->cache->classes[canonicalName].mro.size(); mi++) { - // ... in the current class - auto b = ctx->cache->classes[canonicalName].mro[mi]->getTypeName(); - if (in(ctx->cache->classes[b].methods, method) && !in(banned, method)) { - ctx->cache->classes[canonicalName].virtuals.insert(method); - } - } - for (auto &v : ctx->cache->classes[canonicalName].virtuals) { - for (size_t mi = 1; mi < ctx->cache->classes[canonicalName].mro.size(); mi++) { - // ... and in parent classes - auto b = ctx->cache->classes[canonicalName].mro[mi]->getTypeName(); - ctx->cache->classes[b].virtuals.insert(v); - } - } - } - } catch (const exc::ParserException &) { - if (!stmt->attributes.has(Attr::Tuple)) - ctx->remove(name); - ctx->cache->classes.erase(name); - throw; - } - for (auto &i : addLater) - ctx->add(ctx->cache->rev(i->canonicalName), i); - - // Extensions are not needed as the cache is already populated - if (!stmt->attributes.has(Attr::Extend)) { - auto c = ctx->cache->classes[canonicalName].ast; - seqassert(c, "not a class AST for {}", canonicalName); - clsStmts.push_back(c); - } - - clsStmts.insert(clsStmts.end(), fnStmts.begin(), fnStmts.end()); - for (auto &a : varStmts) { - // Transform class variables here to allow self-references - if (auto assign = a->getAssign()) { - transform(assign->rhs); - transformType(assign->type); - } - clsStmts.push_back(a); - } - resultStmt = N(clsStmts); -} - -/// Parse statically inherited classes. -/// Returns a list of their ASTs. Also updates the class fields. -/// @param args Class fields that are to be updated with base classes' fields. -/// @param typeAst Transformed AST for base class type (e.g., `A[T]`). -/// Only set when dealing with dynamic polymorphism. -std::vector SimplifyVisitor::parseBaseClasses( - std::vector &baseClasses, std::vector &args, const Attr &attr, - const std::string &canonicalName, const ExprPtr &typeAst) { - std::vector asts; - - // MAJOR TODO: fix MRO it to work with generic classes (maybe replacements? IDK...) - std::vector> mro{{typeAst}}; - std::vector parentClasses; - for (auto &cls : baseClasses) { - std::string name; - std::vector subs; - - // Get the base class and generic replacements (e.g., if there is Bar[T], - // Bar in Foo(Bar[int]) will have `T = int`) - transformType(cls); - if (auto i = cls->getId()) { - name = i->value; - } else if (auto e = CAST(cls, InstantiateExpr)) { - if (auto ei = e->typeExpr->getId()) { - name = ei->value; - subs = e->typeParams; - } - } - - auto cachedCls = const_cast(in(ctx->cache->classes, name)); - if (!cachedCls) - E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), ctx->cache->rev(name)); - asts.push_back(cachedCls->ast.get()); - parentClasses.push_back(clone(cls)); - mro.push_back(cachedCls->mro); - - // Sanity checks - if (attr.has(Attr::Tuple) && typeAst) - E(Error::CLASS_NO_INHERIT, getSrcInfo(), "tuple"); - if (!attr.has(Attr::Tuple) && asts.back()->attributes.has(Attr::Tuple)) - E(Error::CLASS_TUPLE_INHERIT, getSrcInfo()); - if (asts.back()->attributes.has(Attr::Internal)) - E(Error::CLASS_NO_INHERIT, getSrcInfo(), "internal"); - - // Mark parent classes as polymorphic as well. - if (typeAst) { - cachedCls->rtti = true; - } - - // Add generics first - int nGenerics = 0; - for (auto &a : asts.back()->args) - nGenerics += a.status == Param::Generic; - int si = 0; - for (auto &a : asts.back()->args) { - if (a.status == Param::Generic) { - if (si == subs.size()) - E(Error::GENERICS_MISMATCH, cls, ctx->cache->rev(asts.back()->name), - nGenerics, subs.size()); - args.emplace_back(a.name, a.type, transformType(subs[si++], false), - Param::HiddenGeneric); - } else if (a.status == Param::HiddenGeneric) { - args.emplace_back(a); - } - if (a.status != Param::Normal) { - if (auto st = getStaticGeneric(a.type.get())) { - auto val = ctx->addVar(a.name, a.name, a.type->getSrcInfo()); - val->generic = true; - val->staticType = st; - } else { - ctx->addType(a.name, a.name, a.type->getSrcInfo())->generic = true; - } - } - } - if (si != subs.size()) - E(Error::GENERICS_MISMATCH, cls, ctx->cache->rev(asts.back()->name), nGenerics, - subs.size()); - } - // Add normal fields - for (auto &ast : asts) { - int ai = 0; - for (auto &a : ast->args) { - if (a.status == Param::Normal && !ClassStmt::isClassVar(a)) { - auto name = a.name; - int i = 0; - for (auto &aa : args) - i += aa.name == a.name || startswith(aa.name, a.name + "#"); - if (i) - name = format("{}#{}", name, i); - seqassert(ctx->cache->classes[ast->name].fields[ai].name == a.name, - "bad class fields: {} vs {}", - ctx->cache->classes[ast->name].fields[ai].name, a.name); - args.emplace_back(name, a.type, a.defaultValue); - ctx->cache->classes[canonicalName].fields.push_back(Cache::Class::ClassField{ - name, nullptr, ctx->cache->classes[ast->name].fields[ai].baseClass}); - ai++; - } - } - } - if (typeAst) { - if (!parentClasses.empty()) { - mro.push_back(parentClasses); - ctx->cache->classes[canonicalName].rtti = true; - } - ctx->cache->classes[canonicalName].mro = Cache::mergeC3(mro); - if (ctx->cache->classes[canonicalName].mro.empty()) { - E(Error::CLASS_BAD_MRO, getSrcInfo()); - } else if (ctx->cache->classes[canonicalName].mro.size() > 1) { - // LOG("[mro] {} -> {}", canonicalName, ctx->cache->classes[canonicalName].mro); - } - } - return asts; -} - -/// Find the first __init__ with self parameter and use it to deduce class members. -/// Each deduced member will be treated as generic. -/// @example -/// ```@deduce -/// class Foo: -/// def __init__(self): -/// self.x, self.y = 1, 2``` -/// will result in -/// ```class Foo[T1, T2]: -/// x: T1 -/// y: T2``` -/// @return the transformed init and the pointer to the original function. -std::pair -SimplifyVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector &args) { - std::pair init{nullptr, nullptr}; - for (const auto &sp : getClassMethods(stmt->suite)) - if (sp && sp->getFunction()) { - auto f = sp->getFunction(); - if (f->name == "__init__" && !f->args.empty() && f->args[0].name == "self") { - // Set up deducedMembers that will be populated during AssignStmt evaluation - ctx->getBase()->deducedMembers = std::make_shared>(); - auto transformed = transform(sp); - transformed->getFunction()->attributes.set(Attr::RealizeWithoutSelf); - ctx->cache->functions[transformed->getFunction()->name].ast->attributes.set( - Attr::RealizeWithoutSelf); - int i = 0; - // Once done, add arguments - for (auto &m : *(ctx->getBase()->deducedMembers)) { - auto varName = ctx->generateCanonicalName(format("T{}", ++i)); - auto memberName = ctx->cache->rev(varName); - ctx->addType(memberName, varName, stmt->getSrcInfo())->generic = true; - args.emplace_back(varName, N("type"), nullptr, Param::Generic); - args.emplace_back(m, N(varName)); - ctx->cache->classes[stmt->name].fields.push_back( - Cache::Class::ClassField{m, nullptr, stmt->name}); - } - ctx->getBase()->deducedMembers = nullptr; - return {transformed, f}; - } - } - return {nullptr, nullptr}; -} - -/// Return a list of all statements within a given class suite. -/// Checks each suite recursively, and assumes that each statement is either -/// a function, a class or a docstring. -std::vector SimplifyVisitor::getClassMethods(const StmtPtr &s) { - std::vector v; - if (!s) - return v; - if (auto sp = s->getSuite()) { - for (const auto &ss : sp->stmts) - for (const auto &u : getClassMethods(ss)) - v.push_back(u); - } else if (s->getExpr() && s->getExpr()->expr->getString()) { - /// Those are doc-strings, ignore them. - } else if (!s->getFunction() && !s->getClass()) { - E(Error::CLASS_BAD_ATTR, s); - } else { - v.push_back(s); - } - return v; -} - -/// Extract nested classes and transform them before the main class. -void SimplifyVisitor::transformNestedClasses(ClassStmt *stmt, - std::vector &clsStmts, - std::vector &varStmts, - std::vector &fnStmts) { - for (const auto &sp : getClassMethods(stmt->suite)) - if (sp && sp->getClass()) { - auto origName = sp->getClass()->name; - // If class B is nested within A, it's name is always A.B, never B itself. - // Ensure that parent class name is appended - auto parentName = stmt->name; - sp->getClass()->name = fmt::format("{}.{}", parentName, origName); - auto tsp = transform(sp); - std::string name; - if (tsp->getSuite()) { - for (auto &s : tsp->getSuite()->stmts) - if (auto c = s->getClass()) { - clsStmts.push_back(s); - name = c->name; - } else if (auto a = s->getAssign()) { - varStmts.push_back(s); - } else { - fnStmts.push_back(s); - } - ctx->add(origName, ctx->forceFind(name)); - } - } -} - -/// Generate a magic method `__op__` for each magic `op` -/// described by @param typExpr and its arguments. -/// Currently generate: -/// @li Constructors: __new__, __init__ -/// @li Utilities: __raw__, __hash__, __repr__, __tuplesize__, __add__, __mul__, __len__ -/// @li Iteration: __iter__, __getitem__, __len__, __contains__ -/// @li Comparisons: __eq__, __ne__, __lt__, __le__, __gt__, __ge__ -/// @li Pickling: __pickle__, __unpickle__ -/// @li Python: __to_py__, __from_py__ -/// @li GPU: __to_gpu__, __from_gpu__, __from_gpu_new__ -/// TODO: move to Codon as much as possible -StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typExpr, - const std::vector &allArgs, - bool isRecord) { -#define I(s) N(s) -#define NS(x) N(N("__magic__"), (x)) - seqassert(typExpr, "typExpr is null"); - ExprPtr ret; - std::vector fargs; - std::vector stmts; - Attr attr; - attr.set("autogenerated"); - - std::vector args; - for (auto &a : allArgs) - args.push_back(a); - - if (op == "new") { - ret = typExpr->clone(); - if (isRecord) { - // Tuples: def __new__() -> T (internal) - for (auto &a : args) - fargs.emplace_back(a.name, clone(a.type), - a.defaultValue ? clone(a.defaultValue) - : N(clone(a.type))); - attr.set(Attr::Internal); - } else { - // Classes: def __new__() -> T - stmts.emplace_back(N(N(NS(op), typExpr->clone()))); - } - } - // else if (startswith(op, "new.")) { - // // special handle for tuple[t1, t2, ...] - // int sz = atoi(op.substr(4).c_str()); - // std::vector ts; - // for (int i = 0; i < sz; i++) { - // fargs.emplace_back(format("a{}", i + 1), I(format("T{}", i + 1))); - // ts.emplace_back(I(format("T{}", i + 1))); - // } - // for (int i = 0; i < sz; i++) { - // fargs.emplace_back(format("T{}", i + 1), I("type")); - // } - // ret = N(I(TYPE_TUPLE), ts); - // ret->markType(); - // attr.set(Attr::Internal); - // } - else if (op == "init") { - // Classes: def __init__(self: T, a1: T1, ..., aN: TN) -> None: - // self.aI = aI ... - ret = I("NoneType"); - fargs.emplace_back("self", typExpr->clone()); - for (auto &a : args) { - stmts.push_back(N(N(I("self"), a.name), I(a.name))); - fargs.emplace_back(a.name, clone(a.type), - a.defaultValue ? clone(a.defaultValue) - : N(clone(a.type))); - } - } else if (op == "raw") { - // Classes: def __raw__(self: T) - fargs.emplace_back("self", typExpr->clone()); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "tuplesize") { - // def __tuplesize__() -> int - ret = I("int"); - stmts.emplace_back(N(N(NS(op)))); - } else if (op == "getitem") { - // Tuples: def __getitem__(self: T, index: int) - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("index", I("int")); - stmts.emplace_back(N(N(NS(op), I("self"), I("index")))); - } else if (op == "iter") { - // Tuples: def __iter__(self: T) - fargs.emplace_back("self", typExpr->clone()); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "contains") { - // Tuples: def __contains__(self: T, what) -> bool - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("what", nullptr); - ret = I("bool"); - stmts.emplace_back(N(N(NS(op), I("self"), I("what")))); - } else if (op == "eq" || op == "ne" || op == "lt" || op == "le" || op == "gt" || - op == "ge") { - // def __op__(self: T, obj: T) -> bool - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("obj", typExpr->clone()); - ret = I("bool"); - stmts.emplace_back(N(N(NS(op), I("self"), I("obj")))); - } else if (op == "hash") { - // def __hash__(self: T) -> int - fargs.emplace_back("self", typExpr->clone()); - ret = I("int"); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "pickle") { - // def __pickle__(self: T, dest: Ptr[byte]) - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("dest", N(I("Ptr"), I("byte"))); - stmts.emplace_back(N(N(NS(op), I("self"), I("dest")))); - } else if (op == "unpickle") { - // def __unpickle__(src: Ptr[byte]) -> T - fargs.emplace_back("src", N(I("Ptr"), I("byte"))); - ret = typExpr->clone(); - stmts.emplace_back(N(N(NS(op), I("src"), typExpr->clone()))); - } else if (op == "len") { - // def __len__(self: T) -> int - fargs.emplace_back("self", typExpr->clone()); - ret = I("int"); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "to_py") { - // def __to_py__(self: T) -> Ptr[byte] - fargs.emplace_back("self", typExpr->clone()); - ret = N(I("Ptr"), I("byte")); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "from_py") { - // def __from_py__(src: Ptr[byte]) -> T - fargs.emplace_back("src", N(I("Ptr"), I("byte"))); - ret = typExpr->clone(); - stmts.emplace_back(N(N(NS(op), I("src"), typExpr->clone()))); - } else if (op == "to_gpu") { - // def __to_gpu__(self: T, cache) -> T - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("cache"); - ret = typExpr->clone(); - stmts.emplace_back(N(N(NS(op), I("self"), I("cache")))); - } else if (op == "from_gpu") { - // def __from_gpu__(self: T, other: T) - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("other", typExpr->clone()); - stmts.emplace_back(N(N(NS(op), I("self"), I("other")))); - } else if (op == "from_gpu_new") { - // def __from_gpu_new__(other: T) -> T - fargs.emplace_back("other", typExpr->clone()); - ret = typExpr->clone(); - stmts.emplace_back(N(N(NS(op), I("other")))); - } else if (op == "repr") { - // def __repr__(self: T) -> str - fargs.emplace_back("self", typExpr->clone()); - ret = I("str"); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "dict") { - // def __dict__(self: T) - fargs.emplace_back("self", typExpr->clone()); - stmts.emplace_back(N(N(NS(op), I("self")))); - } else if (op == "add") { - // def __add__(self, obj) - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("obj", nullptr); - stmts.emplace_back(N(N(NS(op), I("self"), I("obj")))); - } else if (op == "mul") { - // def __mul__(self, i: Static[int]) - fargs.emplace_back("self", typExpr->clone()); - fargs.emplace_back("i", N(I("Static"), I("int"))); - stmts.emplace_back(N(N(NS(op), I("self"), I("i")))); - } else { - seqassert(false, "invalid magic {}", op); - } -#undef I -#undef NS - auto t = std::make_shared(format("__{}__", op), ret, fargs, - N(stmts), attr); - t->setSrcInfo(ctx->cache->generateSrcInfo()); - return t; -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/collections.cpp b/codon/parser/visitors/simplify/collections.cpp deleted file mode 100644 index 27824640..00000000 --- a/codon/parser/visitors/simplify/collections.cpp +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; - -namespace codon::ast { - -/// Simple transformation. -/// The rest will be handled during the type-checking stage. -void SimplifyVisitor::visit(TupleExpr *expr) { - for (auto &i : expr->items) - transform(i, true); // types needed for some constructs (e.g., isinstance) -} - -/// Simple transformation. -/// The rest will be handled during the type-checking stage. -void SimplifyVisitor::visit(ListExpr *expr) { - for (auto &i : expr->items) - transform(i); -} - -/// Simple transformation. -/// The rest will be handled during the type-checking stage. -void SimplifyVisitor::visit(SetExpr *expr) { - for (auto &i : expr->items) - transform(i); -} - -/// Simple transformation. -/// The rest will be handled during the type-checking stage. -void SimplifyVisitor::visit(DictExpr *expr) { - for (auto &i : expr->items) - transform(i); -} - -/// Transform a collection comprehension to the corresponding statement expression. -/// @example (lists and sets): -/// `[i+a for i in j if a]` -> ```gen = List() -/// for i in j: if a: gen.append(i+a)``` -/// Generators are transformed to lambda calls. -/// @example -/// `(i+a for i in j if a)` -> ```def _lambda(j, a): -/// for i in j: yield i+a -/// _lambda(j, a).__iter__()``` -void SimplifyVisitor::visit(GeneratorExpr *expr) { - std::vector stmts; - - auto loops = clone_nop(expr->loops); // Clone as loops will be modified - - // List comprehension optimization: - // Use `iter.__len__()` when creating list if there is a single for loop - // without any if conditions in the comprehension - bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 && - loops[0].conds.empty(); - if (canOptimize) { - auto iter = transform(clone(loops[0].gen)); - IdExpr *id; - if (iter->getCall() && (id = iter->getCall()->expr->getId())) { - // Turn off this optimization for static items - canOptimize &= !startswith(id->value, "std.internal.types.range.staticrange"); - canOptimize &= !startswith(id->value, "statictuple"); - } - } - - SuiteStmt *prev = nullptr; - auto avoidDomination = true; - std::swap(avoidDomination, ctx->avoidDomination); - auto suite = transformGeneratorBody(loops, prev); - ExprPtr var = N(ctx->cache->getTemporaryVar("gen")); - if (expr->kind == GeneratorExpr::ListGenerator) { - // List comprehensions - std::vector args; - prev->stmts.push_back( - N(N(N(clone(var), "append"), clone(expr->expr)))); - auto noOptStmt = - N(N(clone(var), N(N("List"))), suite); - if (canOptimize) { - seqassert(suite->getSuite() && !suite->getSuite()->stmts.empty() && - CAST(suite->getSuite()->stmts[0], ForStmt), - "bad comprehension transformation"); - auto optimizeVar = ctx->cache->getTemporaryVar("i"); - auto optSuite = clone(suite); - CAST(optSuite->getSuite()->stmts[0], ForStmt)->iter = N(optimizeVar); - - auto optStmt = N( - N(N(optimizeVar), clone(expr->loops[0].gen)), - N( - clone(var), - N(N("List"), - N(N(N(optimizeVar), "__len__")))), - optSuite); - resultExpr = transform( - N(N(N("hasattr"), clone(expr->loops[0].gen), - N("__len__")), - N(optStmt, clone(var)), N(noOptStmt, var))); - } else { - resultExpr = transform(N(noOptStmt, var)); - } - } else if (expr->kind == GeneratorExpr::SetGenerator) { - // Set comprehensions - stmts.push_back( - transform(N(clone(var), N(N("Set"))))); - prev->stmts.push_back( - N(N(N(clone(var), "add"), clone(expr->expr)))); - stmts.push_back(transform(suite)); - resultExpr = N(stmts, transform(var)); - } else { - // Generators: converted to lambda functions that yield the target expression - prev->stmts.push_back(N(clone(expr->expr))); - stmts.push_back(suite); - - auto anon = makeAnonFn(stmts); - if (auto call = anon->getCall()) { - seqassert(!call->args.empty() && call->args.back().value->getEllipsis(), - "bad lambda: {}", *call); - call->args.pop_back(); - } else { - anon = N(anon); - } - resultExpr = anon; - } - std::swap(avoidDomination, ctx->avoidDomination); -} - -/// Transform a dictionary comprehension to the corresponding statement expression. -/// @example -/// `{i+a: j+1 for i in j if a}` -> ```gen = Dict() -/// for i in j: if a: gen.__setitem__(i+a, j+1)``` -void SimplifyVisitor::visit(DictGeneratorExpr *expr) { - SuiteStmt *prev = nullptr; - auto avoidDomination = true; - std::swap(avoidDomination, ctx->avoidDomination); - auto suite = transformGeneratorBody(expr->loops, prev); - - std::vector stmts; - ExprPtr var = N(ctx->cache->getTemporaryVar("gen")); - stmts.push_back(transform(N(clone(var), N(N("Dict"))))); - prev->stmts.push_back(N(N(N(clone(var), "__setitem__"), - clone(expr->key), clone(expr->expr)))); - stmts.push_back(transform(suite)); - resultExpr = N(stmts, transform(var)); - std::swap(avoidDomination, ctx->avoidDomination); -} - -/// Transforms a list of @c GeneratorBody loops to the corresponding set of for loops. -/// @example -/// `for i in j if a for k in i if a if b` -> -/// `for i in j: if a: for k in i: if a: if b: [prev]` -/// @param prev (out-argument): A pointer to the innermost block (suite) where the -/// comprehension (or generator) expression should reside -StmtPtr SimplifyVisitor::transformGeneratorBody(const std::vector &loops, - SuiteStmt *&prev) { - StmtPtr suite = N(), newSuite = nullptr; - prev = dynamic_cast(suite.get()); - for (auto &l : loops) { - newSuite = N(); - auto nextPrev = dynamic_cast(newSuite.get()); - - auto forStmt = N(l.vars->clone(), l.gen->clone(), newSuite); - prev->stmts.push_back(forStmt); - prev = nextPrev; - for (auto &cond : l.conds) { - newSuite = N(); - nextPrev = dynamic_cast(newSuite.get()); - prev->stmts.push_back(N(cond->clone(), newSuite)); - prev = nextPrev; - } - } - return suite; -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/cond.cpp b/codon/parser/visitors/simplify/cond.cpp deleted file mode 100644 index ea0927fd..00000000 --- a/codon/parser/visitors/simplify/cond.cpp +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -void SimplifyVisitor::visit(IfExpr *expr) { - // C++ call order is not defined; make sure to transform the conditional first - transform(expr->cond); - auto tmp = ctx->isConditionalExpr; - // Ensure that ifexpr and elsexpr are set as a potential short-circuit expressions. - // Needed to ensure that variables defined within these expressions are properly - // checked for their existence afterwards - // (e.g., `x` will be created within `a if cond else (x := b)` - // only if `cond` is not true) - ctx->isConditionalExpr = true; - transform(expr->ifexpr); - transform(expr->elsexpr); - ctx->isConditionalExpr = tmp; -} - -void SimplifyVisitor::visit(IfStmt *stmt) { - seqassert(stmt->cond, "invalid if statement"); - transform(stmt->cond); - // Ensure that conditional suites are marked and transformed in their own scope - transformConditionalScope(stmt->ifSuite); - transformConditionalScope(stmt->elseSuite); -} - -/// Simplify match statement by transforming it into a series of conditional statements. -/// @example -/// ```match e: -/// case pattern1: ... -/// case pattern2 if guard: ... -/// ...``` -> -/// ```_match = e -/// while True: # used to simulate goto statement with break -/// [pattern1 transformation]: (...; break) -/// [pattern2 transformation]: if guard: (...; break) -/// ... -/// break # exit the loop no matter what``` -/// The first pattern that matches the given expression will be used; other patterns -/// will not be used (i.e., there is no fall-through). See @c transformPattern for -/// pattern transformations -void SimplifyVisitor::visit(MatchStmt *stmt) { - auto var = ctx->cache->getTemporaryVar("match"); - auto result = N(); - result->stmts.push_back(N(N(var), clone(stmt->what))); - for (auto &c : stmt->cases) { - ctx->enterConditionalBlock(); - StmtPtr suite = N(clone(c.suite), N()); - if (c.guard) - suite = N(clone(c.guard), suite); - result->stmts.push_back(transformPattern(N(var), clone(c.pattern), suite)); - ctx->leaveConditionalBlock(); - } - // Make sure to break even if there is no case _ to prevent infinite loop - result->stmts.push_back(N()); - resultStmt = transform(N(N(true), result)); -} - -/// Transform a match pattern into a series of if statements. -/// @example -/// `case True` -> `if isinstance(var, "bool"): if var == True` -/// `case 1` -> `if isinstance(var, "int"): if var == 1` -/// `case 1...3` -> ```if isinstance(var, "int"): -/// if var >= 1: if var <= 3``` -/// `case (1, pat)` -> ```if isinstance(var, "Tuple"): if staticlen(var) == 2: -/// if match(var[0], 1): if match(var[1], pat)``` -/// `case [1, ..., pat]` -> ```if isinstance(var, "List"): if len(var) >= 2: -/// if match(var[0], 1): if match(var[-1], pat)``` -/// `case 1 or pat` -> `if match(var, 1): if match(var, pat)` -/// (note: pattern suite is cloned for each `or`) -/// `case (x := pat)` -> `(x = var; if match(var, pat))` -/// `case x` -> `(x := var)` -/// (only when `x` is not '_') -/// `case expr` -> `if hasattr(typeof(var), "__match__"): if -/// var.__match__(foo())` -/// (any expression that does not fit above patterns) -StmtPtr SimplifyVisitor::transformPattern(const ExprPtr &var, ExprPtr pattern, - StmtPtr suite) { - // Convenience function to generate `isinstance(e, typ)` calls - auto isinstance = [&](const ExprPtr &e, const std::string &typ) -> ExprPtr { - return N(N("isinstance"), e->clone(), N(typ)); - }; - // Convenience function to find the index of an ellipsis within a list pattern - auto findEllipsis = [&](const std::vector &items) { - size_t i = items.size(); - for (auto it = 0; it < items.size(); it++) - if (items[it]->getEllipsis()) { - if (i != items.size()) - E(Error::MATCH_MULTI_ELLIPSIS, items[it], "multiple ellipses in pattern"); - i = it; - } - return i; - }; - - // See the above examples for transformation details - if (pattern->getInt() || CAST(pattern, BoolExpr)) { - // Bool and int patterns - return N(isinstance(var, CAST(pattern, BoolExpr) ? "bool" : "int"), - N(N(var->clone(), "==", pattern), suite)); - } else if (auto er = CAST(pattern, RangeExpr)) { - // Range pattern - return N( - isinstance(var, "int"), - N( - N(var->clone(), ">=", clone(er->start)), - N(N(var->clone(), "<=", clone(er->stop)), suite))); - } else if (auto et = pattern->getTuple()) { - // Tuple pattern - for (auto it = et->items.size(); it-- > 0;) { - suite = transformPattern(N(var->clone(), N(it)), - clone(et->items[it]), suite); - } - return N( - isinstance(var, "Tuple"), - N(N(N(N("staticlen"), clone(var)), - "==", N(et->items.size())), - suite)); - } else if (auto el = pattern->getList()) { - // List pattern - auto ellipsis = findEllipsis(el->items), sz = el->items.size(); - std::string op; - if (ellipsis == el->items.size()) { - op = "=="; - } else { - op = ">=", sz -= 1; - } - for (auto it = el->items.size(); it-- > ellipsis + 1;) { - suite = transformPattern( - N(var->clone(), N(it - el->items.size())), - clone(el->items[it]), suite); - } - for (auto it = ellipsis; it-- > 0;) { - suite = transformPattern(N(var->clone(), N(it)), - clone(el->items[it]), suite); - } - return N(isinstance(var, "List"), - N(N(N(N("len"), clone(var)), - op, N(sz)), - suite)); - } else if (auto eb = pattern->getBinary()) { - // Or pattern - if (eb->op == "|" || eb->op == "||") { - return N(transformPattern(clone(var), clone(eb->lexpr), clone(suite)), - transformPattern(clone(var), clone(eb->rexpr), suite)); - } - } else if (auto ea = CAST(pattern, AssignExpr)) { - // Bound pattern - seqassert(ea->var->getId(), "only simple assignment expressions are supported"); - return N(N(clone(ea->var), clone(var)), - transformPattern(clone(var), clone(ea->expr), clone(suite))); - } else if (auto ei = pattern->getId()) { - // Wildcard pattern - if (ei->value != "_") { - return N(N(clone(pattern), clone(var)), suite); - } else { - return suite; - } - } - pattern = transform(pattern); // transform to check for pattern errors - if (pattern->getEllipsis()) - pattern = N(N("ellipsis")); - // Fallback (`__match__`) pattern - return N( - N(N("hasattr"), var->clone(), N("__match__"), - N(N("type"), pattern->clone())), - N(N(N(var->clone(), "__match__"), pattern), suite)); -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/ctx.cpp b/codon/parser/visitors/simplify/ctx.cpp deleted file mode 100644 index d381697b..00000000 --- a/codon/parser/visitors/simplify/ctx.cpp +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include "ctx.h" - -#include -#include -#include -#include - -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -SimplifyContext::SimplifyContext(std::string filename, Cache *cache) - : Context(std::move(filename)), cache(cache), isStdlibLoading(false), - moduleName{ImportFile::PACKAGE, "", ""}, isConditionalExpr(false), - allowTypeOf(true) { - bases.emplace_back(Base("")); - scope.blocks.push_back(scope.counter = 0); -} - -SimplifyContext::Base::Base(std::string name, Attr *attributes) - : name(std::move(name)), attributes(attributes), deducedMembers(nullptr), - selfName(), captures(nullptr), pyCaptures(nullptr) {} - -void SimplifyContext::add(const std::string &name, const SimplifyContext::Item &var) { - auto v = find(name); - if (v && v->noShadow) - E(Error::ID_INVALID_BIND, getSrcInfo(), name); - Context::add(name, var); -} - -SimplifyContext::Item SimplifyContext::addVar(const std::string &name, - const std::string &canonicalName, - const SrcInfo &srcInfo) { - seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); - auto t = std::make_shared(SimplifyItem::Var, getBaseName(), - canonicalName, getModule(), scope.blocks); - t->setSrcInfo(srcInfo); - Context::add(name, t); - Context::add(canonicalName, t); - return t; -} - -SimplifyContext::Item SimplifyContext::addType(const std::string &name, - const std::string &canonicalName, - const SrcInfo &srcInfo) { - seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); - auto t = std::make_shared(SimplifyItem::Type, getBaseName(), - canonicalName, getModule(), scope.blocks); - t->setSrcInfo(srcInfo); - Context::add(name, t); - Context::add(canonicalName, t); - return t; -} - -SimplifyContext::Item SimplifyContext::addFunc(const std::string &name, - const std::string &canonicalName, - const SrcInfo &srcInfo) { - seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); - auto t = std::make_shared(SimplifyItem::Func, getBaseName(), - canonicalName, getModule(), scope.blocks); - t->setSrcInfo(srcInfo); - Context::add(name, t); - Context::add(canonicalName, t); - return t; -} - -SimplifyContext::Item -SimplifyContext::addAlwaysVisible(const SimplifyContext::Item &item) { - auto i = std::make_shared(item->kind, item->baseName, - item->canonicalName, item->moduleName, - std::vector{0}, item->importPath); - auto stdlib = cache->imports[STDLIB_IMPORT].ctx; - if (!stdlib->find(i->canonicalName)) { - stdlib->add(i->canonicalName, i); - } - return i; -} - -SimplifyContext::Item SimplifyContext::find(const std::string &name) const { - auto t = Context::find(name); - if (t) - return t; - - // Item is not found in the current module. Time to look in the standard library! - // Note: the standard library items cannot be dominated. - auto stdlib = cache->imports[STDLIB_IMPORT].ctx; - if (stdlib.get() != this) - t = stdlib->find(name); - return t; -} - -SimplifyContext::Item SimplifyContext::forceFind(const std::string &name) const { - auto f = find(name); - seqassert(f, "cannot find '{}'", name); - return f; -} - -SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string &name) { - auto it = map.find(name); - if (it == map.end()) - return find(name); - seqassert(!it->second.empty(), "corrupted SimplifyContext ({})", name); - - // The item is found. Let's see is it accessible now. - - std::string canonicalName; - auto lastGood = it->second.begin(); - bool isOutside = (*lastGood)->getBaseName() != getBaseName(); - int prefix = int(scope.blocks.size()); - // Iterate through all bindings with the given name and find the closest binding that - // dominates the current scope. - for (auto i = it->second.begin(); i != it->second.end(); i++) { - // Find the longest block prefix between the binding and the current scope. - int p = std::min(prefix, int((*i)->scope.size())); - while (p >= 0 && (*i)->scope[p - 1] != scope.blocks[p - 1]) - p--; - // We reached the toplevel. Break. - if (p < 0) - break; - // We went outside the function scope. Break. - if (!isOutside && (*i)->getBaseName() != getBaseName()) - break; - bool completeDomination = - (*i)->scope.size() <= scope.blocks.size() && - (*i)->scope.back() == scope.blocks[(*i)->scope.size() - 1]; - if (!completeDomination && prefix < int(scope.blocks.size()) && prefix != p) { - break; - } - prefix = p; - lastGood = i; - // The binding completely dominates the current scope. Break. - if (completeDomination) - break; - } - seqassert(lastGood != it->second.end(), "corrupted scoping ({})", name); - if (lastGood != it->second.begin() && !(*lastGood)->isVar()) - E(Error::CLASS_INVALID_BIND, getSrcInfo(), name); - - bool hasUsed = false; - if ((*lastGood)->scope.size() == prefix) { - // The current scope is dominated by a binding. Use that binding. - canonicalName = (*lastGood)->canonicalName; - } else { - // The current scope is potentially reachable by multiple bindings that are - // not dominated by a common binding. Create such binding in the scope that - // dominates (covers) all of them. - canonicalName = generateCanonicalName(name); - auto item = std::make_shared( - (*lastGood)->kind, (*lastGood)->baseName, canonicalName, - (*lastGood)->moduleName, - std::vector(scope.blocks.begin(), scope.blocks.begin() + prefix), - (*lastGood)->importPath); - item->accessChecked = {(*lastGood)->scope}; - lastGood = it->second.insert(++lastGood, item); - stack.front().push_back(name); - // Make sure to prepend a binding declaration: `var` and `var__used__ = False` - // to the dominating scope. - scope.stmts[scope.blocks[prefix - 1]].push_back(std::make_unique( - std::make_unique(canonicalName), nullptr, nullptr)); - scope.stmts[scope.blocks[prefix - 1]].push_back(std::make_unique( - std::make_unique(fmt::format("{}.__used__", canonicalName)), - std::make_unique(false), nullptr)); - // Reached the toplevel? Register the binding as global. - if (prefix == 1) { - cache->addGlobal(canonicalName); - cache->addGlobal(fmt::format("{}.__used__", canonicalName)); - } - hasUsed = true; - } - // Remove all bindings after the dominant binding. - for (auto i = it->second.begin(); i != it->second.end(); i++) { - if (i == lastGood) - break; - if (!(*i)->canDominate()) - continue; - // These bindings (and their canonical identifiers) will be replaced by the - // dominating binding during the type checking pass. - cache->replacements[(*i)->canonicalName] = {canonicalName, hasUsed}; - cache->replacements[format("{}.__used__", (*i)->canonicalName)] = { - format("{}.__used__", canonicalName), false}; - seqassert((*i)->canonicalName != canonicalName, "invalid replacement at {}: {}", - getSrcInfo(), canonicalName); - auto it = std::find(stack.front().begin(), stack.front().end(), name); - if (it != stack.front().end()) - stack.front().erase(it); - } - it->second.erase(it->second.begin(), lastGood); - return it->second.front(); -} - -std::string SimplifyContext::getBaseName() const { return bases.back().name; } - -std::string SimplifyContext::getModule() const { - std::string base = moduleName.status == ImportFile::STDLIB ? "std." : ""; - base += moduleName.module; - if (auto sz = startswith(base, "__main__")) - base = base.substr(sz); - return base; -} - -void SimplifyContext::dump() { dump(0); } - -std::string SimplifyContext::generateCanonicalName(const std::string &name, - bool includeBase, - bool zeroId) const { - std::string newName = name; - bool alreadyGenerated = name.find('.') != std::string::npos; - if (includeBase && !alreadyGenerated) { - std::string base = getBaseName(); - if (base.empty()) - base = getModule(); - if (base == "std.internal.core") - base = ""; - newName = (base.empty() ? "" : (base + ".")) + newName; - } - auto num = cache->identifierCount[newName]++; - if (num) - newName = format("{}.{}", newName, num); - if (name != newName && !zeroId) - cache->identifierCount[newName]++; - cache->reverseIdentifierLookup[newName] = name; - return newName; -} - -void SimplifyContext::enterConditionalBlock() { - scope.blocks.push_back(++scope.counter); -} - -void SimplifyContext::leaveConditionalBlock(std::vector *stmts) { - if (stmts && in(scope.stmts, scope.blocks.back())) - stmts->insert(stmts->begin(), scope.stmts[scope.blocks.back()].begin(), - scope.stmts[scope.blocks.back()].end()); - scope.blocks.pop_back(); -} - -bool SimplifyContext::isGlobal() const { return bases.size() == 1; } - -bool SimplifyContext::isConditional() const { return scope.blocks.size() > 1; } - -SimplifyContext::Base *SimplifyContext::getBase() { - return bases.empty() ? nullptr : &(bases.back()); -} - -bool SimplifyContext::inFunction() const { - return !isGlobal() && !bases.back().isType(); -} - -bool SimplifyContext::inClass() const { return !isGlobal() && bases.back().isType(); } - -bool SimplifyContext::isOuter(const Item &val) const { - return getBaseName() != val->getBaseName() || getModule() != val->getModule(); -} - -SimplifyContext::Base *SimplifyContext::getClassBase() { - if (bases.size() >= 2 && bases[bases.size() - 2].isType()) - return &(bases[bases.size() - 2]); - return nullptr; -} - -void SimplifyContext::dump(int pad) { - auto ordered = - std::map(map.begin(), map.end()); - LOG("location: {}", getSrcInfo()); - LOG("module: {}", getModule()); - LOG("base: {}", getBaseName()); - LOG("scope: {}", fmt::join(scope.blocks, ",")); - for (auto &s : stack.front()) - LOG("-> {}", s); - for (auto &i : ordered) { - std::string s; - bool f = true; - for (auto &t : i.second) { - LOG("{}{} {} {:40} {:30} {}", std::string(pad * 2, ' '), - !f ? std::string(40, ' ') : format("{:.<40}", i.first), - (t->isFunc() ? "F" : (t->isType() ? "T" : (t->isImport() ? "I" : "V"))), - t->canonicalName, t->getBaseName(), combine2(t->scope, ",")); - f = false; - } - } -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/ctx.h b/codon/parser/visitors/simplify/ctx.h deleted file mode 100644 index f90e64de..00000000 --- a/codon/parser/visitors/simplify/ctx.h +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "codon/parser/cache.h" -#include "codon/parser/common.h" -#include "codon/parser/ctx.h" - -namespace codon::ast { - -/** - * Simplification context identifier. - * Can be either a function, a class (type), or a variable. - */ -struct SimplifyItem : public SrcObject { - /// Type of the identifier - enum Kind { Func, Type, Var } kind; - /// Base name (e.g., foo.bar.baz) - std::string baseName; - /// Unique identifier (canonical name) - std::string canonicalName; - /// Full module name - std::string moduleName; - /// Full scope information - std::vector scope; - /// Non-empty string if a variable is import variable - std::string importPath; - /// List of scopes where the identifier is accessible - /// without __used__ check - std::vector> accessChecked; - /// Set if an identifier cannot be shadowed - /// (e.g., global-marked variables) - bool noShadow = false; - /// Set if an identifier is a class or a function generic - bool generic = false; - /// Set if an identifier is a static variable. - char staticType = 0; - /// Set if an identifier should not be dominated - /// (e.g., a loop variable in a comprehension). - bool avoidDomination = false; - -public: - SimplifyItem(Kind kind, std::string baseName, std::string canonicalName, - std::string moduleName, std::vector scope, - std::string importPath = "") - : kind(kind), baseName(std::move(baseName)), - canonicalName(std::move(canonicalName)), moduleName(std::move(moduleName)), - scope(std::move(scope)), importPath(std::move(importPath)) {} - - /* Convenience getters */ - std::string getBaseName() const { return baseName; } - std::string getModule() const { return moduleName; } - bool isVar() const { return kind == Var; } - bool isFunc() const { return kind == Func; } - bool isType() const { return kind == Type; } - bool isImport() const { return !importPath.empty(); } - bool isGlobal() const { return scope.size() == 1 && baseName.empty(); } - /// True if an identifier is within a conditional block - /// (i.e., a block that might not be executed during the runtime) - bool isConditional() const { return scope.size() > 1; } - bool isGeneric() const { return generic; } - char isStatic() const { return staticType; } - /// True if an identifier is a loop variable in a comprehension - bool canDominate() const { return !avoidDomination; } -}; - -/** Context class that tracks identifiers during the simplification. **/ -struct SimplifyContext : public Context { - /// A pointer to the shared cache. - Cache *cache; - - /// Holds the information about current scope. - /// A scope is defined as a stack of conditional blocks - /// (i.e., blocks that might not get executed during the runtime). - /// Used mainly to support Python's variable scoping rules. - struct { - /// Scope counter. Each conditional block gets a new scope ID. - int counter; - /// Current hierarchy of conditional blocks. - std::vector blocks; - /// List of statements that are to be prepended to a block - /// after its transformation. - std::map> stmts; - } scope; - - /// Holds the information about current base. - /// A base is defined as a function or a class block. - struct Base { - /// Canonical name of a function or a class that owns this base. - std::string name; - /// Tracks function attributes (e.g. if it has @atomic or @test attributes). - /// Only set for functions. - Attr *attributes; - /// Set if the base is class base and if class is marked with @deduce. - /// Stores the list of class fields in the order of traversal. - std::shared_ptr> deducedMembers; - /// Canonical name of `self` parameter that is used to deduce class fields - /// (e.g., self in self.foo). - std::string selfName; - /// Map of captured identifiers (i.e., identifiers not defined in a function). - /// Captured (canonical) identifiers are mapped to the new canonical names - /// (representing the canonical function argument names that are appended to the - /// function after processing) and their types (indicating if they are a type, a - /// static or a variable). - std::unordered_map> *captures; - - /// Map of identifiers that are to be fetched from Python. - std::unordered_set *pyCaptures; - - /// Scope that defines the base. - std::vector scope; - - /// A stack of nested loops enclosing the current statement used for transforming - /// "break" statement in loop-else constructs. Each loop is defined by a "break" - /// variable created while parsing a loop-else construct. If a loop has no else - /// block, the corresponding loop variable is empty. - struct Loop { - std::string breakVar; - std::vector scope; - /// List of variables "seen" before their assignment within a loop. - /// Used to dominate variables that are updated within a loop. - std::unordered_set seenVars; - /// False if a loop has continue/break statement. Used for flattening static - /// loops. - bool flat = true; - }; - std::vector loops; - - public: - explicit Base(std::string name, Attr *attributes = nullptr); - Loop *getLoop() { return loops.empty() ? nullptr : &(loops.back()); } - bool isType() const { return attributes == nullptr; } - }; - /// Current base stack (the last enclosing base is the last base in the stack). - std::vector bases; - - struct BaseGuard { - SimplifyContext *holder; - BaseGuard(SimplifyContext *holder, const std::string &name) : holder(holder) { - holder->bases.emplace_back(Base(name)); - holder->bases.back().scope = holder->scope.blocks; - holder->addBlock(); - } - ~BaseGuard() { - holder->bases.pop_back(); - holder->popBlock(); - } - }; - - /// Set of seen global identifiers used to prevent later creation of local variables - /// with the same name. - std::unordered_map> - seenGlobalIdentifiers; - - /// Set if the standard library is currently being loaded. - bool isStdlibLoading; - /// Current module. The default module is named `__main__`. - ImportFile moduleName; - /// Tracks if we are in a dependent part of a short-circuiting expression (e.g. b in a - /// and b) to disallow assignment expressions there. - bool isConditionalExpr; - /// Allow type() expressions. Currently used to disallow type() in class - /// and function definitions. - bool allowTypeOf; - /// Set if all assignments should not be dominated later on. - bool avoidDomination = false; - -public: - SimplifyContext(std::string filename, Cache *cache); - - void add(const std::string &name, const Item &var) override; - /// Convenience method for adding an object to the context. - Item addVar(const std::string &name, const std::string &canonicalName, - const SrcInfo &srcInfo = SrcInfo()); - Item addType(const std::string &name, const std::string &canonicalName, - const SrcInfo &srcInfo = SrcInfo()); - Item addFunc(const std::string &name, const std::string &canonicalName, - const SrcInfo &srcInfo = SrcInfo()); - /// Add the item to the standard library module, thus ensuring its visibility from all - /// modules. - Item addAlwaysVisible(const Item &item); - - /// Get an item from the context. If the item does not exist, nullptr is returned. - Item find(const std::string &name) const override; - /// Get an item that exists in the context. If the item does not exist, assertion is - /// raised. - Item forceFind(const std::string &name) const; - /// Get an item from the context. Perform domination analysis for accessing items - /// defined in the conditional blocks (i.e., Python scoping). - Item findDominatingBinding(const std::string &name); - - /// Return a canonical name of the current base. - /// An empty string represents the toplevel base. - std::string getBaseName() const; - /// Return the current module. - std::string getModule() const; - /// Pretty-print the current context state. - void dump() override; - - /// Generate a unique identifier (name) for a given string. - std::string generateCanonicalName(const std::string &name, bool includeBase = false, - bool zeroId = false) const; - /// Enter a conditional block. - void enterConditionalBlock(); - /// Leave a conditional block. Populate stmts (if set) with the declarations of newly - /// added identifiers that dominate the children blocks. - void leaveConditionalBlock(std::vector *stmts = nullptr); - /// True if we are at the toplevel. - bool isGlobal() const; - /// True if we are within a conditional block. - bool isConditional() const; - /// Get the current base. - Base *getBase(); - /// True if the current base is function. - bool inFunction() const; - /// True if the current base is class. - bool inClass() const; - /// True if an item is defined outside of the current base or a module. - bool isOuter(const Item &val) const; - /// Get the enclosing class base (or nullptr if such does not exist). - Base *getClassBase(); - -private: - /// Pretty-print the current context state. - void dump(int pad); -}; - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/error.cpp b/codon/parser/visitors/simplify/error.cpp deleted file mode 100644 index 8f6c6bd6..00000000 --- a/codon/parser/visitors/simplify/error.cpp +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; - -namespace codon::ast { - -/// Transform asserts. -/// @example -/// `assert foo()` -> -/// `if not foo(): raise __internal__.seq_assert([file], [line], "")` -/// `assert foo(), msg` -> -/// `if not foo(): raise __internal__.seq_assert([file], [line], str(msg))` -/// Use `seq_assert_test` instead of `seq_assert` and do not raise anything during unit -/// testing (i.e., when the enclosing function is marked with `@test`). -void SimplifyVisitor::visit(AssertStmt *stmt) { - ExprPtr msg = N(""); - if (stmt->message) - msg = N(N("str"), clone(stmt->message)); - auto test = ctx->inFunction() && (ctx->getBase()->attributes && - ctx->getBase()->attributes->has(Attr::Test)); - auto ex = N( - N("__internal__", test ? "seq_assert_test" : "seq_assert"), - N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), msg); - auto cond = N("!", clone(stmt->expr)); - if (test) { - resultStmt = transform(N(cond, N(ex))); - } else { - resultStmt = transform(N(cond, N(ex))); - } -} - -void SimplifyVisitor::visit(TryStmt *stmt) { - transformConditionalScope(stmt->suite); - for (auto &c : stmt->catches) { - ctx->enterConditionalBlock(); - if (!c.var.empty()) { - c.var = ctx->generateCanonicalName(c.var); - ctx->addVar(ctx->cache->rev(c.var), c.var, c.suite->getSrcInfo()); - } - transform(c.exc, true); - transformConditionalScope(c.suite); - ctx->leaveConditionalBlock(); - } - transformConditionalScope(stmt->finally); -} - -void SimplifyVisitor::visit(ThrowStmt *stmt) { transform(stmt->expr); } - -/// Transform with statements. -/// @example -/// `with foo(), bar() as a: ...` -> -/// ```tmp = foo() -/// tmp.__enter__() -/// try: -/// a = bar() -/// a.__enter__() -/// try: -/// ... -/// finally: -/// a.__exit__() -/// finally: -/// tmp.__exit__()``` -void SimplifyVisitor::visit(WithStmt *stmt) { - seqassert(!stmt->items.empty(), "stmt->items is empty"); - std::vector content; - for (auto i = stmt->items.size(); i-- > 0;) { - std::string var = - stmt->vars[i].empty() ? ctx->cache->getTemporaryVar("with") : stmt->vars[i]; - content = std::vector{ - N(N(var), clone(stmt->items[i])), - N(N(N(var, "__enter__"))), - N( - !content.empty() ? N(content) : clone(stmt->suite), - std::vector{}, - N(N(N(N(var, "__exit__")))))}; - } - resultStmt = transform(N(content)); -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/function.cpp b/codon/parser/visitors/simplify/function.cpp deleted file mode 100644 index 54d084db..00000000 --- a/codon/parser/visitors/simplify/function.cpp +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/cache.h" -#include "codon/parser/common.h" -#include "codon/parser/peg/peg.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// Ensure that `(yield)` is in a function. -void SimplifyVisitor::visit(YieldExpr *expr) { - if (!ctx->inFunction()) - E(Error::FN_OUTSIDE_ERROR, expr, "yield"); - ctx->getBase()->attributes->set(Attr::IsGenerator); -} - -/// Transform lambdas. Capture outer expressions. -/// @example -/// `lambda a, b: a+b+c` -> ```def fn(a, b, c): -/// return a+b+c -/// fn(c=c, ...)``` -/// See @c makeAnonFn -void SimplifyVisitor::visit(LambdaExpr *expr) { - resultExpr = - makeAnonFn(std::vector{N(clone(expr->expr))}, expr->vars); -} - -/// Ensure that `return` is in a function. -void SimplifyVisitor::visit(ReturnStmt *stmt) { - if (!ctx->inFunction()) - E(Error::FN_OUTSIDE_ERROR, stmt, "return"); - transform(stmt->expr); -} - -/// Ensure that `yield` is in a function. -void SimplifyVisitor::visit(YieldStmt *stmt) { - if (!ctx->inFunction()) - E(Error::FN_OUTSIDE_ERROR, stmt, "yield"); - transform(stmt->expr); - ctx->getBase()->attributes->set(Attr::IsGenerator); -} - -/// Transform `yield from` statements. -/// @example -/// `yield from a` -> `for var in a: yield var` -void SimplifyVisitor::visit(YieldFromStmt *stmt) { - auto var = ctx->cache->getTemporaryVar("yield"); - resultStmt = - transform(N(N(var), stmt->expr, N(N(var)))); -} - -/// Process `global` statements. Remove them upon completion. -void SimplifyVisitor::visit(GlobalStmt *stmt) { - if (!ctx->inFunction()) - E(Error::FN_OUTSIDE_ERROR, stmt, stmt->nonLocal ? "nonlocal" : "global"); - - // Dominate the binding - auto val = ctx->findDominatingBinding(stmt->var); - if (!val || !val->isVar()) - E(Error::ID_NOT_FOUND, stmt, stmt->var); - if (val->getBaseName() == ctx->getBaseName()) - E(Error::FN_GLOBAL_ASSIGNED, stmt, stmt->var); - - // Check global/nonlocal distinction - if (!stmt->nonLocal && !val->getBaseName().empty()) - E(Error::FN_GLOBAL_NOT_FOUND, stmt, "global", stmt->var); - else if (stmt->nonLocal && val->getBaseName().empty()) - E(Error::FN_GLOBAL_NOT_FOUND, stmt, "nonlocal", stmt->var); - seqassert(!val->canonicalName.empty(), "'{}' does not have a canonical name", - stmt->var); - - // Register as global if needed - ctx->cache->addGlobal(val->canonicalName); - - val = ctx->addVar(stmt->var, val->canonicalName, stmt->getSrcInfo()); - val->baseName = ctx->getBaseName(); - // Globals/nonlocals cannot be shadowed in children scopes (as in Python) - val->noShadow = true; - // Erase the statement - resultStmt = N(); -} - -/// Validate and transform function definitions. -/// Handle overloads, class methods, default arguments etc. -/// Also capture variables if necessary and apply decorators. -/// @example -/// ```a = 5 -/// @dec -/// def foo(b): -/// return a+b -/// ``` -> ``` -/// a = 5 -/// def foo(b, a_cap): -/// return a_cap+b -/// foo = dec(foo(a_cap=a, ...)) -/// ``` -/// For Python and LLVM definition transformations, see -/// @c transformPythonDefinition and @c transformLLVMDefinition -void SimplifyVisitor::visit(FunctionStmt *stmt) { - if (stmt->attributes.has(Attr::Python)) { - // Handle Python block - resultStmt = transformPythonDefinition(stmt->name, stmt->args, stmt->ret.get(), - stmt->suite->firstInBlock()); - return; - } - - // Parse attributes - for (auto i = stmt->decorators.size(); i-- > 0;) { - if (!stmt->decorators[i]) - continue; - auto [isAttr, attrName] = getDecorator(stmt->decorators[i]); - if (!attrName.empty()) { - stmt->attributes.set(attrName); - if (isAttr) - stmt->decorators[i] = nullptr; // remove it from further consideration - } - } - - bool isClassMember = ctx->inClass(), isEnclosedFunc = ctx->inFunction(); - if (stmt->attributes.has(Attr::ForceRealize) && (!ctx->isGlobal() || isClassMember)) - E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "builtin function"); - - // All overloads share the same canonical name except for the number at the - // end (e.g., `foo.1:0`, `foo.1:1` etc.) - std::string rootName; - if (isClassMember) { - // Case 1: method overload - if (auto n = in(ctx->cache->classes[ctx->getBase()->name].methods, stmt->name)) - rootName = *n; - } else if (stmt->attributes.has(Attr::Overload)) { - // Case 2: function overload - if (auto c = ctx->find(stmt->name)) { - if (c->isFunc() && c->getModule() == ctx->getModule() && - c->getBaseName() == ctx->getBaseName()) { - rootName = c->canonicalName; - } - } - } - if (rootName.empty()) - rootName = ctx->generateCanonicalName(stmt->name, true); - // Append overload number to the name - auto canonicalName = - format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); - ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->name; - - // Ensure that function binding does not shadow anything. - // Function bindings cannot be dominated either - if (!isClassMember) { - auto funcVal = ctx->find(stmt->name); - if (funcVal && funcVal->moduleName == ctx->getModule() && funcVal->noShadow) - E(Error::CLASS_INVALID_BIND, stmt, stmt->name); - funcVal = ctx->addFunc(stmt->name, rootName, stmt->getSrcInfo()); - ctx->addAlwaysVisible(funcVal); - } - - std::vector args; - StmtPtr suite = nullptr; - ExprPtr ret = nullptr; - std::unordered_map> captures; - std::unordered_set pyCaptures; - { - // Set up the base - SimplifyContext::BaseGuard br(ctx.get(), canonicalName); - ctx->getBase()->attributes = &(stmt->attributes); - - // Parse arguments and add them to the context - for (auto &a : stmt->args) { - std::string varName = a.name; - int stars = trimStars(varName); - auto name = ctx->generateCanonicalName(varName); - - // Mark as method if the first argument is self - if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") { - ctx->getBase()->selfName = name; - stmt->attributes.set(Attr::Method); - } - - // Handle default values - auto defaultValue = a.defaultValue; - if (a.type && defaultValue && defaultValue->getNone()) { - // Special case: `arg: Callable = None` -> `arg: Callable = NoneType()` - if (a.type->getIndex() && a.type->getIndex()->expr->isId(TYPE_CALLABLE)) - defaultValue = N(N("NoneType")); - // Special case: `arg: type = None` -> `arg: type = NoneType` - if (a.type->isId("type") || a.type->isId(TYPE_TYPEVAR)) - defaultValue = N("NoneType"); - } - /// TODO: Uncomment for Python-style defaults - // if (defaultValue) { - // auto defaultValueCanonicalName = - // ctx->generateCanonicalName(format("{}.{}", canonicalName, name)); - // prependStmts->push_back(N(N(defaultValueCanonicalName), - // defaultValue)); - // defaultValue = N(defaultValueCanonicalName); - // } - args.emplace_back( - Param{std::string(stars, '*') + name, a.type, defaultValue, a.status}); - - // Add generics to the context - if (a.status != Param::Normal) { - if (auto st = getStaticGeneric(a.type.get())) { - auto val = ctx->addVar(varName, name, stmt->getSrcInfo()); - val->generic = true; - val->staticType = st; - } else { - ctx->addType(varName, name, stmt->getSrcInfo())->generic = true; - } - } - } - - // Parse arguments to the context. Needs to be done after adding generics - // to support cases like `foo(a: T, T: type)` - for (auto &a : args) { - a.type = transformType(a.type, false); - a.defaultValue = transform(a.defaultValue, true); - } - // Add non-generic arguments to the context. Delayed to prevent cases like - // `def foo(a, b=a)` - for (auto &a : args) { - if (a.status == Param::Normal) { - std::string canName = a.name; - trimStars(canName); - ctx->addVar(ctx->cache->rev(canName), canName, stmt->getSrcInfo()); - } - } - - // Parse the return type - ret = transformType(stmt->ret, false); - - // Parse function body - if (!stmt->attributes.has(Attr::Internal) && !stmt->attributes.has(Attr::C)) { - if (stmt->attributes.has(Attr::LLVM)) { - suite = transformLLVMDefinition(stmt->suite->firstInBlock()); - } else if (stmt->attributes.has(Attr::C)) { - // Do nothing - } else { - if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember) - ctx->getBase()->captures = &captures; - if (stmt->attributes.has("std.internal.attributes.pycapture")) - ctx->getBase()->pyCaptures = &pyCaptures; - suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite); - } - } - } - stmt->attributes.module = - format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", - ctx->moduleName.module); - ctx->cache->overloads[rootName].push_back({canonicalName, ctx->cache->age}); - - // Special method handling - if (isClassMember) { - // Set the enclosing class name - stmt->attributes.parentClass = ctx->getBase()->name; - // Add the method to the class' method list - ctx->cache->classes[ctx->getBase()->name].methods[stmt->name] = rootName; - } else { - // Hack so that we can later use same helpers for class overloads - ctx->cache->classes[".toplevel"].methods[stmt->name] = rootName; - } - - // Handle captures. Add additional argument to the function for every capture. - // Make sure to account for **kwargs if present - std::vector partialArgs; - if (!captures.empty()) { - Param kw; - if (!args.empty() && startswith(args.back().name, "**")) { - kw = args.back(); - args.pop_back(); - } - for (auto &c : captures) { - args.emplace_back(Param{c.second.first, c.second.second, nullptr}); - partialArgs.push_back({c.second.first, N(ctx->cache->rev(c.first))}); - } - if (!kw.name.empty()) - args.push_back(kw); - partialArgs.emplace_back("", N(EllipsisExpr::PARTIAL)); - } - // Make function AST and cache it for later realization - auto f = N(canonicalName, ret, args, suite, stmt->attributes); - ctx->cache->functions[canonicalName].ast = f; - ctx->cache->functions[canonicalName].origAst = - std::static_pointer_cast(stmt->clone()); - ctx->cache->functions[canonicalName].isToplevel = - ctx->getModule().empty() && ctx->isGlobal(); - ctx->cache->functions[canonicalName].rootName = rootName; - - // Expression to be used if function binding is modified by captures or decorators - ExprPtr finalExpr = nullptr; - // If there are captures, replace `fn` with `fn(cap1=cap1, cap2=cap2, ...)` - if (!captures.empty()) { - finalExpr = N(N(stmt->name), partialArgs); - // Add updated self reference in case function is recursive! - auto pa = partialArgs; - for (auto &a : pa) { - if (!a.name.empty()) - a.value = N(a.name); - else - a.value = clone(a.value); - } - f->suite = N( - N(N(rootName), N(N(rootName), pa)), - suite); - } - - // Parse remaining decorators - for (auto i = stmt->decorators.size(); i-- > 0;) { - if (stmt->decorators[i]) { - if (isClassMember) - E(Error::FN_NO_DECORATORS, stmt->decorators[i]); - // Replace each decorator with `decorator(finalExpr)` in the reverse order - finalExpr = N(stmt->decorators[i], - finalExpr ? finalExpr : N(stmt->name)); - } - } - - if (finalExpr) { - resultStmt = - N(f, transform(N(N(stmt->name), finalExpr))); - } else { - resultStmt = f; - } -} - -/// Make a capturing anonymous function with the provided suite and argument names. -/// The resulting function will be added before the current statement. -/// Return an expression that can call this function (an @c IdExpr or a partial call). -ExprPtr SimplifyVisitor::makeAnonFn(std::vector suite, - const std::vector &argNames) { - std::vector params; - std::string name = ctx->cache->getTemporaryVar("lambda"); - params.reserve(argNames.size()); - for (auto &s : argNames) - params.emplace_back(Param(s)); - auto f = transform(N( - name, nullptr, params, N(std::move(suite)), Attr({Attr::Capture}))); - if (auto fs = f->getSuite()) { - seqassert(fs->stmts.size() == 2 && fs->stmts[0]->getFunction(), - "invalid function transform"); - prependStmts->push_back(fs->stmts[0]); - for (StmtPtr s = fs->stmts[1]; s;) { - if (auto suite = s->getSuite()) { - // Suites can only occur when captures are inserted for a partial call - // argument. - seqassert(suite->stmts.size() == 2, "invalid function transform"); - prependStmts->push_back(suite->stmts[0]); - s = suite->stmts[1]; - } else if (auto assign = s->getAssign()) { - return assign->rhs; - } else { - seqassert(false, "invalid function transform"); - } - } - return nullptr; // should fail an assert before - } else { - prependStmts->push_back(f); - return transform(N(name)); - } -} - -/// Transform Python code blocks. -/// @example -/// ```@python -/// def foo(x: int, y) -> int: -/// [code] -/// ``` -> ``` -/// pyobj._exec("def foo(x, y): [code]") -/// from python import __main__.foo(int, _) -> int -/// ``` -StmtPtr SimplifyVisitor::transformPythonDefinition(const std::string &name, - const std::vector &args, - const Expr *ret, Stmt *codeStmt) { - seqassert(codeStmt && codeStmt->getExpr() && codeStmt->getExpr()->expr->getString(), - "invalid Python definition"); - - auto code = codeStmt->getExpr()->expr->getString()->getValue(); - std::vector pyargs; - pyargs.reserve(args.size()); - for (const auto &a : args) - pyargs.emplace_back(a.name); - code = format("def {}({}):\n{}\n", name, join(pyargs, ", "), code); - return transform(N( - N(N(N("pyobj", "_exec"), N(code))), - N(N("python"), N("__main__", name), clone_nop(args), - ret ? ret->clone() : N("pyobj")))); -} - -/// Transform LLVM functions. -/// @example -/// ```@llvm -/// def foo(x: int) -> float: -/// [code] -/// ``` -> ``` -/// def foo(x: int) -> float: -/// StringExpr("[code]") -/// SuiteStmt(referenced_types) -/// ``` -/// As LLVM code can reference types and static expressions in `{=expr}` blocks, -/// all block expression will be stored in the `referenced_types` suite. -/// "[code]" is transformed accordingly: each `{=expr}` block will -/// be replaced with `{}` so that @c fmt::format can fill the gaps. -/// Note that any brace (`{` or `}`) that is not part of a block is -/// escaped (e.g. `{` -> `{{` and `}` -> `}}`) so that @c fmt::format can process them. -StmtPtr SimplifyVisitor::transformLLVMDefinition(Stmt *codeStmt) { - seqassert(codeStmt && codeStmt->getExpr() && codeStmt->getExpr()->expr->getString(), - "invalid LLVM definition"); - - auto code = codeStmt->getExpr()->expr->getString()->getValue(); - std::vector items; - auto se = N(""); - std::string finalCode = se->getValue(); - items.push_back(N(se)); - - // Parse LLVM code and look for expression blocks that start with `{=` - int braceCount = 0, braceStart = 0; - for (int i = 0; i < code.size(); i++) { - if (i < code.size() - 1 && code[i] == '{' && code[i + 1] == '=') { - if (braceStart < i) - finalCode += escapeFStringBraces(code, braceStart, i - braceStart) + '{'; - if (!braceCount) { - braceStart = i + 2; - braceCount++; - } else { - E(Error::FN_BAD_LLVM, getSrcInfo()); - } - } else if (braceCount && code[i] == '}') { - braceCount--; - std::string exprCode = code.substr(braceStart, i - braceStart); - auto offset = getSrcInfo(); - offset.col += i; - auto expr = transform(parseExpr(ctx->cache, exprCode, offset).first, true); - items.push_back(N(expr)); - braceStart = i + 1; - finalCode += '}'; - } - } - if (braceCount) - E(Error::FN_BAD_LLVM, getSrcInfo()); - if (braceStart != code.size()) - finalCode += escapeFStringBraces(code, braceStart, int(code.size()) - braceStart); - se->strings[0].first = finalCode; - return N(items); -} - -/// Fetch a decorator canonical name. The first pair member indicates if a decorator is -/// actually an attribute (a function with `@__attribute__`). -std::pair SimplifyVisitor::getDecorator(const ExprPtr &e) { - auto dt = transform(clone(e)); - auto id = dt->getCall() ? dt->getCall()->expr : dt; - if (id && id->getId()) { - auto ci = ctx->find(id->getId()->value); - if (ci && ci->isFunc()) { - if (ctx->cache->overloads[ci->canonicalName].size() == 1) { - return {ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name] - .ast->attributes.isAttribute, - ci->canonicalName}; - } - } - } - return {false, ""}; -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/import.cpp b/codon/parser/visitors/simplify/import.cpp deleted file mode 100644 index 5fc3312f..00000000 --- a/codon/parser/visitors/simplify/import.cpp +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/peg/peg.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// Import and parse a new module into its own context. -/// Also handle special imports ( see @c transformSpecialImport ). -/// To simulate Python's dynamic import logic and import stuff only once, -/// each import statement is guarded as follows: -/// if not _import_N_done: -/// _import_N() -/// _import_N_done = True -/// See @c transformNewImport and below for more details. -void SimplifyVisitor::visit(ImportStmt *stmt) { - seqassert(!ctx->inClass(), "imports within a class"); - if ((resultStmt = transformSpecialImport(stmt))) - return; - - // Fetch the import - auto components = getImportPath(stmt->from.get(), stmt->dots); - auto path = combine2(components, "/"); - auto file = getImportFile(ctx->cache->argv0, path, ctx->getFilename(), false, - ctx->cache->module0, ctx->cache->pluginImportPaths); - if (!file) { - std::string s(stmt->dots, '.'); - for (size_t i = 0; i < components.size(); i++) - if (components[i] == "..") { - continue; - } else if (!s.empty() && s.back() != '.') { - s += "." + components[i]; - } else { - s += components[i]; - } - E(Error::IMPORT_NO_MODULE, stmt->from, s); - } - - // If the file has not been seen before, load it into cache - bool handled = true; - if (ctx->cache->imports.find(file->path) == ctx->cache->imports.end()) { - resultStmt = transformNewImport(*file); - if (!resultStmt) - handled = false; // we need an import - } - - const auto &import = ctx->cache->imports[file->path]; - std::string importVar = import.importVar; - if (!import.loadedAtToplevel) - handled = false; - - // Construct `if _import_done.__invert__(): (_import(); _import_done = True)`. - // Do not do this during the standard library loading (we assume that standard library - // imports are "clean" and do not need guards). Note that the importVar is empty if - // the import has been loaded during the standard library loading. - if (!handled) { - resultStmt = N(N(N(importVar + ":0"))); - LOG_TYPECHECK("[import] loading {}", importVar); - } - - // Import requested identifiers from the import's scope to the current scope - if (!stmt->what) { - // Case: import foo - auto name = stmt->as.empty() ? path : stmt->as; - ctx->addVar(name, importVar, stmt->getSrcInfo())->importPath = file->path; - } else if (stmt->what->isId("*")) { - // Case: from foo import * - seqassert(stmt->as.empty(), "renamed star-import"); - // Just copy all symbols from import's context here. - for (auto &i : *(import.ctx)) { - if ((!startswith(i.first, "_") || - (ctx->isStdlibLoading && startswith(i.first, "__")))) { - // Ignore all identifiers that start with `_` but not those that start with - // `__` while the standard library is being loaded - auto c = i.second.front(); - if (c->isConditional() && i.first.find('.') == std::string::npos) { - c = import.ctx->findDominatingBinding(i.first); - } - // Imports should ignore noShadow property - ctx->Context::add(i.first, c); - } - } - } else { - // Case 3: from foo import bar - auto i = stmt->what->getId(); - seqassert(i, "not a valid import what expression"); - auto c = import.ctx->find(i->value); - // Make sure that we are importing an existing global symbol - if (!c) - E(Error::IMPORT_NO_NAME, i, i->value, file->module); - if (c->isConditional()) - c = import.ctx->findDominatingBinding(i->value); - // Imports should ignore noShadow property - ctx->Context::add(stmt->as.empty() ? i->value : stmt->as, c); - } - - if (!resultStmt) { - resultStmt = N(); // erase it - } -} - -/// Transform special `from C` and `from python` imports. -/// See @c transformCImport, @c transformCDLLImport and @c transformPythonImport -StmtPtr SimplifyVisitor::transformSpecialImport(ImportStmt *stmt) { - if (stmt->from && stmt->from->isId("C") && stmt->what->getId() && stmt->isFunction) { - // C function imports - return transformCImport(stmt->what->getId()->value, stmt->args, stmt->ret.get(), - stmt->as); - } - if (stmt->from && stmt->from->isId("C") && stmt->what->getId()) { - // C variable imports - return transformCVarImport(stmt->what->getId()->value, stmt->ret.get(), stmt->as); - } else if (stmt->from && stmt->from->isId("C") && stmt->what->getDot()) { - // dylib C imports - return transformCDLLImport(stmt->what->getDot()->expr.get(), - stmt->what->getDot()->member, stmt->args, - stmt->ret.get(), stmt->as, stmt->isFunction); - } else if (stmt->from && stmt->from->isId("python") && stmt->what) { - // Python imports - return transformPythonImport(stmt->what.get(), stmt->args, stmt->ret.get(), - stmt->as); - } - return nullptr; -} - -/// Transform Dot(Dot(a, b), c...) into "{a, b, c, ...}". -/// Useful for getting import paths. -std::vector SimplifyVisitor::getImportPath(Expr *from, size_t dots) { - std::vector components; // Path components - if (from) { - for (; from->getDot(); from = from->getDot()->expr.get()) - components.push_back(from->getDot()->member); - seqassert(from->getId(), "invalid import statement"); - components.push_back(from->getId()->value); - } - - // Handle dots (i.e., `..` in `from ..m import x`) - for (size_t i = 1; i < dots; i++) - components.emplace_back(".."); - std::reverse(components.begin(), components.end()); - return components; -} - -/// Transform a C function import. -/// @example -/// `from C import foo(int) -> float as f` -> -/// ```@.c -/// def foo(a1: int) -> float: -/// pass -/// f = foo # if altName is provided``` -/// No return type implies void return type. *args is treated as C VAR_ARGS. -StmtPtr SimplifyVisitor::transformCImport(const std::string &name, - const std::vector &args, - const Expr *ret, const std::string &altName) { - std::vector fnArgs; - auto attr = Attr({Attr::C}); - for (size_t ai = 0; ai < args.size(); ai++) { - seqassert(args[ai].name.empty(), "unexpected argument name"); - seqassert(!args[ai].defaultValue, "unexpected default argument"); - seqassert(args[ai].type, "missing type"); - if (args[ai].type->getEllipsis() && ai + 1 == args.size()) { - // C VAR_ARGS support - attr.set(Attr::CVarArg); - fnArgs.emplace_back(Param{"*args", nullptr, nullptr}); - } else { - fnArgs.emplace_back( - Param{args[ai].name.empty() ? format("a{}", ai) : args[ai].name, - args[ai].type->clone(), nullptr}); - } - } - ctx->generateCanonicalName(name); // avoid canonicalName == name - StmtPtr f = N(name, ret ? ret->clone() : N("NoneType"), fnArgs, - nullptr, attr); - f = transform(f); // Already in the preamble - if (!altName.empty()) { - auto val = ctx->forceFind(name); - ctx->add(altName, val); - ctx->remove(name); - } - return f; -} - -/// Transform a C variable import. -/// @example -/// `from C import foo: int as f` -> -/// ```f: int = "foo"``` -StmtPtr SimplifyVisitor::transformCVarImport(const std::string &name, const Expr *type, - const std::string &altName) { - auto canonical = ctx->generateCanonicalName(name); - auto val = ctx->addVar(altName.empty() ? name : altName, canonical); - val->noShadow = true; - auto s = N(N(canonical), nullptr, transformType(type->clone())); - s->lhs->setAttr(ExprAttr::ExternVar); - return s; -} - -/// Transform a dynamic C import. -/// @example -/// `from C import lib.foo(int) -> float as f` -> -/// `f = _dlsym(lib, "foo", Fn=Function[[int], float]); f` -/// No return type implies void return type. -StmtPtr SimplifyVisitor::transformCDLLImport(const Expr *dylib, const std::string &name, - const std::vector &args, - const Expr *ret, - const std::string &altName, - bool isFunction) { - ExprPtr type = nullptr; - if (isFunction) { - std::vector fnArgs{N(std::vector{}), - ret ? ret->clone() : N("NoneType")}; - for (const auto &a : args) { - seqassert(a.name.empty(), "unexpected argument name"); - seqassert(!a.defaultValue, "unexpected default argument"); - seqassert(a.type, "missing type"); - fnArgs[0]->getList()->items.emplace_back(clone(a.type)); - } - - type = N(N("Function"), N(fnArgs)); - } else { - type = ret->clone(); - } - - return transform(N( - N(altName.empty() ? name : altName), - N(N("_dlsym"), - std::vector{CallExpr::Arg(dylib->clone()), - CallExpr::Arg(N(name)), - {"Fn", type}}))); -} - -/// Transform a Python module and function imports. -/// @example -/// `from python import module as f` -> `f = pyobj._import("module")` -/// `from python import lib.foo(int) -> float as f` -> -/// ```def f(a0: int) -> float: -/// f = pyobj._import("lib")._getattr("foo") -/// return float.__from_py__(f(a0))``` -/// If a return type is nullptr, the function just returns f (raw pyobj). -StmtPtr SimplifyVisitor::transformPythonImport(Expr *what, - const std::vector &args, - Expr *ret, const std::string &altName) { - // Get a module name (e.g., os.path) - auto components = getImportPath(what); - - if (!ret && args.empty()) { - // Simple import: `from python import foo.bar` -> `bar = pyobj._import("foo.bar")` - return transform( - N(N(altName.empty() ? components.back() : altName), - N(N("pyobj", "_import"), - N(combine2(components, "."))))); - } - - // Python function import: - // `from python import foo.bar(int) -> float` -> - // ```def bar(a1: int) -> float: - // f = pyobj._import("foo")._getattr("bar") - // return float.__from_py__(f(a1))``` - - // f = pyobj._import("foo")._getattr("bar") - auto call = N( - N("f"), - N( - N(N(N("pyobj", "_import"), - N(combine2(components, ".", 0, - int(components.size()) - 1))), - "_getattr"), - N(components.back()))); - // f(a1, ...) - std::vector params; - std::vector callArgs; - for (int i = 0; i < args.size(); i++) { - params.emplace_back(Param{format("a{}", i), clone(args[i].type), nullptr}); - callArgs.emplace_back(N(format("a{}", i))); - } - // `return ret.__from_py__(f(a1, ...))` - auto retType = (ret && !ret->getNone()) ? ret->clone() : N("NoneType"); - auto retExpr = N(N(retType->clone(), "__from_py__"), - N(N(N("f"), callArgs), "p")); - auto retStmt = N(retExpr); - // Create a function - return transform(N(altName.empty() ? components.back() : altName, - retType, params, N(call, retStmt))); -} - -/// Import a new file into its own context and wrap its top-level statements into a -/// function to support Python-like runtime import loading. -/// @example -/// ```_import_[I]_done = False -/// def _import_[I](): -/// global [imported global variables]... -/// __name__ = [I] -/// [imported top-level statements]``` -StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) { - // Use a clean context to parse a new file - if (ctx->cache->age) - ctx->cache->age++; - auto ictx = std::make_shared(file.path, ctx->cache); - ictx->isStdlibLoading = ctx->isStdlibLoading; - ictx->moduleName = file; - auto import = ctx->cache->imports.insert({file.path, {file.path, ictx}}).first; - import->second.loadedAtToplevel = - ctx->cache->imports[ctx->moduleName.path].loadedAtToplevel && - (ctx->isStdlibLoading || (ctx->isGlobal() && ctx->scope.blocks.size() == 1)); - auto importVar = import->second.importVar = - ctx->cache->getTemporaryVar(format("import_{}", file.module)); - import->second.moduleName = file.module; - LOG_TYPECHECK("[import] initializing {} ({})", importVar, - import->second.loadedAtToplevel); - - // __name__ = [import name] - StmtPtr n = nullptr; - if (ictx->moduleName.module != "internal.core") { - // str is not defined when loading internal.core; __name__ is not needed anyway - n = N(N("__name__"), N(ictx->moduleName.module)); - preamble->push_back(N( - N(importVar), - N(N("Import.__new__"), N(file.module), - N(file.path), N(false)), - N("Import"))); - auto var = ctx->addAlwaysVisible( - std::make_shared(SimplifyItem::Var, ctx->getBaseName(), importVar, - ctx->getModule(), std::vector{0})); - ctx->cache->addGlobal(importVar); - } - n = N(n, parseFile(ctx->cache, file.path)); - n = SimplifyVisitor(ictx, preamble).transform(n); - if (!ctx->cache->errors.empty()) - throw exc::ParserException(); - // Add comment to the top of import for easier dump inspection - auto comment = N(format("import: {} at {}", file.module, file.path)); - if (ctx->isStdlibLoading) { - // When loading the standard library, imports are not wrapped. - // We assume that the standard library has no recursive imports and that all - // statements are executed before the user-provided code. - return N(comment, n); - } else { - // Wrap all imported top-level statements into a function. - // Make sure to register the global variables and set their assignments as - // updates. Note: signatures/classes/functions are not wrapped - std::vector stmts; - stmts.push_back( - N(N(N(importVar), "loaded"), N())); - stmts.push_back(N( - N(N("Import._set_loaded"), - N(N("__ptr__"), N(importVar))))); - auto processToplevelStmt = [&](const StmtPtr &s) { - // Process toplevel statement - if (auto a = s->getAssign()) { - if (!a->isUpdate() && a->lhs->getId()) { - // Global `a = ...` - auto val = ictx->forceFind(a->lhs->getId()->value); - if (val->isVar() && val->isGlobal()) - ctx->cache->addGlobal(val->canonicalName); - } - } - stmts.push_back(s); - }; - processToplevelStmt(comment); - if (auto st = n->getSuite()) { - for (auto &ss : st->stmts) - if (ss) - processToplevelStmt(ss); - } else { - processToplevelStmt(n); - } - - // Create import function manually with ForceRealize - ctx->cache->functions[importVar + ":0"].ast = - N(importVar + ":0", nullptr, std::vector{}, - N(stmts), Attr({Attr::ForceRealize})); - preamble->push_back(ctx->cache->functions[importVar + ":0"].ast->clone()); - ctx->cache->overloads[importVar].push_back({importVar + ":0", ctx->cache->age}); - } - return nullptr; -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/loops.cpp b/codon/parser/visitors/simplify/loops.cpp deleted file mode 100644 index e4ebea44..00000000 --- a/codon/parser/visitors/simplify/loops.cpp +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/peg/peg.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -/// Ensure that `continue` is in a loop -void SimplifyVisitor::visit(ContinueStmt *stmt) { - if (!ctx->getBase()->getLoop()) - E(Error::EXPECTED_LOOP, stmt, "continue"); - ctx->getBase()->getLoop()->flat = false; -} - -/// Ensure that `break` is in a loop. -/// Transform if a loop break variable is available -/// (e.g., a break within loop-else block). -/// @example -/// `break` -> `no_break = False; break` -void SimplifyVisitor::visit(BreakStmt *stmt) { - if (!ctx->getBase()->getLoop()) - E(Error::EXPECTED_LOOP, stmt, "break"); - ctx->getBase()->getLoop()->flat = false; - if (!ctx->getBase()->getLoop()->breakVar.empty()) { - resultStmt = N( - transform(N(N(ctx->getBase()->getLoop()->breakVar), - N(false))), - N()); - } -} - -/// Transform a while loop. -/// @example -/// `while cond: ...` -> `while cond.__bool__(): ...` -/// `while cond: ... else: ...` -> ```no_break = True -/// while cond.__bool__(): -/// ... -/// if no_break: ...``` -void SimplifyVisitor::visit(WhileStmt *stmt) { - // Check for while-else clause - std::string breakVar; - if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { - // no_break = True - breakVar = ctx->cache->getTemporaryVar("no_break"); - prependStmts->push_back( - transform(N(N(breakVar), N(true)))); - } - - ctx->enterConditionalBlock(); - ctx->getBase()->loops.push_back({breakVar, ctx->scope.blocks, {}}); - stmt->cond = transform(N(N(stmt->cond, "__bool__"))); - transformConditionalScope(stmt->suite); - - // Complete while-else clause - if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { - resultStmt = N(N(*stmt), - N(transform(N(breakVar)), - transformConditionalScope(stmt->elseSuite))); - } - - ctx->leaveConditionalBlock(); - // Dominate loop variables - for (auto &var : ctx->getBase()->getLoop()->seenVars) { - ctx->findDominatingBinding(var); - } - ctx->getBase()->loops.pop_back(); -} - -/// Transform for loop. -/// @example -/// `for i, j in it: ...` -> ```for tmp in it: -/// i, j = tmp -/// ...``` -/// `for i in it: ... else: ...` -> ```no_break = True -/// for i in it: ... -/// if no_break: ...``` -void SimplifyVisitor::visit(ForStmt *stmt) { - stmt->decorator = transformForDecorator(stmt->decorator); - - std::string breakVar; - // Needs in-advance transformation to prevent name clashes with the iterator variable - stmt->iter = transform(stmt->iter); - - // Check for for-else clause - StmtPtr assign = nullptr; - if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { - breakVar = ctx->cache->getTemporaryVar("no_break"); - assign = transform(N(N(breakVar), N(true))); - } - - ctx->enterConditionalBlock(); - ctx->getBase()->loops.push_back({breakVar, ctx->scope.blocks, {}}); - std::string varName; - if (auto i = stmt->var->getId()) { - auto val = ctx->addVar(i->value, varName = ctx->generateCanonicalName(i->value), - stmt->var->getSrcInfo()); - val->avoidDomination = ctx->avoidDomination; - transform(stmt->var); - stmt->suite = transform(N(stmt->suite)); - } else { - varName = ctx->cache->getTemporaryVar("for"); - auto val = ctx->addVar(varName, varName, stmt->var->getSrcInfo()); - auto var = N(varName); - std::vector stmts; - // Add for_var = [for variables] - stmts.push_back(N(stmt->var, clone(var))); - stmt->var = var; - stmts.push_back(stmt->suite); - stmt->suite = transform(N(stmts)); - } - - if (ctx->getBase()->getLoop()->flat) - stmt->flat = true; - // Complete while-else clause - if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { - resultStmt = N(assign, N(*stmt), - N(transform(N(breakVar)), - transformConditionalScope(stmt->elseSuite))); - } - - ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts)); - // Dominate loop variables - for (auto &var : ctx->getBase()->getLoop()->seenVars) { - ctx->findDominatingBinding(var); - } - ctx->getBase()->loops.pop_back(); -} - -/// Transform and check for OpenMP decorator. -/// @example -/// `@par(num_threads=2, openmp="schedule(static)")` -> -/// `for_par(num_threads=2, schedule="static")` -ExprPtr SimplifyVisitor::transformForDecorator(const ExprPtr &decorator) { - if (!decorator) - return nullptr; - ExprPtr callee = decorator; - if (auto c = callee->getCall()) - callee = c->expr; - if (!callee || !callee->isId("par")) - E(Error::LOOP_DECORATOR, decorator); - std::vector args; - std::string openmp; - std::vector omp; - if (auto c = decorator->getCall()) - for (auto &a : c->args) { - if (a.name == "openmp" || - (a.name.empty() && openmp.empty() && a.value->getString())) { - omp = parseOpenMP(ctx->cache, a.value->getString()->getValue(), - a.value->getSrcInfo()); - } else { - args.push_back({a.name, transform(a.value)}); - } - } - for (auto &a : omp) - args.push_back({a.name, transform(a.value)}); - return N(transform(N("for_par")), args); -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/op.cpp b/codon/parser/visitors/simplify/op.cpp deleted file mode 100644 index a7e70b40..00000000 --- a/codon/parser/visitors/simplify/op.cpp +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/cache.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" - -using fmt::format; -using namespace codon::error; - -namespace codon::ast { - -void SimplifyVisitor::visit(UnaryExpr *expr) { transform(expr->expr); } - -/// Transform binary expressions with a few special considerations. -/// The real stuff happens during the type checking. -void SimplifyVisitor::visit(BinaryExpr *expr) { - // Special case: `is` can take type as well - transform(expr->lexpr, startswith(expr->op, "is")); - auto tmp = ctx->isConditionalExpr; - // The second operand of the and/or expression is conditional - ctx->isConditionalExpr = expr->op == "&&" || expr->op == "||"; - transform(expr->rexpr, startswith(expr->op, "is")); - ctx->isConditionalExpr = tmp; -} - -/// Transform chain binary expression. -/// @example -/// `a <= b <= c` -> `(a <= (chain := b)) and (chain <= c)` -/// The assignment above ensures that all expressions are executed only once. -void SimplifyVisitor::visit(ChainBinaryExpr *expr) { - seqassert(expr->exprs.size() >= 2, "not enough expressions in ChainBinaryExpr"); - std::vector items; - std::string prev; - for (int i = 1; i < expr->exprs.size(); i++) { - auto l = prev.empty() ? clone(expr->exprs[i - 1].second) : N(prev); - prev = ctx->generateCanonicalName("chain"); - auto r = - (i + 1 == expr->exprs.size()) - ? clone(expr->exprs[i].second) - : N(N(N(prev), clone(expr->exprs[i].second)), - N(prev)); - items.emplace_back(N(l, expr->exprs[i].first, r)); - } - - ExprPtr final = items.back(); - for (auto i = items.size() - 1; i-- > 0;) - final = N(items[i], "&&", final); - resultExpr = transform(final); -} - -/// Transform index into an instantiation @c InstantiateExpr if possible. -/// Generate tuple class `Tuple` for `Tuple[T1, ... TN]` (and `tuple[...]`). -/// The rest is handled during the type checking. -void SimplifyVisitor::visit(IndexExpr *expr) { - if (expr->expr->isId("tuple") || expr->expr->isId(TYPE_TUPLE)) { - auto t = expr->index->getTuple(); - expr->expr = NT(TYPE_TUPLE); - } else if (expr->expr->isId("Static")) { - // Special case: static types. Ensure that static is supported - if (!expr->index->isId("int") && !expr->index->isId("str")) - E(Error::BAD_STATIC_TYPE, expr->index); - expr->markType(); - return; - } else { - transform(expr->expr, true); - } - - // IndexExpr[i1, ..., iN] is internally represented as - // IndexExpr[TupleExpr[i1, ..., iN]] for N > 1 - std::vector items; - bool isTuple = expr->index->getTuple(); - if (auto t = expr->index->getTuple()) { - items = t->items; - } else { - items.push_back(expr->index); - } - for (auto &i : items) { - if (i->getList() && expr->expr->isType()) { - // Special case: `A[[A, B], C]` -> `A[Tuple[A, B], C]` (e.g., in - // `Function[...]`) - i = N(N(TYPE_TUPLE), N(i->getList()->items)); - } - transform(i, true); - } - if (expr->expr->isType()) { - resultExpr = N(expr->expr, items); - resultExpr->markType(); - } else { - expr->index = (!isTuple && items.size() == 1) ? items[0] : N(items); - } -} - -/// Already transformed. Sometimes needed again -/// for identifier analysis. -void SimplifyVisitor::visit(InstantiateExpr *expr) { - transformType(expr->typeExpr); - for (auto &tp : expr->typeParams) - transform(tp, true); -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/simplify.cpp b/codon/parser/visitors/simplify/simplify.cpp deleted file mode 100644 index e292f779..00000000 --- a/codon/parser/visitors/simplify/simplify.cpp +++ /dev/null @@ -1,285 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#include "simplify.h" - -#include -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/peg/peg.h" -#include "codon/parser/visitors/simplify/ctx.h" - -using fmt::format; -using namespace codon::error; -namespace codon::ast { - -using namespace types; - -/// Simplify an AST node. Load standard library if needed. -/// @param cache Pointer to the shared cache ( @c Cache ) -/// @param file Filename to be used for error reporting -/// @param barebones Use the bare-bones standard library for faster testing -/// @param defines User-defined static values (typically passed as `codon run -DX=Y`). -/// Each value is passed as a string. -StmtPtr -SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &file, - const std::unordered_map &defines, - const std::unordered_map &earlyDefines, - bool barebones) { - auto preamble = std::make_shared>(); - seqassertn(cache->module, "cache's module is not set"); - -#define N std::make_shared - // Load standard library if it has not been loaded - if (!in(cache->imports, STDLIB_IMPORT)) { - // Load the internal.__init__ - auto stdlib = std::make_shared(STDLIB_IMPORT, cache); - auto stdlibPath = - getImportFile(cache->argv0, STDLIB_INTERNAL_MODULE, "", true, cache->module0); - const std::string initFile = "__init__.codon"; - if (!stdlibPath || !endswith(stdlibPath->path, initFile)) - E(Error::COMPILER_NO_STDLIB); - - /// Use __init_test__ for faster testing (e.g., #%% name,barebones) - /// TODO: get rid of it one day... - if (barebones) { - stdlibPath->path = - stdlibPath->path.substr(0, stdlibPath->path.size() - initFile.size()) + - "__init_test__.codon"; - } - stdlib->setFilename(stdlibPath->path); - cache->imports[STDLIB_IMPORT] = {stdlibPath->path, stdlib}; - stdlib->isStdlibLoading = true; - stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"}; - // Load the standard library - stdlib->setFilename(stdlibPath->path); - // Core definitions - preamble->push_back(SimplifyVisitor(stdlib, preamble) - .transform(parseCode(stdlib->cache, stdlibPath->path, - "from internal.core import *"))); - for (auto &d : earlyDefines) { - // Load early compile-time defines (for standard library) - preamble->push_back( - SimplifyVisitor(stdlib, preamble) - .transform( - N(N(d.first), N(d.second), - N(N("Static"), N("int"))))); - } - preamble->push_back(SimplifyVisitor(stdlib, preamble) - .transform(parseFile(stdlib->cache, stdlibPath->path))); - stdlib->isStdlibLoading = false; - - // The whole standard library has the age of zero to allow back-references - cache->age++; - } - - // Set up the context and the cache - auto ctx = std::make_shared(file, cache); - cache->imports[file].filename = file; - cache->imports[file].ctx = ctx; - cache->imports[MAIN_IMPORT] = {file, ctx}; - ctx->setFilename(file); - ctx->moduleName = {ImportFile::PACKAGE, file, MODULE_MAIN}; - - // Prepare the code - auto suite = N(); - suite->stmts.push_back(N(".toplevel", std::vector{}, nullptr, - std::vector{N(Attr::Internal)})); - for (auto &d : defines) { - // Load compile-time defines (e.g., codon run -DFOO=1 ...) - suite->stmts.push_back( - N(N(d.first), N(d.second), - N(N("Static"), N("int")))); - } - // Set up __name__ - suite->stmts.push_back( - N(N("__name__"), N(MODULE_MAIN))); - suite->stmts.push_back(node); - auto n = SimplifyVisitor(ctx, preamble).transform(suite); - - suite = N(); - suite->stmts.push_back(N(*preamble)); - // Add dominated assignment declarations - if (in(ctx->scope.stmts, ctx->scope.blocks.back())) - suite->stmts.insert(suite->stmts.end(), - ctx->scope.stmts[ctx->scope.blocks.back()].begin(), - ctx->scope.stmts[ctx->scope.blocks.back()].end()); - suite->stmts.push_back(n); -#undef N - - if (!ctx->cache->errors.empty()) - throw exc::ParserException(); - - return suite; -} - -/// Simplify an AST node. Assumes that the standard library is loaded. -StmtPtr SimplifyVisitor::apply(const std::shared_ptr &ctx, - const StmtPtr &node, const std::string &file, - int atAge) { - std::vector stmts; - int oldAge = ctx->cache->age; - if (atAge != -1) - ctx->cache->age = atAge; - auto preamble = std::make_shared>(); - stmts.emplace_back(SimplifyVisitor(ctx, preamble).transform(node)); - if (!ctx->cache->errors.empty()) - throw exc::ParserException(); - - if (atAge != -1) - ctx->cache->age = oldAge; - auto suite = std::make_shared(); - for (auto &s : *preamble) - suite->stmts.push_back(s); - for (auto &s : stmts) - suite->stmts.push_back(s); - return suite; -} - -/**************************************************************************************/ - -SimplifyVisitor::SimplifyVisitor(std::shared_ptr ctx, - std::shared_ptr> preamble, - const std::shared_ptr> &stmts) - : ctx(std::move(ctx)), preamble(std::move(preamble)) { - prependStmts = stmts ? stmts : std::make_shared>(); -} - -/**************************************************************************************/ - -ExprPtr SimplifyVisitor::transform(ExprPtr &expr) { return transform(expr, false); } - -/// Transform an expression node. -/// @throw @c ParserException if a node is a type and @param allowTypes is not set -/// (use @c transformType instead). -ExprPtr SimplifyVisitor::transform(ExprPtr &expr, bool allowTypes) { - if (!expr) - return nullptr; - SimplifyVisitor v(ctx, preamble); - v.prependStmts = prependStmts; - v.setSrcInfo(expr->getSrcInfo()); - ctx->pushSrcInfo(expr->getSrcInfo()); - expr->accept(v); - ctx->popSrcInfo(); - if (v.resultExpr) { - v.resultExpr->attributes |= expr->attributes; - expr = v.resultExpr; - } - if (!allowTypes && expr && expr->isType()) - E(Error::UNEXPECTED_TYPE, expr, "type"); - return expr; -} - -/// Transform a type expression node. -/// @param allowTypeOf Set if `type()` expressions are allowed. Usually disallowed in -/// class/function definitions. -/// @throw @c ParserException if a node is not a type (use @c transform instead). -ExprPtr SimplifyVisitor::transformType(ExprPtr &expr, bool allowTypeOf) { - auto oldTypeOf = ctx->allowTypeOf; - ctx->allowTypeOf = allowTypeOf; - transform(expr, true); - if (expr && expr->getNone()) - expr->markType(); - ctx->allowTypeOf = oldTypeOf; - if (expr && !expr->isType()) - E(Error::EXPECTED_TYPE, expr, "type"); - return expr; -} - -/// Transform a statement node. -StmtPtr SimplifyVisitor::transform(StmtPtr &stmt) { - if (!stmt) - return nullptr; - - SimplifyVisitor v(ctx, preamble); - v.setSrcInfo(stmt->getSrcInfo()); - ctx->pushSrcInfo(stmt->getSrcInfo()); - try { - stmt->accept(v); - } catch (const exc::ParserException &e) { - ctx->cache->errors.push_back(e); - // throw; - } - ctx->popSrcInfo(); - if (v.resultStmt) - stmt = v.resultStmt; - stmt->age = ctx->cache->age; - if (!v.prependStmts->empty()) { - // Handle prepends - if (stmt) - v.prependStmts->push_back(stmt); - stmt = N(*v.prependStmts); - stmt->age = ctx->cache->age; - } - return stmt; -} - -/// Transform a statement in conditional scope. -/// Because variables and forward declarations within conditional scopes can be -/// added later after the domination analysis, ensure that all such declarations -/// are prepended. -StmtPtr SimplifyVisitor::transformConditionalScope(StmtPtr &stmt) { - if (stmt) { - ctx->enterConditionalBlock(); - transform(stmt); - SuiteStmt *suite = stmt->getSuite(); - if (!suite) { - stmt = N(stmt); - suite = stmt->getSuite(); - } - ctx->leaveConditionalBlock(&suite->stmts); - return stmt; - } - return stmt = nullptr; -} - -/**************************************************************************************/ - -void SimplifyVisitor::visit(StmtExpr *expr) { - for (auto &s : expr->stmts) - transform(s); - transform(expr->expr); -} - -void SimplifyVisitor::visit(StarExpr *expr) { transform(expr->what); } - -void SimplifyVisitor::visit(KeywordStarExpr *expr) { transform(expr->what); } - -/// Only allowed in @c MatchStmt -void SimplifyVisitor::visit(RangeExpr *expr) { - E(Error::UNEXPECTED_TYPE, expr, "range"); -} - -/// Handled during the type checking -void SimplifyVisitor::visit(SliceExpr *expr) { - transform(expr->start); - transform(expr->stop); - transform(expr->step); -} - -void SimplifyVisitor::visit(SuiteStmt *stmt) { - for (auto &s : stmt->stmts) - transform(s); - resultStmt = N(stmt->stmts); // needed for flattening -} - -void SimplifyVisitor::visit(ExprStmt *stmt) { transform(stmt->expr, true); } - -void SimplifyVisitor::visit(CustomStmt *stmt) { - if (stmt->suite) { - auto fn = ctx->cache->customBlockStmts.find(stmt->keyword); - seqassert(fn != ctx->cache->customBlockStmts.end(), "unknown keyword {}", - stmt->keyword); - resultStmt = fn->second.second(this, stmt); - } else { - auto fn = ctx->cache->customExprStmts.find(stmt->keyword); - seqassert(fn != ctx->cache->customExprStmts.end(), "unknown keyword {}", - stmt->keyword); - resultStmt = fn->second(this, stmt); - } -} - -} // namespace codon::ast diff --git a/codon/parser/visitors/simplify/simplify.h b/codon/parser/visitors/simplify/simplify.h deleted file mode 100644 index c9c9e815..00000000 --- a/codon/parser/visitors/simplify/simplify.h +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (C) 2022-2025 Exaloop Inc. - -#pragma once - -#include -#include -#include -#include - -#include "codon/parser/ast.h" -#include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/ctx.h" -#include "codon/parser/visitors/visitor.h" - -namespace codon::ast { - -/** - * Visitor that implements the initial AST simplification transformation. - * In this stage. the following steps are done: - * - All imports are flattened resulting in a single self-containing - * (and fairly large) AST - * - All identifiers are normalized (no two distinct objects share the same name) - * - Variadic classes (e.g., Tuple) are generated - * - Any AST node that can be trivially expressed as a set of "simpler" nodes - * type is simplified. If a transformation requires a type information, - * it is done during the type checking. - * - * -> Note: this stage *modifies* the provided AST. Clone it before simplification - * if you need it intact. - */ -class SimplifyVisitor : public CallbackASTVisitor { - /// Shared simplification context. - std::shared_ptr ctx; - /// Preamble contains definition statements shared across all visitors - /// in all modules. It is executed before simplified statements. - std::shared_ptr> preamble; - /// Statements to prepend before the current statement. - std::shared_ptr> prependStmts; - - /// Each new expression is stored here (as @c visit does not return anything) and - /// later returned by a @c transform call. - ExprPtr resultExpr; - /// Each new statement is stored here (as @c visit does not return anything) and - /// later returned by a @c transform call. - StmtPtr resultStmt; - -public: - static StmtPtr - apply(Cache *cache, const StmtPtr &node, const std::string &file, - const std::unordered_map &defines = {}, - const std::unordered_map &earlyDefines = {}, - bool barebones = false); - static StmtPtr apply(const std::shared_ptr &cache, - const StmtPtr &node, const std::string &file, int atAge = -1); - -public: - explicit SimplifyVisitor( - std::shared_ptr ctx, - std::shared_ptr> preamble, - const std::shared_ptr> &stmts = nullptr); - -public: // Convenience transformators - ExprPtr transform(ExprPtr &expr) override; - ExprPtr transform(const ExprPtr &expr) override { - auto e = expr; - return transform(e); - } - ExprPtr transform(ExprPtr &expr, bool allowTypes); - ExprPtr transform(ExprPtr &&expr, bool allowTypes) { - return transform(expr, allowTypes); - } - ExprPtr transformType(ExprPtr &expr, bool allowTypeOf = true); - ExprPtr transformType(ExprPtr &&expr, bool allowTypeOf = true) { - return transformType(expr, allowTypeOf); - } - StmtPtr transform(StmtPtr &stmt) override; - StmtPtr transform(const StmtPtr &stmt) override { - auto s = stmt; - return transform(s); - } - StmtPtr transformConditionalScope(StmtPtr &stmt); - -private: // Node simplification rules - /* Basic type expressions (basic.cpp) */ - void visit(IntExpr *) override; - ExprPtr transformInt(IntExpr *); - void visit(FloatExpr *) override; - ExprPtr transformFloat(FloatExpr *); - void visit(StringExpr *) override; - ExprPtr transformFString(const std::string &); - - /* Identifier access expressions (access.cpp) */ - void visit(IdExpr *) override; - bool checkCapture(const SimplifyContext::Item &); - void visit(DotExpr *) override; - std::pair getImport(const std::vector &); - - /* Collection and comprehension expressions (collections.cpp) */ - void visit(TupleExpr *) override; - void visit(ListExpr *) override; - void visit(SetExpr *) override; - void visit(DictExpr *) override; - void visit(GeneratorExpr *) override; - void visit(DictGeneratorExpr *) override; - StmtPtr transformGeneratorBody(const std::vector &, SuiteStmt *&); - - /* Conditional expression and statements (cond.cpp) */ - void visit(IfExpr *) override; - void visit(IfStmt *) override; - void visit(MatchStmt *) override; - StmtPtr transformPattern(const ExprPtr &, ExprPtr, StmtPtr); - - /* Operators (op.cpp) */ - void visit(UnaryExpr *) override; - void visit(BinaryExpr *) override; - void visit(ChainBinaryExpr *) override; - void visit(IndexExpr *) override; - void visit(InstantiateExpr *) override; - - /* Calls (call.cpp) */ - void visit(PrintStmt *) override; - void visit(CallExpr *) override; - ExprPtr transformSpecialCall(const ExprPtr &, const std::vector &); - ExprPtr transformTupleGenerator(const std::vector &); - ExprPtr transformNamedTuple(const std::vector &); - ExprPtr transformFunctoolsPartial(std::vector); - - /* Assignments (assign.cpp) */ - void visit(AssignExpr *) override; - void visit(AssignStmt *) override; - StmtPtr transformAssignment(ExprPtr, ExprPtr, ExprPtr = nullptr, bool = false); - void unpackAssignments(const ExprPtr &, ExprPtr, std::vector &); - void visit(DelStmt *) override; - - /* Imports (import.cpp) */ - void visit(ImportStmt *) override; - StmtPtr transformSpecialImport(ImportStmt *); - std::vector getImportPath(Expr *, size_t = 0); - StmtPtr transformCImport(const std::string &, const std::vector &, - const Expr *, const std::string &); - StmtPtr transformCVarImport(const std::string &, const Expr *, const std::string &); - StmtPtr transformCDLLImport(const Expr *, const std::string &, - const std::vector &, const Expr *, - const std::string &, bool); - StmtPtr transformPythonImport(Expr *, const std::vector &, Expr *, - const std::string &); - StmtPtr transformNewImport(const ImportFile &); - - /* Loops (loops.cpp) */ - void visit(ContinueStmt *) override; - void visit(BreakStmt *) override; - void visit(WhileStmt *) override; - void visit(ForStmt *) override; - ExprPtr transformForDecorator(const ExprPtr &); - - /* Errors and exceptions (error.cpp) */ - void visit(AssertStmt *) override; - void visit(TryStmt *) override; - void visit(ThrowStmt *) override; - void visit(WithStmt *) override; - - /* Functions (function.cpp) */ - void visit(YieldExpr *) override; - void visit(LambdaExpr *) override; - void visit(GlobalStmt *) override; - void visit(ReturnStmt *) override; - void visit(YieldStmt *) override; - void visit(YieldFromStmt *) override; - void visit(FunctionStmt *) override; - ExprPtr makeAnonFn(std::vector, const std::vector & = {}); - StmtPtr transformPythonDefinition(const std::string &, const std::vector &, - const Expr *, Stmt *); - StmtPtr transformLLVMDefinition(Stmt *); - std::pair getDecorator(const ExprPtr &); - - /* Classes (class.cpp) */ - void visit(ClassStmt *) override; - std::vector parseBaseClasses(std::vector &, - std::vector &, const Attr &, - const std::string &, - const ExprPtr & = nullptr); - std::pair autoDeduceMembers(ClassStmt *, - std::vector &); - std::vector getClassMethods(const StmtPtr &s); - void transformNestedClasses(ClassStmt *, std::vector &, - std::vector &, std::vector &); - StmtPtr codegenMagic(const std::string &, const ExprPtr &, const std::vector &, - bool); - - /* The rest (simplify.cpp) */ - void visit(StmtExpr *) override; - void visit(StarExpr *) override; - void visit(KeywordStarExpr *expr) override; - void visit(RangeExpr *) override; - void visit(SliceExpr *) override; - void visit(SuiteStmt *) override; - void visit(ExprStmt *) override; - void visit(CustomStmt *) override; -}; - -} // namespace codon::ast diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 9db399e5..c1eac139 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -12,6 +12,7 @@ #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/translate/translate_ctx.h" +#include "codon/parser/visitors/typecheck/typecheck.h" using codon::ir::cast; using codon::ir::transform::parallel::OMPSched; @@ -22,7 +23,7 @@ namespace codon::ast { TranslateVisitor::TranslateVisitor(std::shared_ptr ctx) : ctx(std::move(ctx)), result(nullptr) {} -ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) { +ir::Func *TranslateVisitor::apply(Cache *cache, Stmt *stmts) { ir::BodiedFunc *main = nullptr; if (cache->isJit) { auto fnName = format("_jit_{}", cache->jitCell); @@ -47,99 +48,102 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) { cache->codegenCtx->bases = {main}; cache->codegenCtx->series = {block}; - for (auto &[name, p] : cache->globals) - if (p.first && !p.second) { - p.second = name == VAR_ARGV ? cache->codegenCtx->getModule()->getArgVar() - : cache->codegenCtx->getModule()->N( - SrcInfo(), nullptr, true, false, name); - cache->codegenCtx->add(TranslateItem::Var, name, p.second); - } - - auto tv = TranslateVisitor(cache->codegenCtx); - tv.transform(stmts); - for (auto &[fn, f] : cache->functions) - if (startswith(fn, TYPE_TUPLE)) { - tv.transformFunctionRealizations(fn, f.ast->attributes.has(Attr::LLVM)); - } + TranslateVisitor(cache->codegenCtx).translateStmts(stmts); cache->populatePythonModule(); return main; } +void TranslateVisitor::translateStmts(Stmt *stmts) { + for (auto &[name, g] : ctx->cache->globals) + if (/*g.first &&*/ !g.second) { + ir::types::Type *vt = nullptr; + if (auto t = ctx->cache->typeCtx->forceFind(name)->getType()) + vt = getType(t); + g.second = name == VAR_ARGV ? ctx->cache->codegenCtx->getModule()->getArgVar() + : ctx->cache->codegenCtx->getModule()->N( + SrcInfo(), vt, true, false, name); + ctx->cache->codegenCtx->add(TranslateItem::Var, name, g.second); + } + TranslateVisitor(ctx->cache->codegenCtx).transform(stmts); + for (auto &[_, f] : ctx->cache->functions) + TranslateVisitor(ctx->cache->codegenCtx).transform(f.ast); +} + /************************************************************************************/ -ir::Value *TranslateVisitor::transform(const ExprPtr &expr) { +ir::Value *TranslateVisitor::transform(Expr *expr) { TranslateVisitor v(ctx); v.setSrcInfo(expr->getSrcInfo()); - types::PartialType *p = nullptr; - if (expr->attributes) { - if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set) || - expr->hasAttr(ExprAttr::Dict) || expr->hasAttr(ExprAttr::Partial)) { - ctx->seqItems.emplace_back(); - } - if (expr->hasAttr(ExprAttr::Partial)) - p = expr->type->getPartial().get(); + types::ClassType *p = nullptr; + if (expr->hasAttribute(Attr::ExprList) || expr->hasAttribute(Attr::ExprSet) || + expr->hasAttribute(Attr::ExprDict) || expr->hasAttribute(Attr::ExprPartial)) { + ctx->seqItems.emplace_back(); + } + if (expr->hasAttribute(Attr::ExprPartial)) { + p = expr->getType()->getPartial(); } expr->accept(v); ir::Value *ir = v.result; - if (expr->attributes) { - if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set)) { - std::vector v; - for (auto &p : ctx->seqItems.back()) { - seqassert(p.first <= ExprAttr::StarSequenceItem, "invalid list/set element"); - v.push_back( - ir::LiteralElement{p.second, p.first == ExprAttr::StarSequenceItem}); - } - if (expr->hasAttr(ExprAttr::List)) - ir->setAttribute(std::make_unique(v)); - else - ir->setAttribute(std::make_unique(v)); - ctx->seqItems.pop_back(); + if (expr->hasAttribute(Attr::ExprList) || expr->hasAttribute(Attr::ExprSet)) { + std::vector v; + for (auto &p : ctx->seqItems.back()) { + seqassert(p.first == Attr::ExprSequenceItem || + p.first == Attr::ExprStarSequenceItem, + "invalid list/set element"); + v.push_back(ir::LiteralElement{p.second, p.first == Attr::ExprStarSequenceItem}); } - if (expr->hasAttr(ExprAttr::Dict)) { - std::vector v; - for (int pi = 0; pi < ctx->seqItems.back().size(); pi++) { - auto &p = ctx->seqItems.back()[pi]; - if (p.first == ExprAttr::StarSequenceItem) { - v.push_back({p.second, nullptr}); - } else { - seqassert(p.first == ExprAttr::SequenceItem && - pi + 1 < ctx->seqItems.back().size() && - ctx->seqItems.back()[pi + 1].first == ExprAttr::SequenceItem, - "invalid dict element"); - v.push_back({p.second, ctx->seqItems.back()[pi + 1].second}); - pi++; - } + if (expr->hasAttribute(Attr::ExprList)) + ir->setAttribute(std::make_unique(v)); + else + ir->setAttribute(std::make_unique(v)); + ctx->seqItems.pop_back(); + } + if (expr->hasAttribute(Attr::ExprDict)) { + std::vector v; + for (int pi = 0; pi < ctx->seqItems.back().size(); pi++) { + auto &p = ctx->seqItems.back()[pi]; + if (p.first == Attr::ExprStarSequenceItem) { + v.push_back({p.second, nullptr}); + } else { + seqassert(p.first == Attr::ExprSequenceItem && + pi + 1 < ctx->seqItems.back().size() && + ctx->seqItems.back()[pi + 1].first == Attr::ExprSequenceItem, + "invalid dict element"); + v.push_back({p.second, ctx->seqItems.back()[pi + 1].second}); + pi++; } - ir->setAttribute(std::make_unique(v)); - ctx->seqItems.pop_back(); } - if (expr->hasAttr(ExprAttr::Partial)) { - std::vector v; - seqassert(p, "invalid partial element"); - int j = 0; - for (int i = 0; i < p->known.size(); i++) { - if (p->known[i] && p->func->ast->args[i].status == Param::Normal) { - seqassert(j < ctx->seqItems.back().size() && - ctx->seqItems.back()[j].first == ExprAttr::SequenceItem, - "invalid partial element"); - v.push_back(ctx->seqItems.back()[j++].second); - } else if (p->func->ast->args[i].status == Param::Normal) { - v.push_back({nullptr}); - } + ir->setAttribute(std::make_unique(v)); + ctx->seqItems.pop_back(); + } + if (expr->hasAttribute(Attr::ExprPartial)) { + std::vector v; + seqassert(p, "invalid partial element"); + int j = 0; + auto known = p->getPartialMask(); + auto func = p->getPartialFunc(); + for (int i = 0; i < known.size(); i++) { + if (known[i] && (*func->ast)[i].isValue()) { + seqassert(j < ctx->seqItems.back().size() && + ctx->seqItems.back()[j].first == Attr::ExprSequenceItem, + "invalid partial element: {}"); + v.push_back(ctx->seqItems.back()[j++].second); + } else if ((*func->ast)[i].isValue()) { + v.push_back({nullptr}); } - ir->setAttribute( - std::make_unique(p->func->ast->name, v)); - ctx->seqItems.pop_back(); - } - if (expr->hasAttr(ExprAttr::SequenceItem)) { - ctx->seqItems.back().push_back({ExprAttr::SequenceItem, ir}); - } - if (expr->hasAttr(ExprAttr::StarSequenceItem)) { - ctx->seqItems.back().push_back({ExprAttr::StarSequenceItem, ir}); } + ir->setAttribute( + std::make_unique(func->ast->getName(), v)); + ctx->seqItems.pop_back(); + } + if (expr->hasAttribute(Attr::ExprSequenceItem)) { + ctx->seqItems.back().emplace_back(Attr::ExprSequenceItem, ir); + } + if (expr->hasAttribute(Attr::ExprStarSequenceItem)) { + ctx->seqItems.back().emplace_back(Attr::ExprStarSequenceItem, ir); } return ir; @@ -150,7 +154,7 @@ void TranslateVisitor::defaultVisit(Expr *n) { } void TranslateVisitor::visit(NoneExpr *expr) { - auto f = expr->type->realizedName() + ":Optional.__new__:0"; + auto f = expr->getType()->realizedName() + ":Optional.__new__:0"; auto val = ctx->find(f); seqassert(val, "cannot find '{}'", f); result = make(expr, make(expr, val->getFunc()), @@ -158,15 +162,15 @@ void TranslateVisitor::visit(NoneExpr *expr) { } void TranslateVisitor::visit(BoolExpr *expr) { - result = make(expr, expr->value, getType(expr->getType())); + result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(IntExpr *expr) { - result = make(expr, *(expr->intValue), getType(expr->getType())); + result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(FloatExpr *expr) { - result = make(expr, *(expr->floatValue), getType(expr->getType())); + result = make(expr, expr->getValue(), getType(expr->getType())); } void TranslateVisitor::visit(StringExpr *expr) { @@ -174,9 +178,9 @@ void TranslateVisitor::visit(StringExpr *expr) { } void TranslateVisitor::visit(IdExpr *expr) { - auto val = ctx->find(expr->value); - seqassert(val, "cannot find '{}'", expr->value); - if (expr->value == "__vtable_size__") { + auto val = ctx->find(expr->getValue()); + seqassert(val, "cannot find '{}'", expr->getValue()); + if (expr->getValue() == "__vtable_size__.0") { // LOG("[] __vtable_size__={}", ctx->cache->classRealizationCnt + 2); result = make(expr, ctx->cache->classRealizationCnt + 2, getType(expr->getType())); @@ -188,74 +192,142 @@ void TranslateVisitor::visit(IdExpr *expr) { } void TranslateVisitor::visit(IfExpr *expr) { - auto cond = transform(expr->cond); - auto ifexpr = transform(expr->ifexpr); - auto elsexpr = transform(expr->elsexpr); + auto cond = transform(expr->getCond()); + auto ifexpr = transform(expr->getIf()); + auto elsexpr = transform(expr->getElse()); result = make(expr, cond, ifexpr, elsexpr); } +// Search expression tree for a identifier +class IdVisitor : public CallbackASTVisitor { +public: + std::unordered_set ids; + + bool transform(Expr *expr) override { + IdVisitor v; + if (expr) + expr->accept(v); + ids.insert(v.ids.begin(), v.ids.end()); + return true; + } + bool transform(Stmt *stmt) override { + IdVisitor v; + if (stmt) + stmt->accept(v); + ids.insert(v.ids.begin(), v.ids.end()); + return true; + } + void visit(IdExpr *expr) override { ids.insert(expr->getValue()); } +}; + +void TranslateVisitor::visit(GeneratorExpr *expr) { + auto name = ctx->cache->imports[MAIN_IMPORT].ctx->generateCanonicalName("_generator"); + ir::Func *fn = ctx->cache->module->Nr(name); + fn->setGlobal(); + fn->setGenerator(); + std::vector names; + std::vector types; + std::vector items; + + IdVisitor v; + expr->accept(v); + for (auto &i : v.ids) { + auto val = ctx->find(i); + if (val && !val->getFunc() && !val->getType() && !val->getVar()->isGlobal()) { + types.push_back(val->getVar()->getType()); + names.push_back(i); + items.emplace_back(make(expr, val->getVar())); + } + } + auto irType = ctx->cache->module->unsafeGetFuncType( + name, ctx->forceFind(expr->getType()->realizedName())->getType(), types, false); + fn->realize(irType, names); + + ctx->addBlock(); + for (auto &n : names) + ctx->add(TranslateItem::Var, n, fn->getArgVar(n)); + auto body = make(expr, "body"); + ctx->bases.push_back(cast(fn)); + ctx->addSeries(body); + + expr->setFinalStmt(ctx->cache->N(expr->getFinalExpr())); + auto e = expr->getFinalSuite(); + transform(e); + ctx->popSeries(); + ctx->bases.pop_back(); + cast(fn)->setBody(body); + ctx->popBlock(); + result = make(expr, make(expr, fn), std::move(items)); +} + void TranslateVisitor::visit(CallExpr *expr) { - if (expr->expr->isId("__ptr__")) { - seqassert(expr->args[0].value->getId(), "expected IdExpr, got {}", - expr->args[0].value); - auto val = ctx->find(expr->args[0].value->getId()->value); - seqassert(val && val->getVar(), "{} is not a variable", - expr->args[0].value->getId()->value); + auto ei = cast(expr->getExpr()); + if (ei && ei->getValue() == "__ptr__:0") { + auto id = cast(expr->begin()->getExpr()); + if (!id) { + // Case where id is guarded by a check + if (auto sexp = cast(expr->begin()->getExpr())) + id = cast(sexp->getExpr()); + } + seqassert(id, "expected IdExpr, got {}", *((*expr)[0].value)); + auto key = id->getValue(); + auto val = ctx->find(key); + seqassert(val && val->getVar(), "{} is not a variable", key); result = make(expr, val->getVar()); return; - } else if (expr->expr->isId("__array__.__new__:0")) { - auto fnt = expr->expr->type->getFunc(); - auto szt = fnt->funcGenerics[0].type->getStatic(); - auto sz = szt->evaluate().getInt(); - auto typ = fnt->funcParent->getClass()->generics[0].type; + } else if (ei && ei->getValue() == "__array__.__new__:0") { + auto fnt = expr->getExpr()->getType()->getFunc(); + auto sz = fnt->funcGenerics[0].type->getIntStatic()->value; + auto typ = fnt->funcParent->getClass()->generics[0].getType(); auto *arrayType = ctx->getModule()->unsafeGetArrayType(getType(typ)); - arrayType->setAstType(expr->getType()); + arrayType->setAstType(expr->getType()->shared_from_this()); result = make(expr, arrayType, sz); return; - } else if (expr->expr->getId() && startswith(expr->expr->getId()->value, - "__internal__.yield_in_no_suspend:0")) { + } else if (ei && startswith(ei->getValue(), "__internal__.yield_in_no_suspend")) { result = make(expr, getType(expr->getType()), false); return; } - auto ft = expr->expr->type->getFunc(); - seqassert(ft, "not calling function: {}", ft); - auto callee = transform(expr->expr); - bool isVariadic = ft->ast->hasAttr(Attr::CVarArg); + auto ft = expr->getExpr()->getType()->getFunc(); + seqassert(ft, "not calling function"); + auto callee = transform(expr->getExpr()); + bool isVariadic = ft->ast->hasAttribute(Attr::CVarArg); std::vector items; - for (int i = 0; i < expr->args.size(); i++) { - seqassert(!expr->args[i].value->getEllipsis(), "ellipsis not elided"); - if (i + 1 == expr->args.size() && isVariadic) { - auto call = expr->args[i].value->getCall(); - seqassert( - call && call->expr->getId() && - startswith(call->expr->getId()->value, std::string(TYPE_TUPLE) + "["), - "expected *args tuple: '{}'", call->toString()); - for (auto &arg : call->args) + size_t i = 0; + for (auto &a : *expr) { + seqassert(!cast(a.value), "ellipsis not elided"); + if (i + 1 == expr->size() && isVariadic) { + auto call = cast(a.value); + seqassert(call, "expected *args tuple: '{}'", call->toString(0)); + for (auto &arg : *call) items.emplace_back(transform(arg.value)); } else { - items.emplace_back(transform(expr->args[i].value)); + items.emplace_back(transform(a.value)); } + i++; } result = make(expr, callee, std::move(items)); } void TranslateVisitor::visit(DotExpr *expr) { - if (expr->member == "__atomic__" || expr->member == "__elemsize__" || - expr->member == "__contents_atomic__") { - seqassert(expr->expr->getId(), "expected IdExpr, got {}", expr->expr); - auto type = ctx->find(expr->expr->getId()->value)->getType(); - seqassert(type, "{} is not a type", expr->expr->getId()->value); + if (expr->getMember() == "__atomic__" || expr->getMember() == "__elemsize__" || + expr->getMember() == "__contents_atomic__") { + auto ei = cast(expr->getExpr()); + seqassert(ei, "expected IdExpr, got {}", *(expr->getExpr())); + auto t = TypecheckVisitor(ctx->cache->typeCtx).extractType(ei->getType()); + auto type = ctx->find(t->realizedName())->getType(); + seqassert(type, "{} is not a type", ei->getValue()); result = make( expr, type, - expr->member == "__atomic__" + expr->getMember() == "__atomic__" ? ir::TypePropertyInstr::Property::IS_ATOMIC - : (expr->member == "__contents_atomic__" + : (expr->getMember() == "__contents_atomic__" ? ir::TypePropertyInstr::Property::IS_CONTENT_ATOMIC : ir::TypePropertyInstr::Property::SIZEOF)); } else { - result = make(expr, transform(expr->expr), expr->member); + result = + make(expr, transform(expr->getExpr()), expr->getMember()); } } @@ -275,24 +347,24 @@ void TranslateVisitor::visit(PipeExpr *expr) { }; std::vector stages; - auto *firstStage = transform(expr->items[0].expr); + auto *firstStage = transform((*expr)[0].expr); auto firstIsGen = isGen(firstStage); stages.emplace_back(firstStage, std::vector(), firstIsGen, false); // Pipeline without generators (just function call sugar) auto simplePipeline = !firstIsGen; - for (auto i = 1; i < expr->items.size(); i++) { - auto call = expr->items[i].expr->getCall(); - seqassert(call, "{} is not a call", expr->items[i].expr); + for (auto i = 1; i < expr->size(); i++) { + auto call = cast((*expr)[i].expr); + seqassert(call, "{} is not a call", *((*expr)[i].expr)); - auto fn = transform(call->expr); - if (i + 1 != expr->items.size()) + auto fn = transform(call->getExpr()); + if (i + 1 != expr->size()) simplePipeline &= !isGen(fn); std::vector args; - args.reserve(call->args.size()); - for (auto &a : call->args) - args.emplace_back(a.value->getEllipsis() ? nullptr : transform(a.value)); + args.reserve(call->size()); + for (auto &a : *call) + args.emplace_back(cast(a.value) ? nullptr : transform(a.value)); stages.emplace_back(fn, args, isGen(fn), false); } @@ -307,8 +379,8 @@ void TranslateVisitor::visit(PipeExpr *expr) { result = make(expr, cv.clone(stages[i].getCallee()), newArgs); } } else { - for (int i = 0; i < expr->items.size(); i++) - if (expr->items[i].op == "||>") + for (int i = 0; i < expr->size(); i++) + if ((*expr)[i].op == "||>") stages[i].setParallel(); // This is a statement in IR. ctx->getSeries()->push_back(make(expr, stages)); @@ -318,15 +390,15 @@ void TranslateVisitor::visit(PipeExpr *expr) { void TranslateVisitor::visit(StmtExpr *expr) { auto *bodySeries = make(expr, "body"); ctx->addSeries(bodySeries); - for (auto &s : expr->stmts) + for (auto &s : *expr) transform(s); ctx->popSeries(); - result = make(expr, bodySeries, transform(expr->expr)); + result = make(expr, bodySeries, transform(expr->getExpr())); } /************************************************************************************/ -ir::Value *TranslateVisitor::transform(const StmtPtr &stmt) { +ir::Value *TranslateVisitor::transform(Stmt *stmt) { TranslateVisitor v(ctx); v.setSrcInfo(stmt->getSrcInfo()); stmt->accept(v); @@ -340,7 +412,7 @@ void TranslateVisitor::defaultVisit(Stmt *n) { } void TranslateVisitor::visit(SuiteStmt *stmt) { - for (auto &s : stmt->stmts) + for (auto *s : *stmt) transform(s); } @@ -351,139 +423,151 @@ void TranslateVisitor::visit(ContinueStmt *stmt) { } void TranslateVisitor::visit(ExprStmt *stmt) { - if (stmt->expr->getCall() && - stmt->expr->getCall()->expr->isId("__internal__.yield_final:0")) { - result = make(stmt, transform(stmt->expr->getCall()->args[0].value), - true); + IdExpr *ei = nullptr; + auto ce = cast(stmt->getExpr()); + if (ce && (ei = cast(ce->getExpr())) && + ei->getValue() == "__internal__.yield_final:0") { + result = make(stmt, transform((*ce)[0].value), true); ctx->getBase()->setGenerator(); } else { - result = transform(stmt->expr); + result = transform(stmt->getExpr()); } } void TranslateVisitor::visit(AssignStmt *stmt) { - if (stmt->lhs && stmt->lhs->isId(VAR_ARGV)) + if (stmt->getLhs() && cast(stmt->getLhs()) && + cast(stmt->getLhs())->getValue() == VAR_ARGV) return; - if (stmt->isUpdate()) { - seqassert(stmt->lhs->getId(), "expected IdExpr, got {}", stmt->lhs); - auto val = ctx->find(stmt->lhs->getId()->value); - seqassert(val && val->getVar(), "{} is not a variable", stmt->lhs->getId()->value); - result = make(stmt, val->getVar(), transform(stmt->rhs)); - return; - } + auto lei = cast(stmt->getLhs()); + seqassert(lei, "expected IdExpr, got {}", *(stmt->getLhs())); + auto var = lei->getValue(); - seqassert(stmt->lhs->getId(), "expected IdExpr, got {}", stmt->lhs); - auto var = stmt->lhs->getId()->value; - if (!stmt->rhs || (!stmt->rhs->isType() && stmt->rhs->type)) { - auto isGlobal = in(ctx->cache->globals, var); - ir::Var *v = nullptr; + auto isGlobal = in(ctx->cache->globals, var); + ir::Var *v = nullptr; - // dead declaration due to static compilation - if (!stmt->rhs && !stmt->type && !stmt->lhs->type->getClass()) - return; + if (stmt->isUpdate()) { + auto val = ctx->find(lei->getValue()); + seqassert(val && val->getVar(), "{} is not a variable", lei->getValue()); + v = val->getVar(); - if (isGlobal) { - seqassert(ctx->find(var) && ctx->find(var)->getVar(), "cannot find global '{}'", - var); - v = ctx->find(var)->getVar(); + if (!v->getType()) { v->setSrcInfo(stmt->getSrcInfo()); - v->setType(getType((stmt->rhs ? stmt->rhs : stmt->lhs)->getType())); - } else { - v = make(stmt, getType((stmt->rhs ? stmt->rhs : stmt->lhs)->getType()), - false, false, var); - ctx->getBase()->push_back(v); - ctx->add(TranslateItem::Var, var, v); - } - // Check if it is a C variable - if (stmt->lhs->hasAttr(ExprAttr::ExternVar)) { - v->setExternal(); - v->setName(ctx->cache->rev(var)); - v->setGlobal(); - return; + v->setType(getType(stmt->getRhs()->getType())); } + result = make(stmt, v, transform(stmt->getRhs())); + return; + } - if (stmt->rhs) - result = make(stmt, v, transform(stmt->rhs)); + if (!stmt->getLhs()->getType()->isInstantiated() || + (stmt->getLhs()->getType()->is(TYPE_TYPE))) { + // LOG("{} {}", getSrcInfo(), stmt->toString(0)); + return; // type aliases/fn aliases etc + } + + if (isGlobal) { + seqassert(ctx->find(var) && ctx->find(var)->getVar(), "cannot find global '{}'", + var); + v = ctx->find(var)->getVar(); + v->setSrcInfo(stmt->getSrcInfo()); + v->setType(getType((stmt->getRhs() ? stmt->getRhs() : stmt->getLhs())->getType())); + } else { + v = make( + stmt, getType((stmt->getRhs() ? stmt->getRhs() : stmt->getLhs())->getType()), + false, false, var); + ctx->getBase()->push_back(v); + ctx->add(TranslateItem::Var, var, v); + } + // Check if it is a C variable + if (stmt->getLhs()->hasAttribute(Attr::ExprExternVar)) { + v->setExternal(); + v->setName(ctx->cache->rev(var)); + v->setGlobal(); + return; + } + + if (stmt->getRhs()) { + result = make(stmt, v, transform(stmt->getRhs())); } } void TranslateVisitor::visit(AssignMemberStmt *stmt) { - result = make(stmt, transform(stmt->lhs), stmt->member, - transform(stmt->rhs)); + result = make(stmt, transform(stmt->getLhs()), stmt->getMember(), + transform(stmt->getRhs())); } void TranslateVisitor::visit(ReturnStmt *stmt) { - result = make(stmt, stmt->expr ? transform(stmt->expr) : nullptr); + result = make(stmt, stmt->getExpr() ? transform(stmt->getExpr()) + : nullptr); } void TranslateVisitor::visit(YieldStmt *stmt) { - result = make(stmt, stmt->expr ? transform(stmt->expr) : nullptr); + result = make(stmt, + stmt->getExpr() ? transform(stmt->getExpr()) : nullptr); ctx->getBase()->setGenerator(); } void TranslateVisitor::visit(WhileStmt *stmt) { - auto loop = make(stmt, transform(stmt->cond), + auto loop = make(stmt, transform(stmt->getCond()), make(stmt, "body")); ctx->addSeries(cast(loop->getBody())); - transform(stmt->suite); + transform(stmt->getSuite()); ctx->popSeries(); result = loop; } void TranslateVisitor::visit(ForStmt *stmt) { std::unique_ptr os = nullptr; - if (stmt->decorator) { + if (stmt->getDecorator()) { os = std::make_unique(); - auto c = stmt->decorator->getCall(); - seqassert(c, "for par is not a call: {}", stmt->decorator); - auto fc = c->expr->getType()->getFunc(); - seqassert(fc && fc->ast->name == "std.openmp.for_par:0", + auto c = cast(stmt->getDecorator()); + seqassert(c, "for par is not a call: {}", *(stmt->getDecorator())); + auto fc = c->getExpr()->getType()->getFunc(); + seqassert(fc && fc->ast->getName() == "std.openmp.for_par.0:0", "for par is not a function"); - auto schedule = - fc->funcGenerics[0].type->getStatic()->expr->staticValue.getString(); - bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt(); - auto threads = transform(c->args[0].value); - auto chunk = transform(c->args[1].value); - int64_t collapse = - fc->funcGenerics[2].type->getStatic()->expr->staticValue.getInt(); - bool gpu = fc->funcGenerics[3].type->getStatic()->expr->staticValue.getInt(); + auto schedule = fc->funcGenerics[0].type->getStrStatic()->value; + bool ordered = fc->funcGenerics[1].type->getBoolStatic()->value; + auto threads = transform((*c)[0].value); + auto chunk = transform((*c)[1].value); + auto collapse = fc->funcGenerics[2].type->getIntStatic()->value; + bool gpu = fc->funcGenerics[3].type->getBoolStatic()->value; os = std::make_unique(schedule, threads, chunk, ordered, collapse, gpu); } - seqassert(stmt->var->getId(), "expected IdExpr, got {}", stmt->var); - auto varName = stmt->var->getId()->value; + seqassert(cast(stmt->getVar()), "expected IdExpr, got {}", *(stmt->getVar())); + auto varName = cast(stmt->getVar())->getValue(); ir::Var *var = nullptr; - if (!ctx->find(varName) || !stmt->var->hasAttr(ExprAttr::Dominated)) { - var = make(stmt, getType(stmt->var->getType()), false, false, varName); + if (!ctx->find(varName) || !stmt->getVar()->hasAttribute(Attr::ExprDominated)) { + var = + make(stmt, getType(stmt->getVar()->getType()), false, false, varName); } else { var = ctx->find(varName)->getVar(); } ctx->getBase()->push_back(var); auto bodySeries = make(stmt, "body"); - auto loop = make(stmt, transform(stmt->iter), bodySeries, var); + auto loop = make(stmt, transform(stmt->getIter()), bodySeries, var); if (os) loop->setSchedule(std::move(os)); ctx->add(TranslateItem::Var, varName, var); ctx->addSeries(cast(loop->getBody())); - transform(stmt->suite); + transform(stmt->getSuite()); ctx->popSeries(); result = loop; } void TranslateVisitor::visit(IfStmt *stmt) { - auto cond = transform(stmt->cond); + auto cond = transform(stmt->getCond()); auto trueSeries = make(stmt, "ifstmt_true"); ctx->addSeries(trueSeries); - transform(stmt->ifSuite); + transform(stmt->getIf()); ctx->popSeries(); ir::SeriesFlow *falseSeries = nullptr; - if (stmt->elseSuite) { + if (stmt->getElse()) { falseSeries = make(stmt, "ifstmt_false"); ctx->addSeries(falseSeries); - transform(stmt->elseSuite); + transform(stmt->getElse()); ctx->popSeries(); } result = make(stmt, cond, trueSeries, falseSeries); @@ -492,32 +576,35 @@ void TranslateVisitor::visit(IfStmt *stmt) { void TranslateVisitor::visit(TryStmt *stmt) { auto *bodySeries = make(stmt, "body"); ctx->addSeries(bodySeries); - transform(stmt->suite); + transform(stmt->getSuite()); ctx->popSeries(); auto finallySeries = make(stmt, "finally"); - if (stmt->finally) { + if (stmt->getFinally()) { ctx->addSeries(finallySeries); - transform(stmt->finally); + transform(stmt->getFinally()); ctx->popSeries(); } auto *tc = make(stmt, bodySeries, finallySeries); - for (auto &c : stmt->catches) { + for (auto *c : *stmt) { auto *catchBody = make(stmt, "catch"); - auto *excType = c.exc ? getType(c.exc->getType()) : nullptr; + auto *excType = c->getException() + ? getType(TypecheckVisitor(ctx->cache->typeCtx) + .extractType(c->getException()->getType())) + : nullptr; ir::Var *catchVar = nullptr; - if (!c.var.empty()) { - if (!ctx->find(c.var) || !c.exc->hasAttr(ExprAttr::Dominated)) { - catchVar = make(stmt, excType, false, false, c.var); + if (!c->getVar().empty()) { + if (!ctx->find(c->getVar()) || !c->hasAttribute(Attr::ExprDominated)) { + catchVar = make(stmt, excType, false, false, c->getVar()); } else { - catchVar = ctx->find(c.var)->getVar(); + catchVar = ctx->find(c->getVar())->getVar(); } - ctx->add(TranslateItem::Var, c.var, catchVar); + ctx->add(TranslateItem::Var, c->getVar(), catchVar); ctx->getBase()->push_back(catchVar); } ctx->addSeries(catchBody); - transform(c.suite); + transform(c->getSuite()); ctx->popSeries(); tc->push_back(ir::TryCatchFlow::Catch(catchBody, excType, catchVar)); } @@ -525,12 +612,13 @@ void TranslateVisitor::visit(TryStmt *stmt) { } void TranslateVisitor::visit(ThrowStmt *stmt) { - result = make(stmt, stmt->expr ? transform(stmt->expr) : nullptr); + result = make(stmt, + stmt->getExpr() ? transform(stmt->getExpr()) : nullptr); } void TranslateVisitor::visit(FunctionStmt *stmt) { // Process all realizations. - transformFunctionRealizations(stmt->name, stmt->attributes.has(Attr::LLVM)); + transformFunctionRealizations(stmt->getName(), stmt->hasAttribute(Attr::LLVM)); } void TranslateVisitor::visit(ClassStmt *stmt) { @@ -540,11 +628,11 @@ void TranslateVisitor::visit(ClassStmt *stmt) { /************************************************************************************/ -codon::ir::types::Type *TranslateVisitor::getType(const types::TypePtr &t) { - seqassert(t && t->getClass(), "{} is not a class", t); - std::string name = t->getClass()->realizedTypeName(); +codon::ir::types::Type *TranslateVisitor::getType(types::Type *t) { + seqassert(t && t->getClass(), "not a class"); + std::string name = t->getClass()->ClassType::realizedName(); auto i = ctx->find(name); - seqassert(i, "type {} not realized", t); + seqassert(i, "type {} not realized: {}", t->debugString(2), name); return i->getType(); } @@ -560,9 +648,9 @@ void TranslateVisitor::transformFunctionRealizations(const std::string &name, const auto &ast = real.second->ast; seqassert(ast, "AST not set for {}", real.first); if (!isLLVM) - transformFunction(real.second->type.get(), ast.get(), real.second->ir); + transformFunction(real.second->type.get(), ast, real.second->ir); else - transformLLVMFunction(real.second->type.get(), ast.get(), real.second->ir); + transformLLVMFunction(real.second->type.get(), ast, real.second->ir); } } @@ -570,37 +658,35 @@ void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *as ir::Func *func) { std::vector names; std::vector indices; - for (int i = 0, j = 0; i < ast->args.size(); i++) - if (ast->args[i].status == Param::Normal) { - if (!type->getArgTypes()[j]->getFunc()) { - names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]); + for (int i = 0, j = 0; i < ast->size(); i++) + if ((*ast)[i].isValue()) { + if (!(*type)[j]->getFunc()) { + names.push_back(ctx->cache->rev((*ast)[i].name)); indices.push_back(i); } j++; } - if (ast->hasAttr(Attr::CVarArg)) { + if (ast->hasAttribute(Attr::CVarArg)) { names.pop_back(); indices.pop_back(); } // TODO: refactor IR attribute API std::map attr; - attr[".module"] = ast->attributes.module; - for (auto &a : ast->attributes.customAttr) { - attr[a] = ""; - } + attr[".module"] = ast->getAttribute(Attr::Module)->value; + for (auto it = ast->attributes_begin(); it != ast->attributes_end(); ++it) + attr[*it] = ""; func->setAttribute(std::make_unique(attr)); for (int i = 0; i < names.size(); i++) - func->getArgVar(names[i])->setSrcInfo(ast->args[indices[i]].getSrcInfo()); + func->getArgVar(names[i])->setSrcInfo((*ast)[indices[i]].getSrcInfo()); // func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); - if (!ast->attributes.has(Attr::C) && !ast->attributes.has(Attr::Internal)) { + if (!ast->hasAttribute(Attr::C) && !ast->hasAttribute(Attr::Internal)) { ctx->addBlock(); for (auto i = 0; i < names.size(); i++) - ctx->add(TranslateItem::Var, ast->args[indices[i]].name, - func->getArgVar(names[i])); + ctx->add(TranslateItem::Var, (*ast)[indices[i]].name, func->getArgVar(names[i])); auto body = make(ast, "body"); ctx->bases.push_back(cast(func)); ctx->addSeries(body); - transform(ast->suite); + transform(ast->getSuite()); ctx->popSeries(); ctx->bases.pop_back(); cast(func)->setBody(body); @@ -612,38 +698,44 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt ir::Func *func) { std::vector names; std::vector indices; - for (int i = 0, j = 1; i < ast->args.size(); i++) - if (ast->args[i].status == Param::Normal) { - names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]); + for (int i = 0, j = 1; i < ast->size(); i++) + if ((*ast)[i].isValue()) { + names.push_back(ctx->cache->reverseIdentifierLookup[(*ast)[i].name]); indices.push_back(i); j++; } auto f = cast(func); // TODO: refactor IR attribute API std::map attr; - attr[".module"] = ast->attributes.module; - for (auto &a : ast->attributes.customAttr) - attr[a] = ""; + attr[".module"] = ast->getAttribute(Attr::Module)->value; + for (auto it = ast->attributes_begin(); it != ast->attributes_end(); ++it) + attr[*it] = ""; func->setAttribute(std::make_unique(attr)); for (int i = 0; i < names.size(); i++) - func->getArgVar(names[i])->setSrcInfo(ast->args[indices[i]].getSrcInfo()); + func->getArgVar(names[i])->setSrcInfo((*ast)[indices[i]].getSrcInfo()); - seqassert(ast->suite->firstInBlock() && ast->suite->firstInBlock()->getExpr() && - ast->suite->firstInBlock()->getExpr()->expr->getString(), - "LLVM function does not begin with a string"); + seqassert( + ast->getSuite()->firstInBlock() && + cast(ast->getSuite()->firstInBlock()) && + cast(cast(ast->getSuite()->firstInBlock())->getExpr()), + "LLVM function does not begin with a string"); std::istringstream sin( - ast->suite->firstInBlock()->getExpr()->expr->getString()->getValue()); + cast(cast(ast->getSuite()->firstInBlock())->getExpr()) + ->getValue()); std::vector literals; - auto &ss = ast->suite->getSuite()->stmts; - for (int i = 1; i < ss.size(); i++) { - if (auto *ei = ss[i]->getExpr()->expr->getInt()) { // static integer expression - literals.emplace_back(*(ei->intValue)); - } else if (auto *es = ss[i]->getExpr()->expr->getString()) { // static string - literals.emplace_back(es->getValue()); + auto ss = cast(ast->getSuite()); + for (int i = 1; i < ss->size(); i++) { + if (auto sti = cast((*ss)[i])->getExpr()->getType()->getIntStatic()) { + literals.emplace_back(sti->value); + } else if (auto sts = + cast((*ss)[i])->getExpr()->getType()->getStrStatic()) { + literals.emplace_back(sts->value); } else { - seqassert(ss[i]->getExpr()->expr->getType(), "invalid LLVM type argument: {}", - ss[i]->getExpr()->toString()); - literals.emplace_back(getType(ss[i]->getExpr()->expr->getType())); + seqassert(cast((*ss)[i])->getExpr()->getType(), + "invalid LLVM type argument: {}", (*ss)[i]->toString(0)); + literals.emplace_back( + getType(TypecheckVisitor(ctx->cache->typeCtx) + .extractType(cast((*ss)[i])->getExpr()->getType()))); } } bool isDeclare = true; diff --git a/codon/parser/visitors/translate/translate.h b/codon/parser/visitors/translate/translate.h index 01d20861..4e3b4e35 100644 --- a/codon/parser/visitors/translate/translate.h +++ b/codon/parser/visitors/translate/translate.h @@ -23,10 +23,11 @@ class TranslateVisitor : public CallbackASTVisitor { public: explicit TranslateVisitor(std::shared_ptr ctx); - static codon::ir::Func *apply(Cache *cache, const StmtPtr &stmts); + static codon::ir::Func *apply(Cache *cache, Stmt *stmts); + void translateStmts(Stmt *stmts); - ir::Value *transform(const ExprPtr &expr) override; - ir::Value *transform(const StmtPtr &stmt) override; + ir::Value *transform(Expr *expr) override; + ir::Value *transform(Stmt *stmt) override; private: void defaultVisit(Expr *expr) override; @@ -40,6 +41,7 @@ class TranslateVisitor : public CallbackASTVisitor { void visit(StringExpr *) override; void visit(IdExpr *) override; void visit(IfExpr *) override; + void visit(GeneratorExpr *) override; void visit(CallExpr *) override; void visit(DotExpr *) override; void visit(YieldExpr *) override; @@ -64,7 +66,7 @@ class TranslateVisitor : public CallbackASTVisitor { void visit(CommentStmt *) override {} private: - ir::types::Type *getType(const types::TypePtr &t); + ir::types::Type *getType(types::Type *t); void transformFunctionRealizations(const std::string &name, bool isLLVM); void transformFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func); diff --git a/codon/parser/visitors/translate/translate_ctx.cpp b/codon/parser/visitors/translate/translate_ctx.cpp index 27b3ddfc..4b5273be 100644 --- a/codon/parser/visitors/translate/translate_ctx.cpp +++ b/codon/parser/visitors/translate/translate_ctx.cpp @@ -9,6 +9,7 @@ #include "codon/parser/ctx.h" #include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/ctx.h" +#include "codon/parser/visitors/typecheck/typecheck.h" namespace codon::ast { @@ -21,20 +22,23 @@ std::shared_ptr TranslateContext::find(const std::string &name) c std::shared_ptr ret = nullptr; auto tt = cache->typeCtx->find(name); if (tt && tt->isType() && tt->type->canRealize()) { + auto t = tt->getType(); + if (name != t->realizedName()) // type prefix + t = TypecheckVisitor(cache->typeCtx).extractType(t); + auto n = t->getClass()->name; + if (!in(cache->classes, n) || !in(cache->classes[n].realizations, name)) + return nullptr; ret = std::make_shared(TranslateItem::Type, bases[0]); - seqassertn(in(cache->classes, tt->type->getClass()->name) && - in(cache->classes[tt->type->getClass()->name].realizations, name), - "cannot find type realization {}", name); - ret->handle.type = - cache->classes[tt->type->getClass()->name].realizations[name]->ir; + ret->handle.type = cache->classes[n].realizations[name]->ir; } else if (tt && tt->type->getFunc() && tt->type->canRealize()) { ret = std::make_shared(TranslateItem::Func, bases[0]); seqassertn( - in(cache->functions, tt->type->getFunc()->ast->name) && - in(cache->functions[tt->type->getFunc()->ast->name].realizations, name), - "cannot find type realization {}", name); + in(cache->functions, tt->type->getFunc()->ast->getName()) && + in(cache->functions[tt->type->getFunc()->ast->getName()].realizations, + name), + "cannot find function realization {}", name); ret->handle.func = - cache->functions[tt->type->getFunc()->ast->name].realizations[name]->ir; + cache->functions[tt->type->getFunc()->ast->getName()].realizations[name]->ir; } return ret; } diff --git a/codon/parser/visitors/translate/translate_ctx.h b/codon/parser/visitors/translate/translate_ctx.h index 6b7db7cf..507d7349 100644 --- a/codon/parser/visitors/translate/translate_ctx.h +++ b/codon/parser/visitors/translate/translate_ctx.h @@ -52,7 +52,7 @@ struct TranslateContext : public Context { /// Stack of IR series (blocks). std::vector series; /// Stack of sequence items for attribute initialization. - std::vector>> seqItems; + std::vector>> seqItems; public: TranslateContext(Cache *cache); diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index 94a51e2a..78b56933 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -6,465 +6,541 @@ #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" +#include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; using namespace codon::error; +using namespace codon::matcher; namespace codon::ast { using namespace types; -/// Typecheck identifiers. If an identifier is a static variable, evaluate it and -/// replace it with its value (e.g., a @c IntExpr ). Also ensure that the identifier of -/// a generic function or a type is fully qualified (e.g., replace `Ptr` with -/// `Ptr[byte]`). +/// Typecheck identifiers. +/// If an identifier is a static variable, evaluate it and replace it with its value +/// (e.g., change `N` to `IntExpr(16)`). +/// If the identifier of a generic is fully qualified, use its qualified name +/// (e.g., replace `Ptr` with `Ptr[byte]`). void TypecheckVisitor::visit(IdExpr *expr) { - // Replace identifiers that have been superseded by domination analysis during the - // simplification - while (auto s = in(ctx->cache->replacements, expr->value)) - expr->value = s->first; - - auto val = ctx->find(expr->value); + auto val = ctx->find(expr->getValue(), getTime()); if (!val) { - // Handle overloads - if (in(ctx->cache->overloads, expr->value)) { - val = ctx->forceFind(getDispatch(expr->value)->ast->name); - } - seqassert(val, "cannot find '{}'", expr->value); - } - unify(expr->type, ctx->instantiate(val->type)); - - if (val->type->isStaticType()) { - // Evaluate static expression if possible - expr->staticValue.type = StaticValue::Type(val->type->isStaticType()); - auto s = val->type->getStatic(); - seqassert(!expr->staticValue.evaluated, "expected unevaluated expression: {}", - expr->toString()); - if (s && s->expr->staticValue.evaluated) { - // Replace the identifier with static expression - if (s->expr->staticValue.type == StaticValue::STRING) - resultExpr = transform(N(s->expr->staticValue.getString())); - else - resultExpr = transform(N(s->expr->staticValue.getInt())); - } - return; + E(Error::ID_NOT_FOUND, expr, expr->getValue()); } - if (val->isType()) - expr->markType(); - - // Realize a type or a function if possible and replace the identifier with the fully - // typed identifier (e.g., `foo` -> `foo[int]`) - if (realize(expr->type)) { - if (!val->isVar()) - expr->value = expr->type->realizedName(); - expr->setDone(); + // If this is an overload, use the dispatch function + if (isUnbound(expr) && hasOverloads(val->getName())) { + val = ctx->forceFind(getDispatch(val->getName())->getFuncName()); } -} - -/// See @c transformDot for details. -void TypecheckVisitor::visit(DotExpr *expr) { - // Make sure to unify the current type with the transformed type - if ((resultExpr = transformDot(expr))) - unify(expr->type, resultExpr->type); - if (!expr->type) - unify(expr->type, ctx->getUnbound()); -} -/// Find an overload dispatch function for a given overload. If it does not exist and -/// there is more than one overload, generate it. Dispatch functions ensure that a -/// function call is being routed to the correct overload even when dealing with partial -/// functions and decorators. -/// @example This is how dispatch looks like: -/// ```def foo:dispatch(*args, **kwargs): -/// return foo(*args, **kwargs)``` -types::FuncTypePtr TypecheckVisitor::getDispatch(const std::string &fn) { - auto &overloads = ctx->cache->overloads[fn]; - - // Single overload: just return it - if (overloads.size() == 1) - return ctx->forceFind(overloads.front().name)->type->getFunc(); + // If we are accessing an outside variable, capture it or raise an error + auto captured = checkCapture(val); + if (captured) + val = ctx->forceFind(expr->getValue()); - // Check if dispatch exists - for (auto &m : overloads) - if (endswith(ctx->cache->functions[m.name].ast->name, ":dispatch")) - return ctx->cache->functions[m.name].type; + // Replace the variable with its canonical name + expr->value = val->getName(); - // Dispatch does not exist. Generate it - auto name = fn + ":dispatch"; - ExprPtr root; // Root function name used for calling - auto a = ctx->cache->functions[overloads[0].name].ast; - if (!a->attributes.parentClass.empty()) - root = N(N(a->attributes.parentClass), - ctx->cache->reverseIdentifierLookup[fn]); - else - root = N(fn); - root = N(root, N(N("args")), - N(N("kwargs"))); - auto ast = N( - name, nullptr, std::vector{Param("*args"), Param("**kwargs")}, - N(N(root)), Attr({"autogenerated"})); - ctx->cache->reverseIdentifierLookup[name] = ctx->cache->reverseIdentifierLookup[fn]; + // Set up type + unify(expr->getType(), instantiateType(val->getType())); + if (auto f = expr->getType()->getFunc()) { + expr->value = f->getFuncName(); // resolve overloads + } - auto baseType = getFuncTypeBase(2); - auto typ = std::make_shared(baseType, ast.get()); - typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel - 1)); - ctx->add(TypecheckItem::Func, name, typ); + // Realize a type or a function if possible and replace the identifier with + // a qualified identifier or a static expression (e.g., `foo` -> `foo[int]`) + if (expr->getType()->canRealize()) { + if (auto s = expr->getType()->getStatic()) { + resultExpr = transform(s->getStaticExpr()); + return; + } + if (!val->isVar()) { + if (!(expr->hasAttribute(Attr::ExprDoNotRealize) && expr->getType()->getFunc())) { + if (auto r = realize(expr->getType())) { + expr->value = r->realizedName(); + expr->setDone(); + } + } + } else { + realize(expr->getType()); + expr->setDone(); + } + } - overloads.insert(overloads.begin(), {name, 0}); - ctx->cache->functions[name].ast = ast; - ctx->cache->functions[name].type = typ; - prependStmts->push_back(ast); - return typ; + // If this identifier needs __used__ checks (see @c ScopeVisitor), add them + if (expr->hasAttribute(Attr::ExprDominatedUndefCheck)) { + auto controlVar = + fmt::format("{}{}", getUnmangledName(val->canonicalName), VAR_USED_SUFFIX); + if (ctx->find(controlVar, getTime())) { + auto checkStmt = N(N( + N(N("__internal__"), "undef"), N(controlVar), + N(getUnmangledName(val->canonicalName)))); + expr->eraseAttribute(Attr::ExprDominatedUndefCheck); + resultExpr = transform(N(checkStmt, expr)); + } + } } /// Transform a dot expression. Select the best method overload if possible. -/// @param args (optional) list of class method arguments used to select the best -/// overload. nullptr if not available. /// @example /// `obj.__class__` -> `type(obj)` /// `cls.__name__` -> `"class"` (same for functions) /// `obj.method` -> `cls.method(obj, ...)` or /// `cls.method(obj)` if method has `@property` attribute -/// @c getClassMember examples: -/// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value) -/// `optional.member` -> `unwrap(optional).member` -/// `pyobj.member` -> `pyobj._getattr("member")` +/// `obj.member` -> see @c getClassMember /// @return nullptr if no transformation was made -/// See @c getClassMember and @c getBestOverload -ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, - std::vector *args) { +void TypecheckVisitor::visit(DotExpr *expr) { + // Check if this is being called from CallExpr (e.g., foo.bar(...)) + // and mark it as such (to inline useless partial expression) + CallExpr *parentCall = cast(ctx->getParentNode()); + if (parentCall && !parentCall->hasAttribute(Attr::ParentCallExpr)) + parentCall = nullptr; + + // Flatten imports: + // `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import + // `a.B.c` -> canonical name of `c` in class `a.B` + // `python.foo` -> internal.python._get_identifier("foo") + std::vector chain; + Expr *head = expr; + for (; cast(head); head = cast(head)->getExpr()) + chain.push_back(cast(head)->getMember()); + Expr *final = expr; + if (auto id = cast(head)) { + // Case: a.bar.baz + chain.push_back(id->getValue()); + std::reverse(chain.begin(), chain.end()); + auto [pos, val] = getImport(chain); + if (!val) { + // Python capture + seqassert(ctx->getBase()->pyCaptures, "unexpected py capture"); + ctx->getBase()->pyCaptures->insert(chain[0]); + final = N(N("__pyenv__"), N(chain[0])); + } else if (val->getModule() == "std.python" && + ctx->getModule() != val->getModule()) { + // Import from python (e.g., pyobj.foo) + final = transform(N( + N(N(N("internal"), "python"), "_get_identifier"), + N(chain[pos++]))); + } else if (val->getModule() == ctx->getModule() && pos == 1) { + final = transform(N(chain[0]), true); + } else { + final = N(val->canonicalName); + } + while (pos < chain.size()) + final = N(final, chain[pos++]); + } + if (auto dot = cast(final)) { + expr->expr = dot->getExpr(); + expr->member = dot->getMember(); + } else { + resultExpr = transform(final); + return; + } + // Special case: obj.__class__ - if (expr->member == "__class__") { + if (expr->getMember() == "__class__") { /// TODO: prevent cls.__class__ and type(cls) - return transformType(NT(NT("type"), expr->expr)); + resultExpr = transform(N(N(TYPE_TYPE), expr->getExpr())); + return; } - - transform(expr->expr); + expr->expr = transform(expr->getExpr()); // Special case: fn.__name__ // Should go before cls.__name__ to allow printing generic functions - if (expr->expr->type->getFunc() && expr->member == "__name__") { - return transform(N(expr->expr->type->prettyString())); + if (extractType(expr->getExpr())->getFunc() && expr->getMember() == "__name__") { + resultExpr = transform(N(extractType(expr->getExpr())->prettyString())); + return; } - if (expr->expr->type->getPartial() && expr->member == "__name__") { - return transform(N(expr->expr->type->getPartial()->prettyString())); + if (expr->getExpr()->getType()->getPartial() && expr->getMember() == "__name__") { + resultExpr = transform( + N(expr->getExpr()->getType()->getPartial()->prettyString())); + return; } // Special case: fn.__llvm_name__ or obj.__llvm_name__ - if (expr->member == "__llvm_name__") { - if (realize(expr->expr->type)) - return transform(N(expr->expr->type->realizedName())); - return nullptr; + if (expr->getMember() == "__llvm_name__") { + if (realize(expr->getExpr()->getType())) + resultExpr = transform(N(expr->getExpr()->getType()->realizedName())); + return; } // Special case: cls.__name__ - if (expr->expr->isType() && expr->member == "__name__") { - if (realize(expr->expr->type)) - return transform(N(expr->expr->type->prettyString())); - return nullptr; + if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__name__") { + if (realize(expr->getExpr()->getType())) + resultExpr = + transform(N(extractType(expr->getExpr())->prettyString())); + return; + } + if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__repr__") { + resultExpr = + transform(N(N("std.internal.internal.__type_repr__.0"), + expr->getExpr(), N())); + return; } // Special case: expr.__is_static__ - if (expr->member == "__is_static__") { - if (expr->expr->isDone()) - return transform(N(expr->expr->isStatic())); - return nullptr; + if (expr->getMember() == "__is_static__") { + if (expr->getExpr()->isDone()) + resultExpr = + transform(N(bool(expr->getExpr()->getType()->getStatic()))); + return; } // Special case: cls.__id__ - if (expr->expr->isType() && expr->member == "__id__") { - if (auto c = realize(expr->expr->type)) - return transform(N(ctx->cache->classes[c->getClass()->name] - .realizations[c->getClass()->realizedTypeName()] - ->id)); - return nullptr; + if (isTypeExpr(expr->getExpr()) && expr->getMember() == "__id__") { + if (auto c = realize(extractType(expr->getExpr()))) + resultExpr = transform(N(getClassRealization(c)->id)); + return; } // Ensure that the type is known (otherwise wait until it becomes known) - auto typ = expr->expr->getType()->getClass(); + auto typ = extractClassType(expr->getExpr()); if (!typ) - return nullptr; + return; // Check if this is a method or member access - if (ctx->findMethod(typ.get(), expr->member).empty()) - return getClassMember(expr, args); - auto bestMethod = getBestOverload(expr, args); - - if (args) { - unify(expr->type, ctx->instantiate(bestMethod, typ)); - - // A function is deemed virtual if it is marked as such and - // if a base class has a RTTI - bool isVirtual = in(ctx->cache->classes[typ->name].virtuals, expr->member); - isVirtual &= ctx->cache->classes[typ->name].rtti; - isVirtual &= !expr->expr->isType(); - if (isVirtual && !bestMethod->ast->attributes.has(Attr::StaticMethod) && - !bestMethod->ast->attributes.has(Attr::Property)) { - // Special case: route the call through a vtable - if (realize(expr->type)) { - auto fn = expr->type->getFunc(); - auto vid = getRealizationID(typ.get(), fn.get()); - - // Function[Tuple[TArg1, TArg2, ...], TRet] - std::vector ids; - for (auto &t : fn->getArgTypes()) - ids.push_back(NT(t->realizedName())); - auto fnType = NT( - NT("Function"), - std::vector{NT(NT(TYPE_TUPLE), ids), - NT(fn->getRetType()->realizedName())}); - // Function[Tuple[TArg1, TArg2, ...],TRet]( - // __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]] - // ) - auto e = N( - fnType, - N(N(N("__internal__.class_get_rtti_vtable:0"), - expr->expr), - N(vid))); - return transform(e); + while (true) { + auto methods = findMethod(typ, expr->getMember()); + if (methods.empty()) + resultExpr = getClassMember(expr); + + // If the expression changed during the @c getClassMember (e.g., optional unwrap), + // keep going further to be able to select the appropriate method or member + auto oldExpr = expr->getExpr(); + if (!resultExpr && expr->getExpr() != oldExpr) { + typ = extractClassType(expr->getExpr()); + if (!typ) + return; // delay typechecking + continue; + } + + if (!methods.empty()) { + // If a method is ambiguous use dispatch + auto bestMethod = methods.size() > 1 ? getDispatch(getRootName(methods.front())) + : methods.front(); + Expr *e = N(bestMethod->getFuncName()); + e->setType(instantiateType(bestMethod, typ)); + if (isTypeExpr(expr->getExpr())) { + // Static access: `cls.method` + unify(expr->getType(), e->getType()); + } else if (parentCall && !bestMethod->ast->hasAttribute(Attr::StaticMethod) && + !bestMethod->ast->hasAttribute(Attr::Property)) { + // Instance access: `obj.method` from the call + // Modify the call to push `self` to the front of the argument list. + // Avoids creating partial functions. + parentCall->items.insert(parentCall->items.begin(), expr->getExpr()); + unify(expr->getType(), e->getType()); + } else { + // Instance access: `obj.method` + // Transform y.method to a partial call `type(y).method(y, ...)` + std::vector methodArgs; + // Do not add self if a method is marked with @staticmethod + if (!bestMethod->ast->hasAttribute(Attr::StaticMethod)) + methodArgs.emplace_back(expr->getExpr()); + // If a method is marked with @property, just call it directly + if (!bestMethod->ast->hasAttribute(Attr::Property)) + methodArgs.emplace_back(N(EllipsisExpr::PARTIAL)); + e = N(e, methodArgs); } + resultExpr = transform(e); } + break; } +} - // Check if a method is a static or an instance method and transform accordingly - if (expr->expr->isType() || args) { - // Static access: `cls.method` - ExprPtr e = N(bestMethod->ast->name); - unify(e->type, unify(expr->type, ctx->instantiate(bestMethod, typ))); - return transform(e); // Realize if needed - } else { - // Instance access: `obj.method` - // Transform y.method to a partial call `type(obj).method(args, ...)` - std::vector methodArgs; - // Do not add self if a method is marked with @staticmethod - if (!bestMethod->ast->attributes.has(Attr::StaticMethod)) - methodArgs.push_back(expr->expr); - // If a method is marked with @property, just call it directly - if (!bestMethod->ast->attributes.has(Attr::Property)) - methodArgs.push_back(N(EllipsisExpr::PARTIAL)); - auto e = transform(N(N(bestMethod->ast->name), methodArgs)); - unify(expr->type, e->type); - return e; +/// Access identifiers from outside of the current function/class scope. +/// Either use them as-is (globals), capture them if allowed (nonlocals), +/// or raise an error. +bool TypecheckVisitor::checkCapture(const TypeContext::Item &val) { + if (!ctx->isOuter(val)) + return false; + if ((val->isType() && !val->isGeneric()) || val->isFunc()) + return false; + + // Ensure that outer variables can be captured (i.e., do not cross no-capture + // boundary). Example: + // def foo(): + // x = 1 + // class T: # <- boundary (classes cannot capture locals) + // t: int = x # x cannot be accessed + // def bar(): # <- another boundary + // # (class methods cannot capture locals except class generics) + // print(x) # x cannot be accessed + bool crossCaptureBoundary = false; + bool localGeneric = val->isGeneric() && val->getBaseName() == ctx->getBaseName(); + bool parentClassGeneric = + val->isGeneric() && !ctx->getBase()->isType() && + (ctx->bases.size() > 1 && ctx->bases[ctx->bases.size() - 2].isType() && + ctx->bases[ctx->bases.size() - 2].name == val->getBaseName()); + auto i = ctx->bases.size(); + for (; i-- > 0;) { + if (ctx->bases[i].name == val->getBaseName()) + break; + if (!localGeneric && !parentClassGeneric) + crossCaptureBoundary = true; + } + + // Mark methods (class functions that access class generics) + if (parentClassGeneric) + ctx->getBase()->func->setAttribute(Attr::Method); + + // Ignore generics + if (parentClassGeneric || localGeneric) + return false; + + // Case: a global variable that has not been marked with `global` statement + if (val->isVar() && val->getBaseName().empty() && val->scope.size() == 1) { + registerGlobal(val->getName(), true); + return false; } + + // Check if a real variable (not a static) is defined outside the current scope + if (crossCaptureBoundary) + E(Error::ID_CANNOT_CAPTURE, getSrcInfo(), getUnmangledName(val->getName())); + + // Case: a nonlocal variable that has not been marked with `nonlocal` statement + // and capturing is *not* enabled + E(Error::ID_NONLOCAL, getSrcInfo(), getUnmangledName(val->getName())); + return false; } -/// Select the requested class member. -/// @param args (optional) list of class method arguments used to select the best -/// overload if the member is optional. nullptr if not available. +/// Check if a chain (a.b.c.d...) contains an import or a class prefix. +std::pair +TypecheckVisitor::getImport(const std::vector &chain) { + size_t importEnd = 0; + std::string importName; + + // 1. Find the longest prefix that corresponds to the existing import + // (e.g., `a.b.c.d` -> `a.b.c` if there is `import a.b.c`) + TypeContext::Item val = nullptr, importVal = nullptr; + for (auto i = chain.size(); i-- > 0;) { + auto name = join(chain, "/", 0, i + 1); + val = ctx->find(name, getTime()); + if (val && val->type->is("Import") && startswith(val->getName(), "%_import_")) { + importName = getStrLiteral(val->type.get()); + importEnd = i + 1; + importVal = val; + break; + } + } + + // Case: the whole chain is import itself + if (importEnd == chain.size()) + return {importEnd, val}; + + // Find the longest prefix that corresponds to the existing class + // (e.g., `a.b.c` -> `a.b` if there is `class a: class b:`) + std::string itemName; + size_t itemEnd = 0; + auto ictx = importName.empty() ? ctx : getImport(importName)->ctx; + for (auto i = chain.size(); i-- > importEnd;) { + if (ictx->getModule() == "std.python" && importEnd < chain.size()) { + // Special case: importing from Python. + // Fake TypecheckItem that indicates std.python access + val = std::make_shared( + "", "", ictx->getModule(), TypecheckVisitor(ictx).instantiateUnbound()); + return {importEnd, val}; + } else { + auto key = join(chain, ".", importEnd, i + 1); + val = ictx->find(key); + if (val && i + 1 != chain.size() && val->getType()->is("Import") && + startswith(val->getName(), "%_import_")) { + importName = getStrLiteral(val->getType()); + importEnd = i + 1; + importVal = val; + ictx = getImport(importName)->ctx; + i = chain.size(); + continue; + } + bool isOverload = val && val->isFunc() && hasOverloads(val->canonicalName); + if (isOverload && importEnd == i) { // top-level overload + itemName = val->canonicalName, itemEnd = i + 1; + break; + } + // Class member + if (val && !isOverload && + (importName.empty() || val->isType() || !val->isConditional())) { + itemName = val->canonicalName, itemEnd = i + 1; + break; + } + // Resolve the identifier from the import + if (auto i = ctx->find("Import")) { + auto t = extractClassType(i->getType()); + if (findMember(t, key)) + return {importEnd, importVal}; + if (!findMethod(t, key).empty()) + return {importEnd, importVal}; + } + } + } + if (itemName.empty() && importName.empty()) { + if (ctx->getBase()->pyCaptures) + return {1, nullptr}; + E(Error::IMPORT_NO_MODULE, getSrcInfo(), chain[importEnd]); + } else if (itemName.empty()) { + auto import = getImport(importName); + if (!ctx->isStdlibLoading && endswith(importName, "__init__.codon")) { + // Special case: subimport is not yet loaded + // (e.g., import a; a.b.x where a.b is a module as well) + auto file = getImportFile(getArgv(), chain[importEnd], importName, false, + getRootModulePath(), getPluginImportPaths()); + if (file) { // auto-load support + Stmt *s = N(N(N(chain[importEnd]), nullptr)); + if (auto err = ScopingVisitor::apply(ctx->cache, s)) + throw exc::ParserException(std::move(err)); + s = TypecheckVisitor(import->ctx, preamble).transform(s); + prependStmts->push_back(s); + return getImport(chain); + } + } + E(Error::IMPORT_NO_NAME, getSrcInfo(), chain[importEnd], import->name); + } + importEnd = itemEnd; + return {importEnd, val}; +} + +/// Find or generate an overload dispatch function for a given overload. +/// Dispatch functions ensure that a function call is being routed to the correct +/// overload +/// even when dealing with partial functions and decorators. +/// @example This is how dispatch looks like: +/// ```def foo:dispatch(*args, **kwargs): +/// return foo(*args, **kwargs)``` +types::FuncType *TypecheckVisitor::getDispatch(const std::string &fn) { + auto &overloads = ctx->cache->overloads[fn]; + + // Single overload: just return it + if (overloads.size() == 1) + return ctx->forceFind(overloads.front())->type->getFunc(); + + // Check if dispatch exists + for (auto &m : overloads) + if (isDispatch(getFunction(m)->ast)) + return getFunction(m)->getType(); + + // Dispatch does not exist. Generate it + auto name = fmt::format("{}{}", fn, FN_DISPATCH_SUFFIX); + Expr *root; // Root function name used for calling + auto ofn = getFunction(overloads[0]); + if (auto aa = ofn->ast->getAttribute(Attr::ParentClass)) + root = N(N(aa->value), getUnmangledName(fn)); + else + root = N(fn); + root = N(root, N(N("args")), + N(N("kwargs"))); + auto nar = ctx->generateCanonicalName("args"); + auto nkw = ctx->generateCanonicalName("kwargs"); + auto ast = N(name, nullptr, + std::vector{Param("*" + nar), Param("**" + nkw)}, + N(N(root))); + ast->setAttribute(Attr::AutoGenerated); + ast->setAttribute(Attr::Module, ctx->moduleName.path); + ctx->cache->reverseIdentifierLookup[name] = getUnmangledName(fn); + + auto baseType = getFuncTypeBase(2); + auto typ = std::make_shared(baseType.get(), ast, 0); + /// Make sure that parent is set so that the parent type can be passed to the inner + /// call + /// (e.g., A[B].foo -> A.foo.dispatch() { A[B].foo() }) + typ->funcParent = ofn->type->funcParent; + typ = std::static_pointer_cast(typ->generalize(ctx->typecheckLevel - 1)); + ctx->addFunc(name, name, typ); + + overloads.insert(overloads.begin(), name); + ctx->cache->functions[name] = Cache::Function{"", fn, ast, typ}; + ast->setDone(); + return typ.get(); // stored in Cache::Function, hence not destroyed +} + +/// Find a class member. /// @example /// `obj.GENERIC` -> `GENERIC` (IdExpr with generic/static value) /// `optional.member` -> `unwrap(optional).member` /// `pyobj.member` -> `pyobj._getattr("member")` -ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr, - std::vector *args) { - auto typ = expr->expr->getType()->getClass(); +Expr *TypecheckVisitor::getClassMember(DotExpr *expr) { + auto typ = extractClassType(expr->getExpr()); seqassert(typ, "not a class"); // Case: object member access (`obj.member`) - if (!expr->expr->isType()) { - if (auto member = ctx->findMember(typ, expr->member)) { - unify(expr->type, ctx->instantiate(member, typ)); - if (expr->expr->isDone() && realize(expr->type)) + if (!isTypeExpr(expr->getExpr())) { + if (auto member = findMember(typ, expr->getMember())) { + unify(expr->getType(), instantiateType(member->getType(), typ)); + if (!expr->getType()->canRealize() && member->typeExpr) { + unify(expr->getType(), extractType(withClassGenerics(typ, [&]() { + return transform(clean_clone(member->typeExpr)); + }))); + } + if (expr->getExpr()->isDone() && realize(expr->getType())) expr->setDone(); return nullptr; } } // Case: class variable (`Cls.var`) - if (auto cls = in(ctx->cache->classes, typ->name)) - if (auto var = in(cls->classVars, expr->member)) { + if (auto cls = getClass(typ)) + if (auto var = in(cls->classVars, expr->getMember())) { return transform(N(*var)); } // Case: special members - if (auto mtyp = findSpecialMember(expr->member)) { - unify(expr->type, mtyp); - if (expr->expr->isDone() && realize(expr->type)) + std::unordered_map specialMembers{ + {"__elemsize__", "int"}, {"__atomic__", "bool"}, {"__contents_atomic__", "bool"}}; + if (auto mtyp = in(specialMembers, expr->getMember())) { + unify(expr->getType(), getStdLibType(*mtyp)); + if (expr->getExpr()->isDone() && realize(expr->getType())) + expr->setDone(); + return nullptr; + } + if (expr->getMember() == "__name__" && isTypeExpr(expr->getExpr())) { + unify(expr->getType(), getStdLibType("str")); + if (expr->getExpr()->isDone() && realize(expr->getType())) expr->setDone(); return nullptr; } // Case: object generic access (`obj.T`) - TypePtr generic = nullptr; + ClassType::Generic *generic = nullptr; for (auto &g : typ->generics) - if (ctx->cache->reverseIdentifierLookup[g.name] == expr->member) { - generic = g.type; + if (expr->getMember() == getUnmangledName(g.name)) { + generic = &g; break; } if (generic) { - unify(expr->type, generic); - if (!generic->isStaticType()) { - expr->markType(); - } else { - expr->staticValue.type = StaticValue::Type(generic->isStaticType()); - } - if (realize(expr->type)) { - if (!generic->isStaticType()) { - return transform(N(generic->realizedName())); - } else if (generic->getStatic()->expr->staticValue.type == StaticValue::STRING) { - expr->type = nullptr; // to prevent unify(T, Static[T]) error - return transform( - N(generic->getStatic()->expr->staticValue.getString())); - } else { - expr->type = nullptr; // to prevent unify(T, Static[T]) error - return transform(N(generic->getStatic()->expr->staticValue.getInt())); + if (generic->isStatic) { + unify(expr->getType(), generic->getType()); + if (realize(expr->getType())) { + return transform(generic->type->getStatic()->getStaticExpr()); } + } else { + unify(expr->getType(), instantiateTypeVar(generic->getType())); + if (realize(expr->getType())) + return transform(N(generic->getType()->realizedName())); } return nullptr; } // Case: transform `optional.member` to `unwrap(optional).member` if (typ->is(TYPE_OPTIONAL)) { - auto dot = N(transform(N(N(FN_UNWRAP), expr->expr)), - expr->member); - dot->setType(ctx->getUnbound()); // as dot is not transformed - if (auto d = transformDot(dot.get(), args)) - return d; - return dot; + expr->expr = transform(N(N(FN_UNWRAP), expr->getExpr())); + return nullptr; } // Case: transform `pyobj.member` to `pyobj._getattr("member")` if (typ->is("pyobj")) { - return transform( - N(N(expr->expr, "_getattr"), N(expr->member))); + return transform(N(N(expr->getExpr(), "_getattr"), + N(expr->getMember()))); } // Case: transform `union.m` to `__internal__.get_union_method(union, "m", ...)` if (typ->getUnion()) { if (!typ->canRealize()) return nullptr; // delay! - // bool isMember = false; - // for (auto &t: typ->getUnion()->getRealizationTypes()) - // if (ctx->findMethod(t.get(), expr->member).empty()) return transform(N( - N("__internal__.union_member:0"), - std::vector{{"union", expr->expr}, - {"member", N(expr->member)}})); + N(N("__internal__"), "union_member"), + std::vector{{"union", expr->getExpr()}, + {"member", N(expr->getMember())}})); } // For debugging purposes: - // ctx->findMethod(typ.get(), expr->member); - E(Error::DOT_NO_ATTR, expr, typ->prettyString(), expr->member); - return nullptr; -} - -TypePtr TypecheckVisitor::findSpecialMember(const std::string &member) { - if (member == "__elemsize__") - return ctx->getType("int"); - if (member == "__atomic__") - return ctx->getType("bool"); - if (member == "__contents_atomic__") - return ctx->getType("bool"); - if (member == "__name__") - return ctx->getType("str"); - return nullptr; -} - -/// Select the best overloaded function or method. -/// @param expr a DotExpr (for methods) or an IdExpr (for overloaded functions) -/// @param methods List of available methods. -/// @param args (optional) list of class method arguments used to select the best -/// overload if the member is optional. nullptr if not available. -FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr, - std::vector *args) { - // Prepare the list of method arguments if possible - std::unique_ptr> methodArgs; - - if (args) { - // Case: method overloads (DotExpr) - bool addSelf = true; - if (auto dot = expr->getDot()) { - auto methods = - ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false); - if (!methods.empty() && methods.front()->ast->attributes.has(Attr::StaticMethod)) - addSelf = false; - } - - // Case: arguments explicitly provided (by CallExpr) - if (addSelf && expr->getDot() && !expr->getDot()->expr->isType()) { - // Add `self` as the first argument - args->insert(args->begin(), {"", expr->getDot()->expr}); - } - methodArgs = std::make_unique>(); - for (auto &a : *args) - methodArgs->push_back(a); - } else { - // Partially deduced type thus far - auto typeSoFar = expr->getType() ? expr->getType()->getClass() : nullptr; - if (typeSoFar && typeSoFar->getFunc()) { - // Case: arguments available from the previous type checking round - methodArgs = std::make_unique>(); - if (expr->getDot() && !expr->getDot()->expr->isType()) { // Add `self` - auto n = N(); - n->setType(expr->getDot()->expr->type); - methodArgs->push_back({"", n}); - } - for (auto &a : typeSoFar->getFunc()->getArgTypes()) { - auto n = N(); - n->setType(a); - methodArgs->push_back({"", n}); - } - } - } - - bool goDispatch = methodArgs == nullptr; - if (!goDispatch) { - std::vector m; - // Use the provided arguments to select the best method - if (auto dot = expr->getDot()) { - // Case: method overloads (DotExpr) - auto methods = - ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false); - m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs); - } else if (auto id = expr->getId()) { - // Case: function overloads (IdExpr) - std::vector methods; - for (auto &m : ctx->cache->overloads[id->value]) - if (!endswith(m.name, ":dispatch")) - methods.push_back(ctx->cache->functions[m.name].type); - std::reverse(methods.begin(), methods.end()); - m = findMatchingMethods(nullptr, methods, *methodArgs); - } - - if (m.size() == 1) { - return m[0]; - } else if (m.size() > 1) { - for (auto &a : *methodArgs) { - if (auto u = a.value->type->getUnbound()) { - goDispatch = true; - } - } - if (!goDispatch) - return m[0]; - } - } - - if (goDispatch) { - // If overload is ambiguous, route through a dispatch function - std::string name; - if (auto dot = expr->getDot()) { - name = ctx->cache->getMethod(dot->expr->type->getClass(), dot->member); - } else { - name = expr->getId()->value; - } - return getDispatch(name); - } - - // Print a nice error message - std::string argsNice; - if (methodArgs) { - std::vector a; - for (auto &t : *methodArgs) - a.emplace_back(fmt::format("{}", t.value->type->prettyString())); - argsNice = fmt::format("({})", fmt::join(a, ", ")); - } - - if (auto dot = expr->getDot()) { - E(Error::DOT_NO_ATTR_ARGS, expr, dot->expr->type->prettyString(), dot->member, - argsNice); - } else { - E(Error::FN_NO_ATTR_ARGS, expr, ctx->cache->rev(expr->getId()->value), argsNice); - } - + findMethod(typ, expr->getMember()); + E(Error::DOT_NO_ATTR, expr, typ->prettyString(), expr->getMember()); return nullptr; } diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index 0cfbb3ad..f51e60fe 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -3,10 +3,10 @@ #include #include +#include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -16,152 +16,307 @@ namespace codon::ast { using namespace types; +/// Transform walrus (assignment) expression. +/// @example +/// `(expr := var)` -> `var = expr; var` +void TypecheckVisitor::visit(AssignExpr *expr) { + auto a = N(clone(expr->getVar()), expr->getExpr()); + a->cloneAttributesFrom(expr); + resultExpr = transform(N(a, expr->getVar())); +} + /// Transform assignments. Handle dominated assignments, forward declarations, static /// assignments and type/function aliases. +/// See @c transformAssignment and @c unpackAssignments for more details. /// See @c wrapExpr for more examples. void TypecheckVisitor::visit(AssignStmt *stmt) { - // Update statements are handled by @c visitUpdate - if (stmt->isUpdate()) { - transformUpdate(stmt); + if (cast(stmt->lhs) || cast(stmt->lhs)) { + resultStmt = transform(unpackAssignment(stmt->lhs, stmt->rhs)); return; } - seqassert(stmt->lhs->getId(), "invalid AssignStmt {}", stmt->lhs); - std::string lhs = stmt->lhs->getId()->value; - - // Special case: this assignment has been dominated and is not a true assignment but - // an update of the dominating binding. - if (auto changed = in(ctx->cache->replacements, lhs)) { - while (auto s = in(ctx->cache->replacements, lhs)) - lhs = changed->first, changed = s; - if (changed->second) { // has __used__ binding - if (stmt->rhs) { - // Mark the dominating binding as used: `var.__used__ = True` - auto u = N(N(fmt::format("{}.__used__", lhs)), - N(true)); - u->setUpdate(); - prependStmts->push_back(transform(u)); - } else { - // This assignment was a declaration only. Just mark the dominating binding as - // used: `var.__used__ = True` - stmt->lhs = N(fmt::format("{}.__used__", lhs)); - stmt->rhs = N(true); - } - } + bool mustUpdate = stmt->isUpdate() || stmt->isAtomicUpdate(); + mustUpdate |= stmt->getLhs()->hasAttribute(Attr::ExprDominated); + mustUpdate |= stmt->getLhs()->hasAttribute(Attr::ExprDominatedUsed); + if (cast(stmt->getRhs()) && + cast(stmt->getRhs())->isInPlace()) { + // Update case: a += b + seqassert(!stmt->getTypeExpr(), "invalid AssignStmt {}", stmt->toString(0)); + mustUpdate = true; + } - if (endswith(lhs, ".__used__") || !stmt->rhs) { - // unneeded declaration (unnecessary used or binding) - resultStmt = transform(N()); - return; - } + resultStmt = transformAssignment(stmt, mustUpdate); + if (stmt->getLhs()->hasAttribute(Attr::ExprDominatedUsed)) { + // If this is dominated, set __used__ if needed + stmt->getLhs()->eraseAttribute(Attr::ExprDominatedUsed); + auto e = cast(stmt->getLhs()); + seqassert(e, "dominated bad assignment"); + resultStmt = transform(N( + resultStmt, + N( + N(format("{}{}", getUnmangledName(e->getValue()), VAR_USED_SUFFIX)), + N(true), nullptr, AssignStmt::UpdateMode::Update))); + } +} - // Change this to the update and follow the update logic - stmt->setUpdate(); - transformUpdate(stmt); - return; +/// Transform deletions. +/// @example +/// `del a` -> `a = type(a)()` and remove `a` from the context +/// `del a[x]` -> `a.__delitem__(x)` +void TypecheckVisitor::visit(DelStmt *stmt) { + if (auto idx = cast(stmt->getExpr())) { + resultStmt = N(transform( + N(N(idx->getExpr(), "__delitem__"), idx->getIndex()))); + } else if (auto ei = cast(stmt->getExpr())) { + // Assign `a` to `type(a)()` to mark it for deletion + resultStmt = transform(N( + stmt->getExpr(), + N(N(N(TYPE_TYPE), clone(stmt->getExpr()))), nullptr, + AssignStmt::Update)); + + // Allow deletion *only* if the binding is dominated + auto val = ctx->find(ei->getValue()); + if (!val) + E(Error::ID_NOT_FOUND, ei, ei->getValue()); + if (ctx->getScope() != val->scope) + E(Error::DEL_NOT_ALLOWED, ei, ei->getValue()); + ctx->remove(ei->getValue()); + ctx->remove(getUnmangledName(ei->getValue())); + } else { + E(Error::DEL_INVALID, stmt); } +} - transform(stmt->rhs); - transformType(stmt->type); - if (!stmt->rhs) { - // Forward declarations (e.g., dominating bindings, C imports etc.). - // The type is unknown and will be deduced later - unify(stmt->lhs->type, ctx->getUnbound(stmt->lhs->getSrcInfo())); - if (stmt->type) { - unify(stmt->lhs->type, - ctx->instantiate(stmt->type->getSrcInfo(), stmt->type->getType())); - } - ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type); - if (in(ctx->cache->globals, lhs)) - ctx->cache->globals[lhs].first = true; - if (realize(stmt->lhs->type) || !stmt->type) - stmt->setDone(); - } else if (stmt->type && stmt->type->getType()->isStaticType()) { - // Static assignments (e.g., `x: Static[int] = 5`) - if (!stmt->rhs->isStatic()) - E(Error::EXPECTED_STATIC, stmt->rhs); - seqassert(stmt->rhs->staticValue.evaluated, "static not evaluated"); - unify(stmt->lhs->type, - unify(stmt->type->type, Type::makeStatic(ctx->cache, stmt->rhs))); - auto val = ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type); - if (in(ctx->cache->globals, lhs)) { - // Make globals always visible! - ctx->cache->globals[lhs].first = true; - ctx->addToplevel(lhs, val); - } - if (realize(stmt->lhs->type)) - stmt->setDone(); +/// Unpack an assignment expression `lhs = rhs` into a list of simple assignment +/// expressions (e.g., `a = b`, `a.x = b`, or `a[x] = b`). +/// Handle Python unpacking rules. +/// @example +/// `(a, b) = c` -> `a = c[0]; b = c[1]` +/// `a, b = c` -> `a = c[0]; b = c[1]` +/// `[a, *x, b] = c` -> `a = c[0]; x = c[1:-1]; b = c[-1]`. +/// Non-trivial right-hand expressions are first stored in a temporary variable. +/// @example +/// `a, b = c, d + foo()` -> `assign = (c, d + foo); a = assign[0]; b = assign[1]`. +/// Each assignment is unpacked recursively to allow cases like `a, (b, c) = d`. +Stmt *TypecheckVisitor::unpackAssignment(Expr *lhs, Expr *rhs) { + std::vector leftSide; + if (auto et = cast(lhs)) { + // Case: (a, b) = ... + for (auto *i : *et) + leftSide.push_back(i); + } else if (auto el = cast(lhs)) { + // Case: [a, b] = ... + for (auto *i : *el) + leftSide.push_back(i); } else { - // Normal assignments - unify(stmt->lhs->type, ctx->getUnbound()); - if (stmt->type) { - unify(stmt->lhs->type, - ctx->instantiate(stmt->type->getSrcInfo(), stmt->type->getType())); + return N(lhs, rhs); + } + + // Prepare the right-side expression + auto oldSrcInfo = getSrcInfo(); + setSrcInfo(rhs->getSrcInfo()); + auto srcPos = rhs; + SuiteStmt *block = N(); + if (!cast(rhs)) { + // Store any non-trivial right-side expression into a variable + auto var = getTemporaryVar("assign"); + auto newRhs = N(var); + block->addStmt(N(newRhs, ast::clone(rhs))); + rhs = newRhs; + } + + // Process assignments until the fist StarExpr (if any) + size_t st = 0; + for (; st < leftSide.size(); st++) { + if (cast(leftSide[st])) + break; + // Transformation: `leftSide_st = rhs[st]` where `st` is static integer + auto rightSide = N(ast::clone(rhs), N(st)); + // Recursively process the assignment because of cases like `(a, (b, c)) = d)` + auto ns = unpackAssignment(leftSide[st], rightSide); + block->addStmt(ns); + } + // Process StarExpr (if any) and the assignments that follow it + if (st < leftSide.size() && cast(leftSide[st])) { + // StarExpr becomes SliceExpr (e.g., `b` in `(a, *b, c) = d` becomes `d[1:-2]`) + auto rightSide = N( + ast::clone(rhs), + N(N(st), + // this slice is either [st:] or [st:-lhs_len + st + 1] + leftSide.size() == st + 1 ? nullptr + : N(-leftSide.size() + st + 1), + nullptr)); + auto ns = unpackAssignment(cast(leftSide[st])->getExpr(), rightSide); + block->addStmt(ns); + st += 1; + // Process remaining assignments. They will use negative indices (-1, -2 etc.) + // because we do not know how big is StarExpr + for (; st < leftSide.size(); st++) { + if (cast(leftSide[st])) + E(Error::ASSIGN_MULTI_STAR, leftSide[st]->getSrcInfo()); + rightSide = N(ast::clone(rhs), N(-int(leftSide.size() - st))); + auto ns = unpackAssignment(leftSide[st], rightSide); + block->addStmt(ns); } + } + setSrcInfo(oldSrcInfo); + return block; +} + +/// Transform simple assignments. +/// @example +/// `a[x] = b` -> `a.__setitem__(x, b)` +/// `a.x = b` -> @c AssignMemberStmt +/// `a: type` = b -> @c AssignStmt +/// `a = b` -> @c AssignStmt or @c UpdateStmt (see below) +Stmt *TypecheckVisitor::transformAssignment(AssignStmt *stmt, bool mustExist) { + if (auto idx = cast(stmt->getLhs())) { + // Case: a[x] = b + seqassert(!stmt->type, "unexpected type annotation"); + if (auto b = cast(stmt->getRhs())) { + // Case: a[x] += b (inplace operator) + if (mustExist && b->isInPlace() && !cast(b->getRhs())) { + auto var = getTemporaryVar("assign"); + return transform(N( + N(N(var), idx->getIndex()), + N(N( + N(idx->getExpr(), "__setitem__"), N(var), + N(N(clone(idx->getExpr()), N(var)), + b->getOp(), b->getRhs(), true))))); + } + } + return transform(N(N(N(idx->getExpr(), "__setitem__"), + idx->getIndex(), stmt->getRhs()))); + } + + if (auto dot = cast(stmt->getLhs())) { + // Case: a.x = b + seqassert(!stmt->type, "unexpected type annotation"); + dot->expr = transform(dot->getExpr(), true); + return transform( + N(dot->getExpr(), dot->member, transform(stmt->getRhs()))); + } + + // Case: a (: t) = b + auto e = cast(stmt->getLhs()); + if (!e) + E(Error::ASSIGN_INVALID, stmt->getLhs()); + + // Make sure that existing values that cannot be shadowed are only updated + // mustExist |= val && !ctx->isOuter(val); + if (mustExist) { + auto val = ctx->find(e->getValue(), getTime()); + if (!val) + E(Error::ASSIGN_LOCAL_REFERENCE, e, e->getValue(), e->getSrcInfo()); + + auto s = N(stmt->getLhs(), stmt->getRhs()); + if (!ctx->getBase()->isType() && ctx->getBase()->func->hasAttribute(Attr::Atomic)) + s->setAtomicUpdate(); + else + s->setUpdate(); + if (auto u = transformUpdate(s)) + return u; + else + return s; // delay + } + + stmt->rhs = transform(stmt->getRhs(), true); + stmt->type = transformType(stmt->getTypeExpr(), false); + + // Generate new canonical variable name for this assignment and add it to the context + auto canonical = ctx->generateCanonicalName(e->getValue()); + auto assign = + N(N(canonical), stmt->getRhs(), stmt->getTypeExpr()); + assign->getLhs()->cloneAttributesFrom(stmt->getLhs()); + assign->getLhs()->setType(stmt->getLhs()->getType() + ? stmt->getLhs()->getType()->shared_from_this() + : instantiateUnbound(assign->getLhs()->getSrcInfo())); + if (!stmt->getRhs() && !stmt->getTypeExpr() && ctx->find("NoneType")) { + // All declarations that are not handled are to be marked with NoneType later on + // (useful for dangling declarations that are not initialized afterwards due to + // static check) + assign->getLhs()->getType()->getLink()->defaultType = + getStdLibType("NoneType")->shared_from_this(); + ctx->getBase()->pendingDefaults[1].insert( + assign->getLhs()->getType()->shared_from_this()); + } + if (stmt->getTypeExpr()) { + unify(assign->getLhs()->getType(), + instantiateType(stmt->getTypeExpr()->getSrcInfo(), + extractType(stmt->getTypeExpr()))); + } + auto val = std::make_shared( + canonical, ctx->getBaseName(), ctx->getModule(), + assign->getLhs()->getType()->shared_from_this(), ctx->getScope()); + val->time = getTime(); + val->setSrcInfo(getSrcInfo()); + ctx->add(e->getValue(), val); + ctx->addAlwaysVisible(val); + + if (assign->getRhs()) { // not a declaration! // Check if we can wrap the expression (e.g., `a: float = 3` -> `a = float(3)`) - if (wrapExpr(stmt->rhs, stmt->lhs->getType())) - unify(stmt->lhs->type, stmt->rhs->type); - auto type = stmt->lhs->getType(); - auto kind = TypecheckItem::Var; - if (stmt->rhs->isType()) - kind = TypecheckItem::Type; - else if (type->getFunc()) - kind = TypecheckItem::Func; + if (wrapExpr(&assign->rhs, assign->getLhs()->getType())) + unify(assign->getLhs()->getType(), assign->getRhs()->getType()); + // Generalize non-variable types. That way we can support cases like: // `a = foo(x, ...); a(1); a('s')` - auto val = std::make_shared( - kind, - kind != TypecheckItem::Var ? type->generalize(ctx->typecheckLevel - 1) : type); - - if (in(ctx->cache->globals, lhs)) { - // Make globals always visible! - ctx->cache->globals[lhs].first = true; - ctx->addToplevel(lhs, val); - if (kind != TypecheckItem::Var) - ctx->cache->globals.erase(lhs); - } else if (startswith(ctx->getRealizationBase()->name, "._import_") && - kind == TypecheckItem::Type) { - // Make import toplevel type aliases (e.g., `a = Ptr[byte]`) visible - ctx->addToplevel(lhs, val); - } else { - ctx->add(lhs, val); + if (!val->isVar()) { + val->type = val->type->generalize(ctx->typecheckLevel - 1); + // See capture_function_partial_proper_realize test + assign->getLhs()->setType(val->type); + assign->getRhs()->setType(val->type); } + } - if (stmt->lhs->getId() && kind != TypecheckItem::Var) { - // Special case: type/function renames - stmt->rhs->type = nullptr; - stmt->setDone(); - } else if (stmt->rhs->isDone() && realize(stmt->lhs->type)) { - stmt->setDone(); + // Mark declarations or generalizedtype/functions as done + if ((!assign->getRhs() || assign->getRhs()->isDone()) && + assign->getLhs()->getType()->canRealize()) { + if (auto r = realize(assign->getLhs()->getType())) { + // overwrite types to remove dangling unbounds with some partials... + assign->getLhs()->setType(r->shared_from_this()); + if (assign->getRhs()) + assign->getRhs()->setType(r->shared_from_this()); + assign->setDone(); } + } else if (assign->getRhs() && !val->isVar() && !val->type->hasUnbounds()) { + assign->setDone(); } + + // Register all toplevel variables as global in JIT mode + bool isGlobal = (ctx->cache->isJit && val->isGlobal() && !val->isGeneric()) || + (canonical == VAR_ARGV); + if (isGlobal && val->isVar()) + registerGlobal(canonical, assign->getRhs()); + + return assign; } /// Transform binding updates. Special handling is done for atomic or in-place /// statements (e.g., `a += b`). /// See @c transformInplaceUpdate and @c wrapExpr for details. -void TypecheckVisitor::transformUpdate(AssignStmt *stmt) { - transform(stmt->lhs); - if (stmt->lhs->isStatic()) - E(Error::ASSIGN_UNEXPECTED_STATIC, stmt->lhs); +Stmt *TypecheckVisitor::transformUpdate(AssignStmt *stmt) { + stmt->lhs = transform(stmt->getLhs()); // Check inplace updates auto [inPlace, inPlaceExpr] = transformInplaceUpdate(stmt); if (inPlace) { if (inPlaceExpr) { - resultStmt = N(inPlaceExpr); + auto s = N(inPlaceExpr); if (inPlaceExpr->isDone()) - resultStmt->setDone(); + s->setDone(); + return s; } - return; + return nullptr; } - transform(stmt->rhs); + stmt->rhs = transform(stmt->getRhs()); + // Case: wrap expressions if needed (e.g. floats or optionals) - if (wrapExpr(stmt->rhs, stmt->lhs->getType())) - unify(stmt->rhs->type, stmt->lhs->type); - if (stmt->rhs->done && realize(stmt->lhs->type)) + if (wrapExpr(&stmt->rhs, stmt->getLhs()->getType())) + unify(stmt->getRhs()->getType(), stmt->getLhs()->getType()); + if (stmt->getRhs()->isDone() && realize(stmt->getLhs()->getType())) stmt->setDone(); + return nullptr; } /// Typecheck instance member assignments (e.g., `a.b = c`) and handle optional @@ -170,23 +325,24 @@ void TypecheckVisitor::transformUpdate(AssignStmt *stmt) { /// `opt.foo = bar` -> `unwrap(opt).foo = wrap(bar)` /// See @c wrapExpr for more examples. void TypecheckVisitor::visit(AssignMemberStmt *stmt) { - transform(stmt->lhs); - - if (auto lhsClass = stmt->lhs->getType()->getClass()) { - auto member = ctx->findMember(lhsClass, stmt->member); + stmt->lhs = transform(stmt->getLhs()); + if (auto lhsClass = extractClassType(stmt->getLhs())) { + auto member = findMember(lhsClass, stmt->getMember()); if (!member) { - // Case: setters - auto setters = ctx->findMethod(lhsClass.get(), format(".set_{}", stmt->member)); + // Case: property setters + auto setters = + findMethod(lhsClass, format("{}{}", FN_SETTER_SUFFIX, stmt->getMember())); if (!setters.empty()) { - resultStmt = transform(N( - N(N(setters[0]->ast->name), stmt->lhs, stmt->rhs))); + resultStmt = + transform(N(N(N(setters.front()->getFuncName()), + stmt->getLhs(), stmt->getRhs()))); return; } // Case: class variables - if (auto cls = in(ctx->cache->classes, lhsClass->name)) - if (auto var = in(cls->classVars, stmt->member)) { - auto a = N(N(*var), transform(stmt->rhs)); + if (auto cls = getClass(lhsClass)) + if (auto var = in(cls->classVars, stmt->getMember())) { + auto a = N(N(*var), transform(stmt->getRhs())); a->setUpdate(); resultStmt = transform(a); return; @@ -194,22 +350,30 @@ void TypecheckVisitor::visit(AssignMemberStmt *stmt) { } if (!member && lhsClass->is(TYPE_OPTIONAL)) { // Unwrap optional and look up there - resultStmt = transform(N( - N(N(FN_UNWRAP), stmt->lhs), stmt->member, stmt->rhs)); + resultStmt = transform( + N(N(N(FN_UNWRAP), stmt->getLhs()), + stmt->getMember(), stmt->getRhs())); return; } if (!member) - E(Error::DOT_NO_ATTR, stmt->lhs, lhsClass->prettyString(), stmt->member); - if (lhsClass->getRecord()) - E(Error::ASSIGN_UNEXPECTED_FROZEN, stmt->lhs); + E(Error::DOT_NO_ATTR, stmt->getLhs(), lhsClass->prettyString(), + stmt->getMember()); + if (lhsClass->isRecord()) // prevent tuple member assignment + E(Error::ASSIGN_UNEXPECTED_FROZEN, stmt->getLhs()); - transform(stmt->rhs); - auto typ = ctx->instantiate(stmt->lhs->getSrcInfo(), member, lhsClass); - if (!wrapExpr(stmt->rhs, typ)) + stmt->rhs = transform(stmt->getRhs()); + auto ftyp = + instantiateType(stmt->getLhs()->getSrcInfo(), member->getType(), lhsClass); + if (!ftyp->canRealize() && member->typeExpr) { + unify(ftyp.get(), extractType(withClassGenerics(lhsClass, [&]() { + return transform(clean_clone(member->typeExpr)); + }))); + } + if (!wrapExpr(&stmt->rhs, ftyp.get())) return; - unify(stmt->rhs->type, typ); - if (stmt->rhs->isDone()) + unify(stmt->getRhs()->getType(), ftyp.get()); + if (stmt->getRhs()->isDone()) stmt->setDone(); } } @@ -223,59 +387,71 @@ void TypecheckVisitor::visit(AssignMemberStmt *stmt) { /// `a = min(a, b)` -> `type(a).__atomic_min__(__ptr__(a), b)` (same for `max`) /// @return a tuple indicating whether (1) the update statement can be replaced with an /// expression, and (2) the replacement expression. -std::pair TypecheckVisitor::transformInplaceUpdate(AssignStmt *stmt) { +std::pair TypecheckVisitor::transformInplaceUpdate(AssignStmt *stmt) { // Case: in-place updates (e.g., `a += b`). // They are stored as `Update(a, Binary(a + b, inPlace=true))` - auto bin = stmt->rhs->getBinary(); - if (bin && bin->inPlace) { - transform(bin->lexpr); - transform(bin->rexpr); - if (bin->lexpr->type->getClass() && bin->rexpr->type->getClass()) { + + auto bin = cast(stmt->getRhs()); + if (bin && bin->isInPlace()) { + bin->lexpr = transform(bin->getLhs()); + bin->rexpr = transform(bin->getRhs()); + + if (!stmt->getRhs()->getType()) + stmt->getRhs()->setType(instantiateUnbound()); + if (bin->getLhs()->getClassType() && bin->getRhs()->getClassType()) { if (auto transformed = transformBinaryInplaceMagic(bin, stmt->isAtomicUpdate())) { - unify(stmt->rhs->type, transformed->type); + unify(stmt->getRhs()->getType(), transformed->getType()); return {true, transformed}; } else if (!stmt->isAtomicUpdate()) { // If atomic, call normal magic and then use __atomic_xchg__ below return {false, nullptr}; } - } else { // Not yet completed - unify(stmt->lhs->type, unify(stmt->rhs->type, ctx->getUnbound())); + } else { // Delay + unify(stmt->lhs->getType(), + unify(stmt->getRhs()->getType(), instantiateUnbound())); return {true, nullptr}; } } // Case: atomic min/max operations. // Note: check only `a = min(a, b)`; does NOT check `a = min(b, a)` - auto lhsClass = stmt->lhs->getType()->getClass(); - auto call = stmt->rhs->getCall(); - if (stmt->isAtomicUpdate() && call && stmt->lhs->getId() && - (call->expr->isId("min") || call->expr->isId("max")) && call->args.size() == 2 && - call->args[0].value->isId(std::string(stmt->lhs->getId()->value))) { - // `type(a).__atomic_min__(__ptr__(a), b)` - auto ptrTyp = ctx->instantiateGeneric(stmt->lhs->getSrcInfo(), ctx->getType("Ptr"), - {lhsClass}); - call->args[1].value = transform(call->args[1].value); - auto rhsTyp = call->args[1].value->getType()->getClass(); - if (auto method = findBestMethod( - lhsClass, format("__atomic_{}__", call->expr->getId()->value), - {ptrTyp, rhsTyp})) { - return {true, transform(N(N(method->ast->name), - N(N("__ptr__"), stmt->lhs), - call->args[1].value))}; + auto lhsClass = extractClassType(stmt->getLhs()); + auto call = cast(stmt->getRhs()); + auto lei = cast(stmt->getLhs()); + auto cei = call ? cast(call->getExpr()) : nullptr; + if (stmt->isAtomicUpdate() && call && lei && cei && + (cei->getValue() == "min" || cei->getValue() == "max") && call->size() == 2) { + call->front().value = transform(call->front()); + if (cast(call->front()) && + cast(call->front())->getValue() == lei->getValue()) { + // `type(a).__atomic_min__(__ptr__(a), b)` + auto ptrTyp = instantiateType(stmt->getLhs()->getSrcInfo(), getStdLibType("Ptr"), + std::vector{lhsClass}); + (*call)[1].value = transform((*call)[1]); + auto rhsTyp = extractClassType((*call)[1].value); + if (auto method = + findBestMethod(lhsClass, format("__atomic_{}__", cei->getValue()), + {ptrTyp.get(), rhsTyp})) { + return {true, + transform(N(N(method->getFuncName()), + N(N("__ptr__"), stmt->getLhs()), + (*call)[1]))}; + } } } // Case: atomic assignments if (stmt->isAtomicUpdate() && lhsClass) { // `type(a).__atomic_xchg__(__ptr__(a), b)` - transform(stmt->rhs); - if (auto rhsClass = stmt->rhs->getType()->getClass()) { - auto ptrType = ctx->instantiateGeneric(stmt->lhs->getSrcInfo(), - ctx->getType("Ptr"), {lhsClass}); - if (auto m = findBestMethod(lhsClass, "__atomic_xchg__", {ptrType, rhsClass})) { - return {true, - N(N(m->ast->name), - N(N("__ptr__"), stmt->lhs), stmt->rhs)}; + stmt->rhs = transform(stmt->getRhs()); + if (auto rhsClass = stmt->getRhs()->getClassType()) { + auto ptrType = instantiateType(stmt->getLhs()->getSrcInfo(), getStdLibType("Ptr"), + std::vector{lhsClass}); + if (auto m = + findBestMethod(lhsClass, "__atomic_xchg__", {ptrType.get(), rhsClass})) { + return {true, N(N(m->getFuncName()), + N(N("__ptr__"), stmt->getLhs()), + stmt->getRhs())}; } } } diff --git a/codon/parser/visitors/typecheck/basic.cpp b/codon/parser/visitors/typecheck/basic.cpp index 94f99e7f..1777871b 100644 --- a/codon/parser/visitors/typecheck/basic.cpp +++ b/codon/parser/visitors/typecheck/basic.cpp @@ -3,10 +3,11 @@ #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/peg/peg.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; +using namespace codon::error; namespace codon::ast { @@ -14,38 +15,157 @@ using namespace types; /// Set type to `Optional[?]` void TypecheckVisitor::visit(NoneExpr *expr) { - unify(expr->type, ctx->instantiate(ctx->getType(TYPE_OPTIONAL))); - if (realize(expr->type)) { + unify(expr->getType(), instantiateType(getStdLibType(TYPE_OPTIONAL))); + if (realize(expr->getType())) { // Realize the appropriate `Optional.__new__` for the translation stage - auto cls = expr->type->getClass(); - auto f = ctx->forceFind(TYPE_OPTIONAL ".__new__:0")->type; - auto t = realize(ctx->instantiate(f, cls)->getFunc()); + auto f = ctx->forceFind(TYPE_OPTIONAL ".__new__:0")->getType(); + auto t = realize(instantiateType(f, extractClassType(expr))); expr->setDone(); } } /// Set type to `bool` void TypecheckVisitor::visit(BoolExpr *expr) { - unify(expr->type, ctx->getType("bool")); + unify(expr->getType(), instantiateStatic(expr->getValue())); expr->setDone(); } /// Set type to `int` -void TypecheckVisitor::visit(IntExpr *expr) { - unify(expr->type, ctx->getType("int")); - expr->setDone(); -} +void TypecheckVisitor::visit(IntExpr *expr) { resultExpr = transformInt(expr); } /// Set type to `float` -void TypecheckVisitor::visit(FloatExpr *expr) { - unify(expr->type, ctx->getType("float")); - expr->setDone(); -} +void TypecheckVisitor::visit(FloatExpr *expr) { resultExpr = transformFloat(expr); } -/// Set type to `str` +/// Set type to `str`. Concatinate strings in list and apply appropriate transformations +/// (e.g., `str` wrap). void TypecheckVisitor::visit(StringExpr *expr) { - unify(expr->type, ctx->getType("str")); - expr->setDone(); + if (expr->isSimple()) { + unify(expr->getType(), instantiateStatic(expr->getValue())); + expr->setDone(); + } else { + std::vector items; + for (auto &p : *expr) { + if (p.expr) { + if (!p.format.conversion.empty()) { + switch (p.format.conversion[0]) { + case 'r': + p.expr = N(N("repr"), p.expr); + break; + case 's': + p.expr = N(N("str"), p.expr); + break; + case 'a': + p.expr = N(N("ascii"), p.expr); + break; + default: + // TODO: error? + break; + } + } + if (!p.format.spec.empty()) { + p.expr = N(N(p.expr, "__format__"), + N(p.format.spec)); + } + p.expr = N(N("str"), p.expr); + if (!p.format.text.empty()) { + p.expr = N(N(N("str"), "cat"), + N(p.format.text), p.expr); + } + items.emplace_back(p.expr); + } else if (!p.prefix.empty()) { + /// Custom prefix strings: + /// call `str.__prefsix_[prefix]__(str, [static length of str])` + items.emplace_back( + N(N(N("str"), format("__prefix_{}__", p.prefix)), + N(p.value), N(p.value.size()))); + } else { + items.emplace_back(N(p.value)); + } + } + if (items.size() == 1) + resultExpr = transform(items.front()); + else + resultExpr = transform(N(N(N("str"), "cat"), items)); + } +} + +/// Parse various integer representations depending on the integer suffix. +/// @example +/// `123u` -> `UInt[64](123)` +/// `123i56` -> `Int[56](123)` +/// `123pf` -> `int.__suffix_pf__(123)` +Expr *TypecheckVisitor::transformInt(IntExpr *expr) { + auto [value, suffix] = expr->getRawData(); + Expr *holder = nullptr; + if (!expr->hasStoredValue()) { + holder = N(value); + if (suffix.empty()) + suffix = "i64"; + } else { + holder = N(expr->getValue()); + } + + /// Handle fixed-width integers: suffixValue is a pointer to NN if the suffix + /// is `uNNN` or `iNNN`. + std::unique_ptr suffixValue = nullptr; + if (suffix.size() > 1 && (suffix[0] == 'u' || suffix[0] == 'i') && + isdigit(suffix.substr(1))) { + try { + suffixValue = std::make_unique(std::stoi(suffix.substr(1))); + } catch (...) { + } + if (suffixValue && *suffixValue > MAX_INT_WIDTH) + suffixValue = nullptr; + } + + if (suffix.empty()) { + // A normal integer (int64_t) + unify(expr->getType(), instantiateStatic(expr->getValue())); + expr->setDone(); + return nullptr; + } else if (suffix == "u") { + // Unsigned integer: call `UInt[64](value)` + return transform( + N(N(N("UInt"), N(64)), holder)); + } else if (suffixValue) { + // Fixed-width numbers (with `uNNN` and `iNNN` suffixes): + // call `UInt[NNN](value)` or `Int[NNN](value)` + return transform( + N(N(N(suffix[0] == 'u' ? "UInt" : "Int"), + N(*suffixValue)), + holder)); + } else { + // Custom suffix: call `int.__suffix_[suffix]__(value)` + return transform(N( + N(N("int"), format("__suffix_{}__", suffix)), holder)); + } +} + +/// Parse various float representations depending on the suffix. +/// @example +/// `123.4pf` -> `float.__suffix_pf__(123.4)` +Expr *TypecheckVisitor::transformFloat(FloatExpr *expr) { + auto [value, suffix] = expr->getRawData(); + + Expr *holder = nullptr; + if (!expr->hasStoredValue()) { + holder = N(value); + } else { + holder = N(expr->getValue()); + } + + if (suffix.empty() && expr->hasStoredValue()) { + // A normal float (double) + unify(expr->getType(), getStdLibType("float")); + expr->setDone(); + return nullptr; + } else if (suffix.empty()) { + return transform(N(N(N("float"), "__new__"), holder)); + } else { + // Custom suffix: call `float.__suffix_[suffix]__(value)` + return transform(N( + N(N("float"), format("__suffix_{}__", suffix)), holder)); + } } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 68dcb407..92c0a5e7 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -6,7 +6,7 @@ #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -14,6 +14,21 @@ using namespace codon::error; namespace codon::ast { using namespace types; +using namespace matcher; + +/// Transform print statement. +/// @example +/// `print a, b` -> `print(a, b)` +/// `print a, b,` -> `print(a, b, end=' ')` +void TypecheckVisitor::visit(PrintStmt *stmt) { + std::vector args; + args.reserve(stmt->size()); + for (auto &i : *stmt) + args.emplace_back(i); + if (!stmt->hasNewline()) + args.emplace_back("end", N(" ")); + resultStmt = transform(N(N(N("print"), args))); +} /// Just ensure that this expression is not independent of CallExpr where it is handled. void TypecheckVisitor::visit(StarExpr *expr) { @@ -28,14 +43,11 @@ void TypecheckVisitor::visit(KeywordStarExpr *expr) { /// Typechecks an ellipsis. Ellipses are typically replaced during the typechecking; the /// only remaining ellipses are those that belong to PipeExprs. void TypecheckVisitor::visit(EllipsisExpr *expr) { - unify(expr->type, ctx->getUnbound()); - if (expr->mode == EllipsisExpr::PIPE && realize(expr->type)) { + if (expr->isPipe() && realize(expr->getType())) { expr->setDone(); - } - - if (expr->mode == EllipsisExpr::STANDALONE) { + } else if (expr->isStandalone()) { resultExpr = transform(N(N("ellipsis"))); - unify(expr->type, resultExpr->type); + unify(expr->getType(), resultExpr->getType()); } } @@ -46,169 +58,293 @@ void TypecheckVisitor::visit(EllipsisExpr *expr) { /// See @c transformCallArgs , @c getCalleeFn , @c callReorderArguments , /// @c typecheckCallArgs , @c transformSpecialCall and @c wrapExpr for more details. void TypecheckVisitor::visit(CallExpr *expr) { - if (expr->expr->isId("__internal__.undef") && expr->args.size() == 2 && - expr->args[0].value->getId()) { - auto val = expr->args[0].value->getId()->value; - val = val.substr(0, val.size() - 9); - if (auto changed = in(ctx->cache->replacements, val)) { - while (auto s = in(ctx->cache->replacements, val)) - val = changed->first, changed = s; - if (!changed->second) { - // TODO: add no-op expr - resultExpr = transform(N(false)); - return; - } - } - } + auto orig = expr->toString(0); + if (match(expr->getExpr(), M("tuple")) && expr->size() == 1) + expr->setAttribute(Attr::TupleCall); - // Transform and expand arguments. Return early if it cannot be done yet - if (!transformCallArgs(expr->args)) - return; + validateCall(expr); // Check if this call is partial call - PartialCallData part{!expr->args.empty() && expr->args.back().value->getEllipsis() && - expr->args.back().value->getEllipsis()->mode == - EllipsisExpr::PARTIAL}; - // Transform the callee - if (!part.isPartial) { - // Intercept method calls (e.g. `obj.method`) for faster compilation (because it - // avoids partial calls). This intercept passes the call arguments to - // @c transformDot to select the best overload as well - if (auto dot = expr->expr->getDot()) { - // Pick the best method overload - if (auto edt = transformDot(dot, &expr->args)) - expr->expr = edt; - } else if (auto id = expr->expr->getId()) { - // Pick the best function overload - auto overloads = in(ctx->cache->overloads, id->value); - if (overloads && overloads->size() > 1) { - if (auto bestMethod = getBestOverload(id, &expr->args)) { - auto t = id->type; - expr->expr = N(bestMethod->ast->name); - expr->expr->setType(unify(t, ctx->instantiate(bestMethod))); - } - } + PartialCallData part; + if (!expr->empty()) { + if (auto el = cast(expr->back().getExpr())) { + if (expr->back().getName().empty() && !el->isPipe()) + el->mode = EllipsisExpr::PARTIAL; + if (el->mode == EllipsisExpr::PARTIAL) + part.isPartial = true; } } - transform(expr->expr); + + // Do not allow realization here (function will be realized later); + // used to prevent early realization of compile_error + expr->setAttribute(Attr::ParentCallExpr); + if (part.isPartial) + expr->getExpr()->setAttribute(Attr::ExprDoNotRealize); + expr->expr = transform(expr->getExpr()); + expr->eraseAttribute(Attr::ParentCallExpr); + if (isUnbound(expr->getExpr())) + return; // delay + auto [calleeFn, newExpr] = getCalleeFn(expr, part); - if ((resultExpr = newExpr)) + // Transform `tuple(i for i in tup)` into a GeneratorExpr that will be handled during + // the type checking. + if (!calleeFn && expr->hasAttribute(Attr::TupleCall)) { + if (cast(expr->begin()->getExpr())) { + auto g = cast(expr->begin()->getExpr()); + if (!g || g->kind != GeneratorExpr::Generator || g->loopCount() != 1) + E(Error::CALL_TUPLE_COMPREHENSION, expr->begin()->getExpr()); + g->kind = GeneratorExpr::TupleGenerator; + resultExpr = transform(g); + return; + } else { + resultExpr = transformTupleFn(expr); + return; + } + } else if ((resultExpr = newExpr)) { + return; + } else if (!calleeFn) { return; - if (!calleeFn) + } + + if (!withClassGenerics( + calleeFn.get(), [&]() { return transformCallArgs(expr); }, true, true)) return; + // Early dispatch modifier + if (isDispatch(calleeFn.get())) { + if (startswith(calleeFn->getFuncName(), "Tuple.__new__")) { + generateTuple(expr->size()); + } + std::unique_ptr> m = nullptr; + if (auto id = cast(expr->getExpr())) { + // Case: function overloads (IdExpr) + std::vector methods; + auto key = id->getValue(); + if (isDispatch(key)) + key = key.substr(0, key.size() - std::string(FN_DISPATCH_SUFFIX).size()); + for (auto &m : getOverloads(key)) { + if (!isDispatch(m)) + methods.push_back(getFunction(m)->getType()); + } + std::reverse(methods.begin(), methods.end()); + m = std::make_unique>(findMatchingMethods( + calleeFn->funcParent ? calleeFn->funcParent->getClass() : nullptr, methods, + expr->items, expr->getExpr()->getType()->getPartial())); + } + // partials have dangling ellipsis that messes up with the unbound check below + bool doDispatch = !m || m->size() == 0 || part.isPartial; + if (!doDispatch && m && m->size() > 1) { + auto unbounds = 0; + for (auto &a : *expr) { + if (isUnbound(a.getExpr())) + return; // typecheck this later once we know the argument + } + if (unbounds) + return; + } + if (!doDispatch) { + calleeFn = instantiateType(m->front(), calleeFn->funcParent + ? calleeFn->funcParent->getClass() + : nullptr); + auto e = N(calleeFn->getFuncName()); + e->setType(calleeFn); + if (cast(expr->getExpr())) { + expr->expr = e; + } else { + expr->expr = N(N(expr->getExpr()), e); + } + expr->getExpr()->setType(calleeFn); + } else if (m && m->empty()) { + std::vector a; + for (auto &t : *expr) + a.emplace_back(fmt::format("{}", t.getExpr()->getType()->getStatic() + ? t.getExpr()->getClassType()->name + : t.getExpr()->getType()->prettyString())); + auto argsNice = fmt::format("({})", fmt::join(a, ", ")); + auto name = getUnmangledName(calleeFn->getFuncName()); + if (calleeFn->getParentType() && calleeFn->getParentType()->getClass()) + name = format("{}.{}", calleeFn->getParentType()->getClass()->niceName, name); + E(Error::FN_NO_ATTR_ARGS, expr, name, argsNice); + } + } + + bool isVirtual = false; + if (auto dot = cast(expr->getExpr()->getOrigExpr())) { + if (auto baseTyp = dot->getExpr()->getClassType()) { + auto cls = getClass(baseTyp); + isVirtual = bool(in(cls->virtuals, dot->getMember())) && cls->rtti && + !isTypeExpr(expr->getExpr()) && !isDispatch(calleeFn.get()) && + !calleeFn->ast->hasAttribute(Attr::StaticMethod) && + !calleeFn->ast->hasAttribute(Attr::Property); + } + } // Handle named and default arguments - if ((resultExpr = callReorderArguments(calleeFn, expr, part))) + if ((resultExpr = callReorderArguments(calleeFn.get(), expr, part))) return; // Handle special calls if (!part.isPartial) { auto [isSpecial, specialExpr] = transformSpecialCall(expr); if (isSpecial) { - unify(expr->type, ctx->getUnbound()); resultExpr = specialExpr; return; } } // Typecheck arguments with the function signature - bool done = typecheckCallArgs(calleeFn, expr->args); - if (!part.isPartial && realize(calleeFn)) { + bool done = typecheckCallArgs(calleeFn.get(), expr->items, part.isPartial); + if (!part.isPartial && calleeFn->canRealize()) { // Previous unifications can qualify existing identifiers. // Transform again to get the full identifier - transform(expr->expr); + expr->expr = transform(expr->expr); } done &= expr->expr->isDone(); // Emit the final call if (part.isPartial) { - // Case: partial call. `calleeFn(args...)` -> `Partial.N.(args...)` - auto partialTypeName = generatePartialStub(part.known, calleeFn->getFunc().get()); - std::vector newArgs; - for (auto &r : expr->args) - if (!r.value->getEllipsis()) { - newArgs.push_back(r.value); - newArgs.back()->setAttr(ExprAttr::SequenceItem); + // Case: partial call. `calleeFn(args...)` -> `Partial(args..., fn, mask)` + std::vector newArgs; + for (auto &r : *expr) + if (!cast(r.getExpr())) { + newArgs.push_back(r.getExpr()); + newArgs.back()->setAttribute(Attr::ExprSequenceItem); } newArgs.push_back(part.args); - newArgs.push_back(part.kwArgs); - - std::string var = ctx->cache->getTemporaryVar("part"); - ExprPtr call = nullptr; + auto partialCall = generatePartialCall(part.known, calleeFn->getFunc(), + N(newArgs), part.kwArgs); + std::string var = getTemporaryVar("part"); + Expr *call = nullptr; if (!part.var.empty()) { // Callee is already a partial call - auto stmts = expr->expr->getStmtExpr()->stmts; - stmts.push_back(N(N(var), - N(N(partialTypeName), newArgs))); + auto stmts = cast(expr->expr)->items; + stmts.push_back(N(N(var), partialCall)); call = N(stmts, N(var)); } else { - // New partial call: `(part = Partial.N.(stored_args...); part)` - call = - N(N(N(var), - N(N(partialTypeName), newArgs)), - N(var)); + // New partial call: `(part = Partial(stored_args...); part)` + call = N(N(N(var), partialCall), N(var)); } - call->setAttr(ExprAttr::Partial); + call->setAttribute(Attr::ExprPartial); resultExpr = transform(call); + } else if (isVirtual) { + if (!realize(calleeFn)) + return; + auto vid = getRealizationID(calleeFn->getParentType()->getClass(), calleeFn.get()); + + // Function[Tuple[TArg1, TArg2, ...], TRet] + std::vector ids; + for (auto &t : *calleeFn) + ids.push_back(N(t.getType()->realizedName())); + auto fnType = N( + N("Function"), + std::vector{N(N(TYPE_TUPLE), ids), + N(calleeFn->getRetType()->realizedName())}); + // Function[Tuple[TArg1, TArg2, ...],TRet]( + // __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]] + // ) + auto e = N(fnType, + N(N(N(N("__internal__"), + "class_get_rtti_vtable"), + expr->front().getExpr()), + N(vid))); + std::vector args; + for (auto &a : *expr) + args.emplace_back(a.getExpr()); + resultExpr = transform(N(e, args)); } else { // Case: normal function call - unify(expr->type, calleeFn->getRetType()); + unify(expr->getType(), calleeFn->getRetType()); if (done) expr->setDone(); } } -/// Transform call arguments. Expand *args and **kwargs to the list of @c CallExpr::Arg +void TypecheckVisitor::validateCall(CallExpr *expr) { + if (expr->hasAttribute(Attr::Validated)) + return; + bool namesStarted = false, foundEllipsis = false; + for (auto &a : *expr) { + if (a.name.empty() && namesStarted && + !(cast(a.value) || cast(a.value))) + E(Error::CALL_NAME_ORDER, a.value); + if (!a.name.empty() && (cast(a.value) || cast(a.value))) + E(Error::CALL_NAME_STAR, a.value); + if (cast(a.value) && foundEllipsis) + E(Error::CALL_ELLIPSIS, a.value); + foundEllipsis |= bool(cast(a.value)); + namesStarted |= !a.name.empty(); + } + expr->setAttribute(Attr::Validated); +} + +/// Transform call arguments. Expand *args and **kwargs to the list of @c CallArg /// objects. /// @return false if expansion could not be completed; true otherwise -bool TypecheckVisitor::transformCallArgs(std::vector &args) { - for (auto ai = 0; ai < args.size();) { - if (auto star = args[ai].value->getStar()) { +bool TypecheckVisitor::transformCallArgs(CallExpr *expr) { + for (auto ai = 0; ai < expr->size();) { + if (auto star = cast((*expr)[ai].getExpr())) { // Case: *args expansion - transform(star->what); - auto typ = star->what->type->getClass(); + star->expr = transform(star->getExpr()); + auto typ = star->getExpr()->getClassType(); while (typ && typ->is(TYPE_OPTIONAL)) { - star->what = transform(N(N(FN_UNWRAP), star->what)); - typ = star->what->type->getClass(); + star->expr = transform(N(N(FN_UNWRAP), star->getExpr())); + typ = star->getExpr()->getClassType(); } if (!typ) // Process later return false; - if (!typ->getRecord()) - E(Error::CALL_BAD_UNPACK, args[ai], typ->prettyString()); - auto fields = getClassFields(typ.get()); - for (size_t i = 0; i < typ->getRecord()->args.size(); i++, ai++) { - args.insert(args.begin() + ai, - {"", transform(N(clone(star->what), fields[i].name))}); + if (!typ->isRecord()) + E(Error::CALL_BAD_UNPACK, (*expr)[ai], typ->prettyString()); + auto fields = getClassFields(typ); + for (size_t i = 0; i < fields.size(); i++, ai++) { + expr->items.insert( + expr->items.begin() + ai, + {"", transform(N(clone(star->getExpr()), fields[i].name))}); } - args.erase(args.begin() + ai); - } else if (auto kwstar = CAST(args[ai].value, KeywordStarExpr)) { + expr->items.erase(expr->items.begin() + ai); + } else if (auto kwstar = cast((*expr)[ai].getExpr())) { // Case: **kwargs expansion - kwstar->what = transform(kwstar->what); - auto typ = kwstar->what->type->getClass(); + kwstar->expr = transform(kwstar->getExpr()); + auto typ = kwstar->getExpr()->getClassType(); while (typ && typ->is(TYPE_OPTIONAL)) { - kwstar->what = transform(N(N(FN_UNWRAP), kwstar->what)); - typ = kwstar->what->type->getClass(); + kwstar->expr = transform(N(N(FN_UNWRAP), kwstar->getExpr())); + typ = kwstar->getExpr()->getClassType(); } if (!typ) return false; - if (!typ->getRecord() || typ->name == TYPE_TUPLE) - E(Error::CALL_BAD_KWUNPACK, args[ai], typ->prettyString()); - auto fields = getClassFields(typ.get()); - for (size_t i = 0; i < typ->getRecord()->args.size(); i++, ai++) { - args.insert(args.begin() + ai, - {fields[i].name, - transform(N(clone(kwstar->what), fields[i].name))}); + if (typ->is("NamedTuple")) { + auto id = getIntLiteral(typ); + seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", + id); + auto names = ctx->cache->generatedTupleNames[id]; + for (size_t i = 0; i < names.size(); i++, ai++) { + expr->items.insert( + expr->items.begin() + ai, + CallArg{names[i], + transform(N(N(kwstar->getExpr(), "args"), + format("item{}", i + 1)))}); + } + expr->items.erase(expr->items.begin() + ai); + } else if (typ->isRecord()) { + auto fields = getClassFields(typ); + for (size_t i = 0; i < fields.size(); i++, ai++) { + expr->items.insert( + expr->items.begin() + ai, + CallArg{fields[i].name, + transform(N(kwstar->expr, fields[i].name))}); + } + expr->items.erase(expr->items.begin() + ai); + } else { + E(Error::CALL_BAD_KWUNPACK, (*expr)[ai], typ->prettyString()); } - args.erase(args.begin() + ai); } else { // Case: normal argument (no expansion) - transform(args[ai++].value); + (*expr)[ai].value = transform((*expr)[ai].getExpr()); + ai++; } } // Check if some argument names are reused after the expansion std::set seen; - for (auto &a : args) + for (auto &a : *expr) if (!a.name.empty()) { if (in(seen, a.name)) E(Error::CALL_REPEATED_NAME, a, a.name); @@ -222,85 +358,108 @@ bool TypecheckVisitor::transformCallArgs(std::vector &args) { /// Also handle special callees: constructors and partial functions. /// @return a pair with the callee's @c FuncType and the replacement expression /// (when needed; otherwise nullptr). -std::pair TypecheckVisitor::getCalleeFn(CallExpr *expr, - PartialCallData &part) { - auto callee = expr->expr->type->getClass(); +std::pair, Expr *> +TypecheckVisitor::getCalleeFn(CallExpr *expr, PartialCallData &part) { + auto callee = expr->getExpr()->getClassType(); if (!callee) { // Case: unknown callee, wait until it becomes known - unify(expr->type, ctx->getUnbound()); return {nullptr, nullptr}; } - if (expr->expr->isType() && callee->getRecord()) { - // Case: tuple constructor. Transform to: `T.__new__(args)` - return {nullptr, - transform(N(N(expr->expr, "__new__"), expr->args))}; - } + if (isTypeExpr(expr->getExpr())) { + auto typ = expr->getExpr()->getClassType(); + if (!isId(expr->getExpr(), TYPE_TYPE)) + typ = extractClassGeneric(typ)->getClass(); + if (!typ) + return {nullptr, nullptr}; + auto clsName = typ->name; + if (typ->isRecord()) { + if (expr->hasAttribute(Attr::TupleCall)) { + if (extractType(expr->getExpr())->is(TYPE_TUPLE)) + return {nullptr, nullptr}; + expr->eraseAttribute(Attr::TupleCall); + } + // Case: tuple constructor. Transform to: `T.__new__(args)` + auto e = + transform(N(N(expr->getExpr(), "__new__"), expr->items)); + return {nullptr, e}; + } - if (expr->expr->isType()) { // Case: reference type constructor. Transform to // `ctr = T.__new__(); v.__init__(args)` - ExprPtr var = N(ctx->cache->getTemporaryVar("ctr")); - auto clsName = expr->expr->type->getClass()->name; + Expr *var = N(getTemporaryVar("ctr")); auto newInit = - N(clone(var), N(N(expr->expr, "__new__"))); + N(clone(var), N(N(expr->getExpr(), "__new__"))); auto e = N(N(newInit), clone(var)); auto init = - N(N(N(clone(var), "__init__"), expr->args)); - e->stmts.emplace_back(init); + N(N(N(clone(var), "__init__"), expr->items)); + e->items.emplace_back(init); return {nullptr, transform(e)}; } - auto calleeFn = callee->getFunc(); if (auto partType = callee->getPartial()) { - // Case: calling partial object `p`. Transform roughly to - // `part = callee; partial_fn(*part.args, args...)` - ExprPtr var = N(part.var = ctx->cache->getTemporaryVar("partcall")); - expr->expr = transform(N(N(clone(var), expr->expr), - N(partType->func->ast->name))); - - // Ensure that we got a function - calleeFn = expr->expr->type->getFunc(); - seqassert(calleeFn, "not a function: {}", expr->expr->type); + auto mask = partType->getPartialMask(); + auto genFn = partType->getPartialFunc()->generalize(0); + auto calleeFn = + std::static_pointer_cast(instantiateType(genFn.get())); + + if (!partType->isPartialEmpty() || + std::any_of(mask.begin(), mask.end(), [](char c) { return c; })) { + // Case: calling partial object `p`. Transform roughly to + // `part = callee; partial_fn(*part.args, args...)` + Expr *var = N(part.var = getTemporaryVar("partcall")); + expr->expr = transform(N(N(clone(var), expr->getExpr()), + N(calleeFn->getFuncName()))); + part.known = mask; + } else { + expr->expr = transform(N(calleeFn->getFuncName())); + } + seqassert(expr->getExpr()->getType()->getFunc(), "not a function: {}", + *(expr->getExpr()->getType())); + unify(expr->getExpr()->getType(), calleeFn); // Unify partial generics with types known thus far - for (size_t i = 0, j = 0, k = 0; i < partType->known.size(); i++) - if (partType->func->ast->args[i].status == Param::Generic) { - if (partType->known[i]) - unify(calleeFn->funcGenerics[j].type, - ctx->instantiate(partType->func->funcGenerics[j].type)); + auto knownArgTypes = extractClassGeneric(partType, 1)->getClass(); + for (size_t i = 0, j = 0, k = 0; i < mask.size(); i++) + if ((*calleeFn->ast)[i].isGeneric()) { j++; - } else if (partType->known[i]) { - unify(calleeFn->getArgTypes()[i - j], partType->generics[k].type); + } else if (mask[i]) { + unify(extractFuncArgType(calleeFn.get(), i - j), + extractClassGeneric(knownArgTypes, k)); k++; } - part.known = partType->known; return {calleeFn, nullptr}; } else if (!callee->getFunc()) { // Case: callee is not a function. Try __call__ method instead - return {nullptr, - transform(N(N(expr->expr, "__call__"), expr->args))}; + return {nullptr, transform(N(N(expr->getExpr(), "__call__"), + expr->items))}; + } else { + return {std::static_pointer_cast( + callee->getFunc()->shared_from_this()), + nullptr}; } - return {calleeFn, nullptr}; } /// Reorder the call arguments to match the signature order. Ensure that every @c -/// CallExpr::Arg has a set name. Form *args/**kwargs tuples if needed, and use partial +/// CallArg has a set name. Form *args/**kwargs tuples if needed, and use partial /// and default values where needed. /// @example /// `foo(1, 2, baz=3, baf=4)` -> `foo(a=1, baz=2, args=(3, ), kwargs=KwArgs(baf=4))` -ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *expr, - PartialCallData &part) { - std::vector args; // stores ordered and processed arguments - std::vector typeArgs; // stores type and static arguments (e.g., `T: type`) - auto newMask = std::vector(calleeFn->ast->args.size(), 1); +Expr *TypecheckVisitor::callReorderArguments(FuncType *calleeFn, CallExpr *expr, + PartialCallData &part) { + if (calleeFn->ast->hasAttribute(Attr::NoArgReorder)) + return nullptr; + + std::vector args; // stores ordered and processed arguments + std::vector typeArgs; // stores type and static arguments (e.g., `T: type`) + auto newMask = std::vector(calleeFn->ast->size(), 1); // Extract pi-th partial argument from a partial object auto getPartialArg = [&](size_t pi) { - auto id = transform(N(part.var)); + auto id = transform(N(N(part.var), "args")); // Manually call @c transformStaticTupleIndex to avoid spurious InstantiateExpr - auto ex = transformStaticTupleIndex(id->type->getClass(), id, N(pi)); - seqassert(ex.first && ex.second, "partial indexing failed: {}", id->type); + auto ex = transformStaticTupleIndex(id->getClassType(), id, N(pi)); + seqassert(ex.first && ex.second, "partial indexing failed: {}", *(id->getType())); return ex.second; }; @@ -309,116 +468,113 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e auto reorderFn = [&](int starArgIndex, int kwstarArgIndex, const std::vector> &slots, bool _partial) { partial = _partial; - ctx->addBlock(); // add function generics to typecheck default arguments - addFunctionGenerics(calleeFn->getFunc().get()); - for (size_t si = 0, pi = 0; si < slots.size(); si++) { - // Get the argument name to be used later - auto rn = calleeFn->ast->args[si].name; - trimStars(rn); - auto realName = ctx->cache->rev(rn); - - if (calleeFn->ast->args[si].status == Param::Generic) { - // Case: generic arguments. Populate typeArgs - typeArgs.push_back(slots[si].empty() ? nullptr - : expr->args[slots[si][0]].value); - newMask[si] = slots[si].empty() ? 0 : 1; - } else if (si == starArgIndex && - !(slots[si].size() == 1 && - expr->args[slots[si][0]].value->hasAttr(ExprAttr::StarArgument))) { - // Case: *args. Build the tuple that holds them all - std::vector extra; - if (!part.known.empty()) - extra.push_back(N(getPartialArg(-2))); - for (auto &e : slots[si]) { - extra.push_back(expr->args[e].value); - } - ExprPtr e = N(extra); - e->setAttr(ExprAttr::StarArgument); - if (!expr->expr->isId("hasattr")) - e = transform(e); - if (partial) { - part.args = e; - args.push_back({realName, transform(N(EllipsisExpr::PARTIAL))}); - newMask[si] = 0; - } else { - args.push_back({realName, e}); - } - } else if (si == kwstarArgIndex && - !(slots[si].size() == 1 && - expr->args[slots[si][0]].value->hasAttr(ExprAttr::KwStarArgument))) { - // Case: **kwargs. Build the named tuple that holds them all - std::vector names; - std::vector values; - if (!part.known.empty()) { - auto e = getPartialArg(-1); - auto t = e->getType()->getRecord(); - seqassert(t && startswith(t->name, TYPE_KWTUPLE), "{} not a kwtuple", e); - auto ff = getClassFields(t.get()); - for (int i = 0; i < t->getRecord()->args.size(); i++) { - names.emplace_back(ff[i].name); - values.emplace_back( - CallExpr::Arg(transform(N(clone(e), ff[i].name)))); + return withClassGenerics( + calleeFn, + [&]() { + for (size_t si = 0, pi = 0; si < slots.size(); si++) { + // Get the argument name to be used later + auto [_, rn] = (*calleeFn->ast)[si].getNameWithStars(); + auto realName = getUnmangledName(rn); + + if ((*calleeFn->ast)[si].isGeneric()) { + // Case: generic arguments. Populate typeArgs + typeArgs.push_back(slots[si].empty() ? nullptr + : (*expr)[slots[si][0]].getExpr()); + newMask[si] = slots[si].empty() ? 0 : 1; + } else if (si == starArgIndex && + !(slots[si].size() == 1 && + (*expr)[slots[si][0]].getExpr()->hasAttribute( + Attr::ExprStarArgument))) { + // Case: *args. Build the tuple that holds them all + std::vector extra; + if (!part.known.empty()) + extra.push_back(N(getPartialArg(-1))); + for (auto &e : slots[si]) { + extra.push_back((*expr)[e].getExpr()); + } + Expr *e = N(extra); + e->setAttribute(Attr::ExprStarArgument); + if (!match(expr->getExpr(), M("hasattr"))) + e = transform(e); + if (partial) { + part.args = e; + args.emplace_back(realName, + transform(N(EllipsisExpr::PARTIAL))); + newMask[si] = 0; + } else { + args.emplace_back(realName, e); + } + } else if (si == kwstarArgIndex && + !(slots[si].size() == 1 && + (*expr)[slots[si][0]].getExpr()->hasAttribute( + Attr::ExprKwStarArgument))) { + // Case: **kwargs. Build the named tuple that holds them all + std::vector names; + std::vector values; + if (!part.known.empty()) { + auto e = transform(N(N(part.var), "kwargs")); + for (auto &[n, ne] : extractNamedTuple(e)) { + names.emplace_back(n); + values.emplace_back(transform(ne)); + } + } + for (auto &e : slots[si]) { + names.emplace_back((*expr)[e].getName()); + values.emplace_back((*expr)[e].getExpr()); + } + + auto kwid = generateKwId(names); + auto e = transform(N(N("NamedTuple"), + N(values), N(kwid))); + e->setAttribute(Attr::ExprKwStarArgument); + if (partial) { + part.kwArgs = e; + args.emplace_back(realName, + transform(N(EllipsisExpr::PARTIAL))); + newMask[si] = 0; + } else { + args.emplace_back(realName, e); + } + } else if (slots[si].empty()) { + // Case: no argument. Check if the arguments is provided by the partial + // type (if calling it) or if a default argument can be used + if (!part.known.empty() && part.known[si]) { + args.emplace_back(realName, getPartialArg(pi++)); + } else if (partial) { + args.emplace_back(realName, + transform(N(EllipsisExpr::PARTIAL))); + newMask[si] = 0; + } else { + if (cast((*calleeFn->ast)[si].getDefault()) && + !(*calleeFn->ast)[si].type) { + args.push_back( + {realName, transform(N(N( + N("Optional"), N("NoneType"))))}); + } else { + args.push_back({realName, transform(clean_clone( + (*calleeFn->ast)[si].getDefault()))}); + } + } + } else { + // Case: argument provided + seqassert(slots[si].size() == 1, "call transformation failed"); + args.emplace_back(realName, (*expr)[slots[si][0]].getExpr()); + } } - } - for (auto &e : slots[si]) { - names.emplace_back(expr->args[e].name); - values.emplace_back(CallExpr::Arg(expr->args[e].value)); - } - auto kwName = generateTuple(names.size(), TYPE_KWTUPLE, names); - auto e = transform(N(N(kwName), values)); - e->setAttr(ExprAttr::KwStarArgument); - if (partial) { - part.kwArgs = e; - args.push_back({realName, transform(N(EllipsisExpr::PARTIAL))}); - newMask[si] = 0; - } else { - args.push_back({realName, e}); - } - } else if (slots[si].empty()) { - // Case: no argument. Check if the arguments is provided by the partial type (if - // calling it) or if a default argument can be used - if (!part.known.empty() && part.known[si]) { - args.push_back({realName, getPartialArg(pi++)}); - } else if (partial) { - args.push_back({realName, transform(N(EllipsisExpr::PARTIAL))}); - newMask[si] = 0; - } else { - auto es = calleeFn->ast->args[si].defaultValue->toString(); - if (in(ctx->defaultCallDepth, es)) - E(Error::CALL_RECURSIVE_DEFAULT, expr, - ctx->cache->rev(calleeFn->ast->args[si].name)); - ctx->defaultCallDepth.insert(es); - - if (calleeFn->ast->args[si].defaultValue->getNone() && - !calleeFn->ast->args[si].type) { - args.push_back( - {realName, transform(N(N( - N("Optional"), N("NoneType"))))}); - } else { - args.push_back( - {realName, transform(clone(calleeFn->ast->args[si].defaultValue))}); - } - ctx->defaultCallDepth.erase(es); - } - } else { - // Case: argument provided - seqassert(slots[si].size() == 1, "call transformation failed"); - args.push_back({realName, expr->args[slots[si][0]].value}); - } - } - ctx->popBlock(); - return 0; + return 0; + }, + true); }; // Reorder arguments if needed part.args = part.kwArgs = nullptr; // Stores partial *args/**kwargs expression - if (expr->hasAttr(ExprAttr::OrderedCall) || expr->expr->isId("superf")) { - args = expr->args; + if (expr->hasAttribute(Attr::ExprOrderedCall)) { + args = expr->items; } else { - ctx->reorderNamedArgs( - calleeFn.get(), expr->args, reorderFn, + reorderNamedArgs( + calleeFn, expr->items, reorderFn, [&](error::Error e, const SrcInfo &o, const std::string &errorMsg) { - error::raise_error(e, o, errorMsg); + E(Error::CUSTOM, o, errorMsg.c_str()); return -1; }, part.known); @@ -426,62 +582,45 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e // Populate partial data if (part.args != nullptr) - part.args->setAttr(ExprAttr::SequenceItem); + part.args->setAttribute(Attr::ExprSequenceItem); if (part.kwArgs != nullptr) - part.kwArgs->setAttr(ExprAttr::SequenceItem); + part.kwArgs->setAttribute(Attr::ExprSequenceItem); if (part.isPartial) { - expr->args.pop_back(); + expr->items.pop_back(); if (!part.args) part.args = transform(N()); // use () - if (!part.kwArgs) { - auto kwName = generateTuple(0, TYPE_KWTUPLE, {}); - part.kwArgs = transform(N(N(kwName))); // use KwTuple() - } + if (!part.kwArgs) + part.kwArgs = transform(N(N("NamedTuple"))); // use NamedTuple() } // Unify function type generics with the provided generics - seqassert((expr->hasAttr(ExprAttr::OrderedCall) && typeArgs.empty()) || - (!expr->hasAttr(ExprAttr::OrderedCall) && + seqassert((expr->hasAttribute(Attr::ExprOrderedCall) && typeArgs.empty()) || + (!expr->hasAttribute(Attr::ExprOrderedCall) && typeArgs.size() == calleeFn->funcGenerics.size()), "bad vector sizes"); if (!calleeFn->funcGenerics.empty()) { auto niGenerics = calleeFn->ast->getNonInferrableGenerics(); - for (size_t si = 0; - !expr->hasAttr(ExprAttr::OrderedCall) && si < calleeFn->funcGenerics.size(); + for (size_t si = 0; !expr->hasAttribute(Attr::ExprOrderedCall) && + si < calleeFn->funcGenerics.size(); si++) { + const auto &gen = calleeFn->funcGenerics[si]; if (typeArgs[si]) { - auto typ = typeArgs[si]->type; - if (calleeFn->funcGenerics[si].type->isStaticType()) { - if (!typeArgs[si]->isStatic()) { + auto typ = extractType(typeArgs[si]); + if (gen.isStatic) + if (!typ->isStaticType()) E(Error::EXPECTED_STATIC, typeArgs[si]); - } - typ = Type::makeStatic(ctx->cache, typeArgs[si]); - } - unify(typ, calleeFn->funcGenerics[si].type); + unify(typ, gen.getType()); } else { - if (calleeFn->funcGenerics[si].type->getUnbound() && - !calleeFn->ast->args[si].defaultValue && !partial && - in(niGenerics, calleeFn->funcGenerics[si].name)) { - error("generic '{}' not provided", calleeFn->funcGenerics[si].niceName); + if (isUnbound(gen.getType()) && !(*calleeFn->ast)[si].getDefault() && + !partial && in(niGenerics, gen.name)) { + E(Error::CUSTOM, getSrcInfo(), "generic '{}' not provided", gen.niceName); } } } } - // Special case: function instantiation (e.g., `foo(T=int)`) - auto cnt = 0; - for (auto &t : typeArgs) - if (t) - cnt++; - if (part.isPartial && cnt && cnt == expr->args.size()) { - transform(expr->expr); // transform again because it might have been changed - unify(expr->type, expr->expr->getType()); - // Return the callee with the corrected type and do not go further - return expr->expr; - } - - expr->args = args; - expr->setAttr(ExprAttr::OrderedCall); + expr->items = args; + expr->setAttribute(Attr::ExprOrderedCall); part.known = newMask; return nullptr; } @@ -491,67 +630,97 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e /// default generics. /// @example /// `foo(1, 2)` -> `foo(1, Optional(2), T=int)` -bool TypecheckVisitor::typecheckCallArgs(const FuncTypePtr &calleeFn, - std::vector &args) { - bool wrappingDone = true; // tracks whether all arguments are wrapped - std::vector replacements; // list of replacement arguments - for (size_t si = 0; si < calleeFn->getArgTypes().size(); si++) { - if (startswith(calleeFn->ast->args[si].name, "*") && calleeFn->ast->args[si].type && - args[si].value->getCall()) { - // Special case: `*args: type` and `**kwargs: type` - auto typ = transform(clone(calleeFn->ast->args[si].type))->type; - for (auto &ca : args[si].value->getCall()->args) { - if (wrapExpr(ca.value, typ, calleeFn)) { - unify(ca.value->type, typ); - } else { - wrappingDone = false; +bool TypecheckVisitor::typecheckCallArgs(FuncType *calleeFn, std::vector &args, + bool isPartial) { + bool wrappingDone = true; // tracks whether all arguments are wrapped + std::vector replacements; // list of replacement arguments + + withClassGenerics( + calleeFn, + [&]() { + for (size_t i = 0, si = 0; i < calleeFn->ast->size(); i++) { + if ((*calleeFn->ast)[i].isGeneric()) + continue; + + if (startswith((*calleeFn->ast)[i].getName(), "*") && + (*calleeFn->ast)[i].getType()) { + // Special case: `*args: type` and `**kwargs: type` + if (auto callExpr = cast(args[si].getExpr())) { + auto typ = extractType(transform(clone((*calleeFn->ast)[i].getType()))); + if (startswith((*calleeFn->ast)[i].getName(), "**")) + callExpr = cast(callExpr->front().getExpr()); + for (auto &ca : *callExpr) { + if (wrapExpr(&ca.value, typ, calleeFn)) { + unify(ca.getExpr()->getType(), typ); + } else { + wrappingDone = false; + } + } + auto name = callExpr->getClassType()->name; + auto tup = transform(N(N(name), callExpr->items)); + if (startswith((*calleeFn->ast)[i].getName(), "**")) { + args[si].value = transform(N( + N(N("NamedTuple"), "__new__"), tup, + N(extractClassGeneric(args[si].getExpr()->getType()) + ->getIntStatic() + ->value))); + } else { + args[si].value = tup; + } + } + replacements.push_back(args[si].getExpr()->getType()); + // else this is empty and is a partial call; leave it for later + } else { + if (wrapExpr(&args[si].value, extractFuncArgType(calleeFn, si), calleeFn)) { + unify(args[si].getExpr()->getType(), extractFuncArgType(calleeFn, si)); + } else { + wrappingDone = false; + } + replacements.push_back(!extractFuncArgType(calleeFn, si)->getClass() + ? args[si].getExpr()->getType() + : extractFuncArgType(calleeFn, si)); + } + si++; } - } - auto name = args[si].value->type->getClass()->name; - args[si].value = - transform(N(N(name), args[si].value->getCall()->args)); - replacements.push_back(args[si].value->type); - } else { - if (wrapExpr(args[si].value, calleeFn->getArgTypes()[si], calleeFn)) { - unify(args[si].value->type, calleeFn->getArgTypes()[si]); - } else { - wrappingDone = false; - } - replacements.push_back(!calleeFn->getArgTypes()[si]->getClass() - ? args[si].value->type - : calleeFn->getArgTypes()[si]); - } - } + return true; + }, + true); // Realize arguments bool done = true; for (auto &a : args) { // Previous unifications can qualify existing identifiers. // Transform again to get the full identifier - if (realize(a.value->type)) - transform(a.value); - done &= a.value->isDone(); + if (realize(a.getExpr()->getType())) + a.value = transform(a.getExpr()); + done &= a.getExpr()->isDone(); } // Handle default generics - for (size_t i = 0, j = 0; wrappingDone && i < calleeFn->ast->args.size(); i++) - if (calleeFn->ast->args[i].status == Param::Generic) { - if (calleeFn->ast->args[i].defaultValue && - calleeFn->funcGenerics[j].type->getUnbound()) { - ctx->addBlock(); // add function generics to typecheck default arguments - addFunctionGenerics(calleeFn->getFunc().get()); - auto def = transform(clone(calleeFn->ast->args[i].defaultValue)); - ctx->popBlock(); - unify(calleeFn->funcGenerics[j].type, - def->isStatic() ? Type::makeStatic(ctx->cache, def) : def->getType()); + if (!isPartial) + for (size_t i = 0, j = 0; wrappingDone && i < calleeFn->ast->size(); i++) + if ((*calleeFn->ast)[i].isGeneric()) { + if ((*calleeFn->ast)[i].getDefault() && + isUnbound(extractFuncGeneric(calleeFn, j))) { + auto def = extractType(withClassGenerics( + calleeFn, + [&]() { + return transform(clean_clone((*calleeFn->ast)[i].getDefault())); + }, + true)); + unify(extractFuncGeneric(calleeFn, j), def); + } + j++; } - j++; - } // Replace the arguments for (size_t si = 0; si < replacements.size(); si++) { - if (replacements[si]) - calleeFn->getArgTypes()[si] = replacements[si]; + if (replacements[si]) { + extractClassGeneric(calleeFn)->getClass()->generics[si].type = + replacements[si]->shared_from_this(); + extractClassGeneric(calleeFn)->getClass()->_rn = ""; + calleeFn->getClass()->_rn = ""; /// TODO: TERRIBLE! + } } return done; @@ -569,610 +738,64 @@ bool TypecheckVisitor::typecheckCallArgs(const FuncTypePtr &calleeFn, /// `type(obj)` /// `compile_err("msg")` /// See below for more details. -std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) { - if (!expr->expr->getId()) +std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) { + auto ei = cast(expr->expr); + if (!ei) return {false, nullptr}; - auto val = expr->expr->getId()->value; - if (val == "superf") { + auto isF = [](IdExpr *val, const std::string &s) { + return val->getValue() == s || val->getValue() == s + ":0" || + val->getValue() == s + ".0:0"; + }; + if (isF(ei, "superf")) { return {true, transformSuperF(expr)}; - } else if (val == "super:0") { + } else if (isF(ei, "super")) { return {true, transformSuper()}; - } else if (val == "__ptr__") { + } else if (isF(ei, "__ptr__")) { return {true, transformPtr(expr)}; - } else if (val == "__array__.__new__:0") { + } else if (isF(ei, "__array__.__new__")) { return {true, transformArray(expr)}; - } else if (val == "isinstance") { + } else if (isF(ei, "isinstance")) { // static return {true, transformIsInstance(expr)}; - } else if (val == "staticlen") { + } else if (isF(ei, "staticlen")) { // static return {true, transformStaticLen(expr)}; - } else if (val == "hasattr") { + } else if (isF(ei, "hasattr")) { // static return {true, transformHasAttr(expr)}; - } else if (val == "getattr") { + } else if (isF(ei, "getattr")) { return {true, transformGetAttr(expr)}; - } else if (val == "setattr") { + } else if (isF(ei, "setattr")) { return {true, transformSetAttr(expr)}; - } else if (val == "type.__new__:0") { + } else if (isF(ei, "type.__new__")) { return {true, transformTypeFn(expr)}; - } else if (val == "compile_error") { + } else if (isF(ei, "compile_error")) { return {true, transformCompileError(expr)}; - } else if (val == "tuple") { - return {true, transformTupleFn(expr)}; - } else if (val == "__realized__") { + } else if (isF(ei, "__realized__")) { return {true, transformRealizedFn(expr)}; - } else if (val == "std.internal.static.static_print") { + } else if (isF(ei, "std.internal.static.static_print")) { return {false, transformStaticPrintFn(expr)}; - } else if (val == "__has_rtti__") { + } else if (isF(ei, "__has_rtti__")) { // static return {true, transformHasRttiFn(expr)}; - } else { - return transformInternalStaticFn(expr); - } -} - -/// Typecheck superf method. This method provides the access to the previous matching -/// overload. -/// @example -/// ```class cls: -/// def foo(): print('foo 1') -/// def foo(): -/// superf() # access the previous foo -/// print('foo 2') -/// cls.foo()``` -/// prints "foo 1" followed by "foo 2" -ExprPtr TypecheckVisitor::transformSuperF(CallExpr *expr) { - auto func = ctx->getRealizationBase()->type->getFunc(); - - // Find list of matching superf methods - std::vector supers; - if (!func->ast->attributes.parentClass.empty() && - !endswith(func->ast->name, ":dispatch")) { - auto p = ctx->find(func->ast->attributes.parentClass)->type; - if (p && p->getClass()) { - if (auto c = in(ctx->cache->classes, p->getClass()->name)) { - if (auto m = in(c->methods, ctx->cache->rev(func->ast->name))) { - for (auto &overload : ctx->cache->overloads[*m]) { - if (endswith(overload.name, ":dispatch")) - continue; - if (overload.name == func->ast->name) - break; - supers.emplace_back(ctx->cache->functions[overload.name].type); - } - } - } - std::reverse(supers.begin(), supers.end()); - } - } - if (supers.empty()) - E(Error::CALL_SUPERF, expr); - auto m = findMatchingMethods( - func->funcParent ? func->funcParent->getClass() : nullptr, supers, expr->args); - if (m.empty()) - E(Error::CALL_SUPERF, expr); - return transform(N(N(m[0]->ast->name), expr->args)); -} - -/// Typecheck and transform super method. Replace it with the current self object cast -/// to the first inherited type. -/// TODO: only an empty super() is currently supported. -ExprPtr TypecheckVisitor::transformSuper() { - if (!ctx->getRealizationBase()->type) - E(Error::CALL_SUPER_PARENT, getSrcInfo()); - auto funcTyp = ctx->getRealizationBase()->type->getFunc(); - if (!funcTyp || !funcTyp->ast->hasAttr(Attr::Method)) - E(Error::CALL_SUPER_PARENT, getSrcInfo()); - if (funcTyp->getArgTypes().empty()) - E(Error::CALL_SUPER_PARENT, getSrcInfo()); - - ClassTypePtr typ = funcTyp->getArgTypes()[0]->getClass(); - auto cands = ctx->cache->classes[typ->name].staticParentClasses; - if (cands.empty()) { - // Dynamic inheritance: use MRO - // TODO: maybe super() should be split into two separate functions... - auto vCands = ctx->cache->classes[typ->name].mro; - if (vCands.size() < 2) - E(Error::CALL_SUPER_PARENT, getSrcInfo()); - - auto superTyp = ctx->instantiate(vCands[1]->type, typ)->getClass(); - auto self = N(funcTyp->ast->args[0].name); - self->type = typ; - - auto typExpr = N(superTyp->name); - typExpr->setType(superTyp); - return transform(N(N(N("__internal__"), "class_super"), - self, typExpr, N(1))); - } - - auto name = cands.front(); // the first inherited type - auto superTyp = ctx->instantiate(ctx->forceFind(name)->type)->getClass(); - if (typ->getRecord()) { - // Case: tuple types. Return `tuple(obj.args...)` - std::vector members; - for (auto &field : getClassFields(superTyp.get())) - members.push_back(N(N(funcTyp->ast->args[0].name), field.name)); - ExprPtr e = transform(N(members)); - e->type = unify(superTyp, e->type); // see super_tuple test for this line - return e; - } else { - // Case: reference types. Return `__internal__.class_super(self, T)` - auto self = N(funcTyp->ast->args[0].name); - self->type = typ; - return castToSuperClass(self, superTyp); - } -} - -/// Typecheck __ptr__ method. This method creates a pointer to an object. Ensure that -/// the argument is a variable binding. -ExprPtr TypecheckVisitor::transformPtr(CallExpr *expr) { - auto id = expr->args[0].value->getId(); - auto val = id ? ctx->find(id->value) : nullptr; - if (!val || val->kind != TypecheckItem::Var) - E(Error::CALL_PTR_VAR, expr->args[0]); - - transform(expr->args[0].value); - unify(expr->type, - ctx->instantiateGeneric(ctx->getType("Ptr"), {expr->args[0].value->type})); - if (expr->args[0].value->isDone()) - expr->setDone(); - return nullptr; -} - -/// Typecheck __array__ method. This method creates a stack-allocated array via alloca. -ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) { - auto arrTyp = expr->expr->type->getFunc(); - unify(expr->type, - ctx->instantiateGeneric(ctx->getType("Array"), - {arrTyp->funcParent->getClass()->generics[0].type})); - if (realize(expr->type)) - expr->setDone(); - return nullptr; -} - -/// Transform isinstance method to a static boolean expression. -/// Special cases: -/// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type -/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type -ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) { - expr->setType(unify(expr->type, ctx->getType("bool"))); - expr->staticValue.type = StaticValue::INT; // prevent branching until this is resolved - transform(expr->args[0].value); - auto typ = expr->args[0].value->type->getClass(); - if (!typ || !typ->canRealize()) - return nullptr; - - transform(expr->args[0].value); // transform again to realize it - - auto &typExpr = expr->args[1].value; - if (auto c = typExpr->getCall()) { - // Handle `isinstance(obj, (type1, type2, ...))` - if (typExpr->origExpr && typExpr->origExpr->getTuple()) { - ExprPtr result = transform(N(false)); - for (auto &i : typExpr->origExpr->getTuple()->items) { - result = transform(N( - result, "||", - N(N("isinstance"), expr->args[0].value, i))); - } - return result; - } - } - - expr->staticValue.type = StaticValue::INT; - if (typExpr->isId(TYPE_TUPLE) || typExpr->isId("tuple")) { - return transform(N(typ->name == TYPE_TUPLE)); - } else if (typExpr->isId("ByVal")) { - return transform(N(typ->getRecord() != nullptr)); - } else if (typExpr->isId("ByRef")) { - return transform(N(typ->getRecord() == nullptr)); - } else if (typExpr->isId("Union")) { - return transform(N(typ->getUnion() != nullptr)); - } else if (!typExpr->type->getUnion() && typ->getUnion()) { - auto unionTypes = typ->getUnion()->getRealizationTypes(); - int tag = -1; - for (size_t ui = 0; ui < unionTypes.size(); ui++) { - if (typExpr->type->unify(unionTypes[ui].get(), nullptr) >= 0) { - tag = ui; - break; - } - } - if (tag == -1) - return transform(N(false)); - return transform(N( - N(N("__internal__.union_get_tag:0"), expr->args[0].value), - "==", N(tag))); - } else if (typExpr->type->is("pyobj") && !typExpr->isType()) { - if (typ->is("pyobj")) { - expr->staticValue.type = StaticValue::NOT_STATIC; - return transform(N(N("std.internal.python._isinstance:0"), - expr->args[0].value, expr->args[1].value)); - } else { - return transform(N(false)); - } - } - - transformType(typExpr); - - // Check super types (i.e., statically inherited) as well - for (auto &tx : getSuperTypes(typ->getClass())) { - types::Type::Unification us; - auto s = tx->unify(typExpr->type.get(), &us); - us.undo(); - if (s >= 0) - return transform(N(true)); - } - return transform(N(false)); -} - -/// Transform staticlen method to a static integer expression. This method supports only -/// static strings and tuple types. -ExprPtr TypecheckVisitor::transformStaticLen(CallExpr *expr) { - expr->staticValue.type = StaticValue::INT; - transform(expr->args[0].value); - auto typ = expr->args[0].value->getType(); - - if (auto s = typ->getStatic()) { - // Case: staticlen on static strings - if (s->expr->staticValue.type != StaticValue::STRING) - E(Error::EXPECTED_STATIC_SPECIFIED, expr->args[0].value, "string"); - if (!s->expr->staticValue.evaluated) - return nullptr; - return transform(N(s->expr->staticValue.getString().size())); - } - if (!typ->getClass()) - return nullptr; - if (typ->getUnion()) { - if (realize(typ)) - return transform(N(typ->getUnion()->getRealizationTypes().size())); - return nullptr; - } - if (!typ->getRecord()) - E(Error::EXPECTED_TUPLE, expr->args[0].value); - return transform(N(typ->getRecord()->args.size())); -} - -/// Transform hasattr method to a static boolean expression. -/// This method also supports additional argument types that are used to check -/// for a matching overload (not available in Python). -ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) { - expr->staticValue.type = StaticValue::INT; - auto typ = expr->args[0].value->getType()->getClass(); - if (!typ) - return nullptr; - auto member = expr->expr->type->getFunc() - ->funcGenerics[0] - .type->getStatic() - ->evaluate() - .getString(); - std::vector> args{{"", typ}}; - - // Case: passing argument types via *args - auto tup = expr->args[1].value->getTuple(); - seqassert(tup, "not a tuple"); - for (auto &a : tup->items) { - transform(a); - if (!a->getType()->getClass()) - return nullptr; - args.emplace_back("", a->getType()); - } - auto kwtup = expr->args[2].value->origExpr->getCall(); - seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(), - "expected call: {}", expr->args[2].value->origExpr); - auto kw = expr->args[2].value->origExpr->getCall(); - auto kwCls = - in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name); - seqassert(kwCls, "cannot find {}", expr->args[2].value->getType()->getClass()->name); - for (size_t i = 0; i < kw->args.size(); i++) { - auto &a = kw->args[i].value; - transform(a); - if (!a->getType()->getClass()) - return nullptr; - args.emplace_back(kwCls->fields[i].name, a->getType()); - } - - if (typ->getUnion()) { - ExprPtr cond = nullptr; - auto unionTypes = typ->getUnion()->getRealizationTypes(); - int tag = -1; - for (size_t ui = 0; ui < unionTypes.size(); ui++) { - auto tu = realize(unionTypes[ui]); - if (!tu) - return nullptr; - auto te = N(tu->getClass()->realizedTypeName()); - auto e = N( - N(N("isinstance"), expr->args[0].value, te), "&&", - N(N("hasattr"), te, N(member))); - cond = !cond ? e : N(cond, "||", e); - } - if (!cond) - return transform(N(false)); - return transform(cond); - } - - bool exists = !ctx->findMethod(typ->getClass().get(), member).empty() || - ctx->findMember(typ->getClass(), member); - if (exists && args.size() > 1) - exists &= findBestMethod(typ, member, args) != nullptr; - return transform(N(exists)); -} - -/// Transform getattr method to a DotExpr. -ExprPtr TypecheckVisitor::transformGetAttr(CallExpr *expr) { - auto funcTyp = expr->expr->type->getFunc(); - auto staticTyp = funcTyp->funcGenerics[0].type->getStatic(); - if (!staticTyp->canRealize()) - return nullptr; - return transform(N(expr->args[0].value, staticTyp->evaluate().getString())); -} - -/// Transform setattr method to a AssignMemberStmt. -ExprPtr TypecheckVisitor::transformSetAttr(CallExpr *expr) { - auto funcTyp = expr->expr->type->getFunc(); - auto staticTyp = funcTyp->funcGenerics[0].type->getStatic(); - if (!staticTyp->canRealize()) - return nullptr; - return transform(N(N(expr->args[0].value, - staticTyp->evaluate().getString(), - expr->args[1].value), - N(N("NoneType")))); -} - -/// Raise a compiler error. -ExprPtr TypecheckVisitor::transformCompileError(CallExpr *expr) { - auto funcTyp = expr->expr->type->getFunc(); - auto staticTyp = funcTyp->funcGenerics[0].type->getStatic(); - if (staticTyp->canRealize()) - E(Error::CUSTOM, expr, staticTyp->evaluate().getString()); - return nullptr; -} - -/// Convert a class to a tuple. -ExprPtr TypecheckVisitor::transformTupleFn(CallExpr *expr) { - auto cls = expr->args.front().value->type->getClass(); - if (!cls) - return nullptr; - - // tuple(ClassType) is a tuple type that corresponds to a class - if (expr->args.front().value->isType()) { - if (!realize(cls)) - return expr->clone(); - - std::vector items; - for (auto &ft : getClassFields(cls.get())) { - auto t = ctx->instantiate(ft.type, cls); - auto rt = realize(t); - seqassert(rt, "cannot realize '{}' in {}", t, ft.name); - items.push_back(NT(t->realizedName())); - } - auto e = transform(NT(N(TYPE_TUPLE), items)); - return e; - } - - std::vector args; - std::string var = ctx->cache->getTemporaryVar("tup"); - for (auto &field : getClassFields(cls.get())) - args.emplace_back(N(N(var), field.name)); - - return transform(N(N(N(var), expr->args.front().value), - N(args))); -} - -/// Transform type function to a type IdExpr identifier. -ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) { - expr->markType(); - transform(expr->args[0].value); - - unify(expr->type, expr->args[0].value->getType()); - - if (!realize(expr->type)) - return nullptr; - - auto e = NT(expr->type->realizedName()); - e->setType(expr->type); - e->setDone(); - return e; -} - -/// Transform __realized__ function to a fully realized type identifier. -ExprPtr TypecheckVisitor::transformRealizedFn(CallExpr *expr) { - auto call = - transform(N(expr->args[0].value, N(expr->args[1].value))); - if (!call->getCall()->expr->type->getFunc()) - E(Error::CALL_REALIZED_FN, expr->args[0].value); - if (auto f = realize(call->getCall()->expr->type)) { - auto e = N(f->getFunc()->realizedName()); - e->setType(f); - e->setDone(); - return e; - } - return nullptr; -} - -/// Transform __static_print__ function to a fully realized type identifier. -ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) { - auto &args = expr->args[0].value->getCall()->args; - for (size_t i = 0; i < args.size(); i++) { - realize(args[i].value->type); - fmt::print(stderr, "[static_print] {}: {} := {}{} (iter: {})\n", getSrcInfo(), - FormatVisitor::apply(args[i].value), - args[i].value->type ? args[i].value->type->debugString(1) : "-", - args[i].value->isStatic() ? " [static]" : "", - ctx->getRealizationBase()->iteration); - } - return nullptr; -} - -/// Transform __has_rtti__ to a static boolean that indicates RTTI status of a type. -ExprPtr TypecheckVisitor::transformHasRttiFn(CallExpr *expr) { - expr->staticValue.type = StaticValue::INT; - auto funcTyp = expr->expr->type->getFunc(); - auto t = funcTyp->funcGenerics[0].type->getClass(); - if (!t) - return nullptr; - auto c = in(ctx->cache->classes, t->name); - seqassert(c, "bad class {}", t->name); - return transform(N(const_cast(c)->rtti)); -} - -// Transform internal.static calls -std::pair TypecheckVisitor::transformInternalStaticFn(CallExpr *expr) { - unify(expr->type, ctx->getUnbound()); - if (expr->expr->isId("std.internal.static.fn_can_call")) { - expr->staticValue.type = StaticValue::INT; - auto typ = expr->args[0].value->getType()->getClass(); - if (!typ) - return {true, nullptr}; - - auto inargs = unpackTupleTypes(expr->args[1].value); - auto kwargs = unpackTupleTypes(expr->args[2].value); - seqassert(inargs && kwargs, "bad call to fn_can_call"); - - std::vector callArgs; - for (auto &a : *inargs) { - callArgs.push_back({a.first, std::make_shared()}); // dummy expression - callArgs.back().value->setType(a.second); - } - for (auto &a : *kwargs) { - callArgs.push_back({a.first, std::make_shared()}); // dummy expression - callArgs.back().value->setType(a.second); - } - - if (auto fn = expr->args[0].value->type->getFunc()) { - return {true, transform(N(canCall(fn, callArgs) >= 0))}; - } else if (auto pt = expr->args[0].value->type->getPartial()) { - return {true, transform(N(canCall(pt->func, callArgs, pt) >= 0))}; - } else { - compilationWarning("cannot use fn_can_call on non-functions", getSrcInfo().file, - getSrcInfo().line, getSrcInfo().col); - return {true, transform(N(false))}; - } - } else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) { - expr->staticValue.type = StaticValue::INT; - auto fn = ctx->extractFunction(expr->args[0].value->type); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); - seqassert(idx, "expected a static integer"); - auto &args = fn->getArgTypes(); - return {true, transform(N(*idx >= 0 && *idx < args.size() && - args[*idx]->canRealize()))}; - } else if (expr->expr->isId("std.internal.static.fn_arg_get_type")) { - auto fn = ctx->extractFunction(expr->args[0].value->type); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); - seqassert(idx, "expected a static integer"); - auto &args = fn->getArgTypes(); - if (*idx < 0 || *idx >= args.size() || !args[*idx]->canRealize()) - error("argument does not have type"); - return {true, transform(NT(args[*idx]->realizedName()))}; - } else if (expr->expr->isId("std.internal.static.fn_args")) { - auto fn = ctx->extractFunction(expr->args[0].value->type); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - std::vector v; - for (size_t i = 0; i < fn->ast->args.size(); i++) { - auto n = fn->ast->args[i].name; - trimStars(n); - n = ctx->cache->rev(n); - v.push_back(N(n)); - } - return {true, transform(N(v))}; - } else if (expr->expr->isId("std.internal.static.fn_has_default")) { - expr->staticValue.type = StaticValue::INT; - auto fn = ctx->extractFunction(expr->args[0].value->type); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); - seqassert(idx, "expected a static integer"); - auto &args = fn->ast->args; - if (*idx < 0 || *idx >= args.size()) - error("argument out of bounds"); - return {true, transform(N(args[*idx].defaultValue != nullptr))}; - } else if (expr->expr->isId("std.internal.static.fn_get_default")) { - auto fn = ctx->extractFunction(expr->args[0].value->type); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - auto idx = ctx->getStaticInt(expr->expr->type->getFunc()->funcGenerics[0].type); - seqassert(idx, "expected a static integer"); - auto &args = fn->ast->args; - if (*idx < 0 || *idx >= args.size()) - error("argument out of bounds"); - return {true, transform(args[*idx].defaultValue)}; - } else if (expr->expr->isId("std.internal.static.fn_wrap_call_args")) { - auto typ = expr->args[0].value->getType()->getClass(); - if (!typ) - return {true, nullptr}; - - auto fn = ctx->extractFunction(expr->args[0].value->type); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - - std::vector callArgs; - if (auto tup = expr->args[1].value->origExpr->getTuple()) { - for (auto &a : tup->items) { - callArgs.push_back({"", a}); - } - } - if (auto kw = expr->args[1].value->origExpr->getCall()) { - auto kwCls = in(ctx->cache->classes, expr->getType()->getClass()->name); - seqassert(kwCls, "cannot find {}", expr->getType()->getClass()->name); - for (size_t i = 0; i < kw->args.size(); i++) { - callArgs.push_back({kwCls->fields[i].name, kw->args[i].value}); - } - } - auto zzz = transform(N(N(fn->ast->name), callArgs)); - if (!zzz->isDone()) - return {true, nullptr}; - - std::vector tupArgs; - for (auto &a : zzz->getCall()->args) - tupArgs.push_back(a.value); - return {true, transform(N(tupArgs))}; - } else if (expr->expr->isId("std.internal.static.vars")) { - auto funcTyp = expr->expr->type->getFunc(); - auto t = funcTyp->funcGenerics[0].type->getStatic(); - if (!t) - return {true, nullptr}; - auto withIdx = t->evaluate().getInt(); - - types::ClassTypePtr typ = nullptr; - std::vector tupleItems; - auto e = transform(expr->args[0].value); - if (!(typ = e->type->getClass())) - return {true, nullptr}; - - size_t idx = 0; - for (auto &f : getClassFields(typ.get())) { - auto k = N(f.name); - auto v = N(expr->args[0].value, f.name); - if (withIdx) { - auto i = N(idx); - tupleItems.push_back(N(std::vector{i, k, v})); - } else { - tupleItems.push_back(N(std::vector{k, v})); - } - idx++; - } - return {true, transform(N(tupleItems))}; - } else if (expr->expr->isId("std.internal.static.tuple_type")) { - auto funcTyp = expr->expr->type->getFunc(); - auto t = funcTyp->funcGenerics[0].type; - if (!t || !realize(t)) - return {true, nullptr}; - auto tn = funcTyp->funcGenerics[1].type->getStatic(); - if (!tn) - return {true, nullptr}; - auto n = tn->evaluate().getInt(); - types::TypePtr typ = nullptr; - if (t->getRecord()) { - if (n < 0 || n >= t->getRecord()->args.size()) - error("invalid index"); - typ = t->getRecord()->args[n]; - } else { - auto f = getClassFields(t->getClass().get()); - if (n < 0 || n >= f.size()) - error("invalid index"); - typ = ctx->instantiate(f[n].type, t->getClass()); - } - typ = realize(typ); - return {true, transform(NT(typ->realizedName()))}; + } else if (isF(ei, "std.collections.namedtuple")) { + return {true, transformNamedTuple(expr)}; + } else if (isF(ei, "std.functools.partial")) { + return {true, transformFunctoolsPartial(expr)}; + } else if (isF(ei, "std.internal.static.fn_can_call")) { // static + return {true, transformStaticFnCanCall(expr)}; + } else if (isF(ei, "std.internal.static.fn_arg_has_type")) { // static + return {true, transformStaticFnArgHasType(expr)}; + } else if (isF(ei, "std.internal.static.fn_arg_get_type")) { + return {true, transformStaticFnArgGetType(expr)}; + } else if (isF(ei, "std.internal.static.fn_args")) { + return {true, transformStaticFnArgs(expr)}; + } else if (isF(ei, "std.internal.static.fn_has_default")) { // static + return {true, transformStaticFnHasDefault(expr)}; + } else if (isF(ei, "std.internal.static.fn_get_default")) { + return {true, transformStaticFnGetDefault(expr)}; + } else if (isF(ei, "std.internal.static.fn_wrap_call_args")) { + return {true, transformStaticFnWrapCallArgs(expr)}; + } else if (isF(ei, "std.internal.static.vars")) { + return {true, transformStaticVars(expr)}; + } else if (isF(ei, "std.internal.static.tuple_type")) { + return {true, transformStaticTupleType(expr)}; } else { return {false, nullptr}; } @@ -1180,87 +803,60 @@ std::pair TypecheckVisitor::transformInternalStaticFn(CallExpr *e /// Get the list that describes the inheritance hierarchy of a given type. /// The first type in the list is the most recently inherited type. -std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) { - std::vector result; +std::vector TypecheckVisitor::getSuperTypes(ClassType *cls) { + std::vector result; if (!cls) return result; - result.push_back(cls); - for (auto &name : ctx->cache->classes[cls->name].staticParentClasses) { - auto parentTyp = ctx->instantiate(ctx->forceFind(name)->type)->getClass(); - for (auto &field : getClassFields(cls.get())) { - for (auto &parentField : getClassFields(parentTyp.get())) + result.push_back(cls->shared_from_this()); + auto c = getClass(cls); + auto fields = getClassFields(cls); + for (auto &name : c->staticParentClasses) { + auto parentTyp = instantiateType(extractClassType(name)); + auto parentFields = getClassFields(parentTyp->getClass()); + for (auto &field : fields) { + for (auto &parentField : parentFields) if (field.name == parentField.name) { - unify(ctx->instantiate(field.type, cls), - ctx->instantiate(parentField.type, parentTyp)); + auto t = instantiateType(field.getType(), cls); + unify(t.get(), instantiateType(parentField.getType(), parentTyp->getClass())); break; } } - for (auto &t : getSuperTypes(parentTyp)) + for (auto &t : getSuperTypes(parentTyp->getClass())) result.push_back(t); } return result; } -/// Find all generics on which a function depends on and add them to the current -/// context. -void TypecheckVisitor::addFunctionGenerics(const FuncType *t) { - for (auto parent = t->funcParent; parent;) { - if (auto f = parent->getFunc()) { - // Add parent function generics - for (auto &g : f->funcGenerics) { - // LOG(" -> {} := {}", g.name, g.type->debugString(true)); - ctx->add(TypecheckItem::Type, g.name, g.type); - } - parent = f->funcParent; - } else { - // Add parent class generics - seqassert(parent->getClass(), "not a class: {}", parent); - for (auto &g : parent->getClass()->generics) { - // LOG(" => {} := {}", g.name, g.type->debugString(true)); - ctx->add(TypecheckItem::Type, g.name, g.type); - } - for (auto &g : parent->getClass()->hiddenGenerics) { - // LOG(" :> {} := {}", g.name, g.type->debugString(true)); - ctx->add(TypecheckItem::Type, g.name, g.type); - } - break; - } - } - // Add function generics - for (auto &g : t->funcGenerics) { - // LOG(" >> {} := {}", g.name, g.type->debugString(true)); - ctx->add(TypecheckItem::Type, g.name, g.type); - } -} - -/// Generate a partial type `Partial.N` for a given function. +/// Return a partial type call `Partial(args, kwargs, fn, mask)` for a given function +/// and a mask. /// @param mask a 0-1 vector whose size matches the number of function arguments. /// 1 indicates that the argument has been provided and is cached within /// the partial object. -/// @example -/// ```@tuple -/// class Partial.N101[T0, T2]: -/// item0: T0 # the first cached argument -/// item2: T2 # the third cached argument -std::string TypecheckVisitor::generatePartialStub(const std::vector &mask, - types::FuncType *fn) { +Expr *TypecheckVisitor::generatePartialCall(const std::vector &mask, + types::FuncType *fn, Expr *args, + Expr *kwargs) { std::string strMask(mask.size(), '1'); - int tupleSize = 0, genericSize = 0; for (size_t i = 0; i < mask.size(); i++) { if (!mask[i]) strMask[i] = '0'; - else if (fn->ast->args[i].status == Param::Normal) - tupleSize++; - else - genericSize++; - } - auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->toString()); - if (!ctx->find(typeName)) { - ctx->cache->partials[typeName] = {fn->generalize(0)->getFunc(), mask}; - generateTuple(tupleSize + 2, typeName, {}, false); } - return typeName; + + if (!args) + args = N(std::vector{N()}); + if (!kwargs) + kwargs = N(N("NamedTuple")); + + auto efn = N(fn->getFuncName()); + efn->setType(instantiateType(getStdLibType("unrealized_type"), + std::vector{fn->getFunc()})); + efn->setDone(); + Expr *call = N(N("Partial"), + std::vector{{"args", args}, + {"kwargs", kwargs}, + {"M", N(strMask)}, + {"F", efn}}); + return call; } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index e8801f79..c3564e84 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -3,203 +3,772 @@ #include #include +#include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" +#include "codon/parser/visitors/format/format.h" +#include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; +using namespace codon::error; namespace codon::ast { using namespace types; +using namespace matcher; /// Parse a class (type) declaration and add a (generic) type to the context. void TypecheckVisitor::visit(ClassStmt *stmt) { - // Extensions are not possible after the simplification - seqassert(!stmt->attributes.has(Attr::Extend), "invalid extension '{}'", stmt->name); - // Type should be constructed only once - stmt->setDone(); - - // Generate the type and add it to the context - auto typ = Type::makeType(ctx->cache, stmt->name, ctx->cache->rev(stmt->name), - stmt->isRecord()) - ->getClass(); - if (stmt->isRecord() && stmt->hasAttr("__notuple__")) - typ->getRecord()->noTuple = true; - if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) { - // Special handling of partial types (e.g., `Partial.0001.foo`) - if (auto p = in(ctx->cache->partials, stmt->name)) - typ = std::make_shared(typ->getRecord(), p->first, p->second); + // Get root name + std::string name = stmt->getName(); + + // Generate/find class' canonical name (unique ID) and AST + std::string canonicalName; + std::vector &argsToParse = stmt->items; + + // classItem will be added later when the scope is different + auto classItem = std::make_shared("", "", ctx->getModule(), nullptr, + ctx->getScope()); + classItem->setSrcInfo(stmt->getSrcInfo()); + std::shared_ptr timedItem = nullptr; + types::ClassType *typ = nullptr; + if (!stmt->hasAttribute(Attr::Extend)) { + classItem->canonicalName = canonicalName = + ctx->generateCanonicalName(name, !stmt->hasAttribute(Attr::Internal), + /* noSuffix*/ stmt->hasAttribute(Attr::Internal)); + + if (canonicalName == "Union") + classItem->type = std::make_shared(ctx->cache); + else + classItem->type = + std::make_shared(ctx->cache, canonicalName, name); + if (stmt->isRecord()) + classItem->type->getClass()->isTuple = true; + classItem->type->setSrcInfo(stmt->getSrcInfo()); + + typ = classItem->getType()->getClass(); + if (canonicalName != TYPE_TYPE) + classItem->type = instantiateTypeVar(classItem->getType()); + + timedItem = std::make_shared(*classItem); + // timedItem->time = getTime(); + + // Reference types are added to the context here. + // Tuple types are added after class contents are parsed to prevent + // recursive record types (note: these are allowed for reference types) + if (!stmt->hasAttribute(Attr::Tuple)) { + ctx->add(name, timedItem); + ctx->addAlwaysVisible(classItem); + } + } else { + // Find the canonical name and AST of the class that is to be extended + if (!ctx->isGlobal() || ctx->isConditional()) + E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "class extension"); + auto val = ctx->find(name, getTime()); + if (!val || !val->isType()) + E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), name); + typ = val->getName() == TYPE_TYPE ? val->getType()->getClass() + : extractClassType(val->getType()); + canonicalName = typ->name; + argsToParse = getClass(typ)->ast->items; } - typ->setSrcInfo(stmt->getSrcInfo()); - // Classes should always be visible, so add them to the toplevel - ctx->addToplevel(stmt->name, - std::make_shared(TypecheckItem::Type, typ)); - - // Handle generics - for (const auto &a : stmt->args) { - if (a.status != Param::Normal) { - // Generic and static types - auto generic = ctx->getUnbound(); - generic->isStatic = getStaticGeneric(a.type.get()); - auto typId = generic->id; - generic->getLink()->genericName = ctx->cache->rev(a.name); - if (a.defaultValue) { - auto defType = transformType(clone(a.defaultValue)); - if (a.status == Param::Generic) { - generic->defaultType = defType->type; + auto &cls = ctx->cache->classes[canonicalName]; + + std::vector clsStmts; // Will be filled later! + std::vector varStmts; // Will be filled later! + std::vector fnStmts; // Will be filled later! + std::vector addLater; + try { + // Add the class base + TypeContext::BaseGuard br(ctx.get(), canonicalName); + ctx->getBase()->type = typ->shared_from_this(); + + // Parse and add class generics + std::vector args; + if (stmt->hasAttribute(Attr::Extend)) { + for (auto &a : argsToParse) { + if (!a.isGeneric()) + continue; + auto val = ctx->forceFind(a.name); + val->type->getLink()->kind = LinkType::Unbound; + ctx->add(getUnmangledName(val->canonicalName), val); + args.emplace_back(val->canonicalName, nullptr, nullptr, a.status); + } + } else { + if (stmt->hasAttribute(Attr::ClassDeduce) && args.empty()) { + autoDeduceMembers(stmt, argsToParse); + stmt->eraseAttribute(Attr::ClassDeduce); + } + + // Add all generics before parent classes, fields and methods + for (auto &a : argsToParse) { + if (!a.isGeneric()) + continue; + + auto varName = ctx->generateCanonicalName(a.getName()), genName = a.getName(); + auto generic = instantiateUnbound(); + auto typId = generic->id; + generic->getLink()->genericName = genName; + auto defType = transformType(clone(a.getDefault())); + if (defType) + generic->defaultType = extractType(defType)->shared_from_this(); + if (auto st = getStaticGeneric(a.getType())) { + if (st > 3) + a.type = transform(a.getType()); // error check + generic->isStatic = st; + auto val = ctx->addVar(genName, varName, generic); + val->generic = true; } else { - // Hidden generics can be outright replaced (e.g., `T=int`). - // Unify them immediately. - unify(defType->type, generic); + if (cast(a.getType())) { // Parse TraitVar + a.type = transform(a.getType()); + auto ti = cast(a.getType()); + seqassert(ti && isId(ti->getExpr(), TYPE_TYPEVAR), + "not a TypeVar instantiation: {}", *(a.getType())); + auto l = extractType(ti->front()); + if (l->getLink() && l->getLink()->trait) + generic->getLink()->trait = l->getLink()->trait; + else + generic->getLink()->trait = + std::make_shared(l->shared_from_this()); + } + ctx->addType(genName, varName, generic)->generic = true; } + typ->generics.emplace_back(varName, genName, + generic->generalize(ctx->typecheckLevel), typId, + generic->isStatic); + args.emplace_back(varName, a.getType(), defType, a.status); } - if (auto ti = CAST(a.type, InstantiateExpr)) { - // Parse TraitVar - seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation"); - auto l = transformType(ti->typeParams[0])->type; - if (l->getLink() && l->getLink()->trait) - generic->getLink()->trait = l->getLink()->trait; - else - generic->getLink()->trait = std::make_shared(l); + } + + // Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes) + Expr *typeAst = nullptr, *transformedTypeAst = nullptr; + if (!stmt->hasAttribute(Attr::Extend)) { + typeAst = N(name); + transformedTypeAst = N(canonicalName); + for (auto &a : args) { + if (a.isGeneric()) { + if (!cast(typeAst)) { + typeAst = N(N(name), N()); + transformedTypeAst = + N(N(canonicalName), std::vector{}); + } + cast(cast(typeAst)->getIndex()) + ->items.push_back(N(a.getName())); + cast(transformedTypeAst) + ->items.push_back(transform(N(a.getName()), true)); + } } - ctx->add(TypecheckItem::Type, a.name, generic); - ClassType::Generic g{a.name, ctx->cache->rev(a.name), - generic->generalize(ctx->typecheckLevel), typId}; - if (a.status == Param::Generic) { - typ->generics.push_back(g); + } + + // Collect classes (and their fields) that are to be statically inherited + std::vector staticBaseASTs; + if (!stmt->hasAttribute(Attr::Extend)) { + staticBaseASTs = parseBaseClasses(stmt->staticBaseClasses, args, stmt, + canonicalName, nullptr, typ); + if (ctx->cache->isJit && !stmt->baseClasses.empty()) + E(Error::CUSTOM, stmt->baseClasses[0], + "inheritance is not yet supported in JIT mode"); + parseBaseClasses(stmt->baseClasses, args, stmt, canonicalName, transformedTypeAst, + typ); + } + + // A ClassStmt will be separated into class variable assignments, method-free + // ClassStmts (that include nested classes) and method FunctionStmts + transformNestedClasses(stmt, clsStmts, varStmts, fnStmts); + + // Collect class fields + for (auto &a : argsToParse) { + if (a.isValue()) { + if (ClassStmt::isClassVar(a)) { + // Handle class variables. Transform them later to allow self-references + auto name = format("{}.{}", canonicalName, a.getName()); + auto h = transform(N(N(name), nullptr, nullptr)); + preamble->push_back(h); + auto val = ctx->forceFind(name); + val->baseName = ""; + val->scope = {0}; + registerGlobal(val->canonicalName); + auto assign = N( + N(name), a.getDefault(), + a.getType() ? cast(a.getType())->getIndex() : nullptr); + assign->setUpdate(); + varStmts.push_back(assign); + cls.classVars[a.getName()] = name; + } else if (!stmt->hasAttribute(Attr::Extend)) { + std::string varName = a.getName(); + args.emplace_back(varName, transformType(clean_clone(a.getType())), + transform(clone(a.getDefault()), true)); + cls.fields.emplace_back(varName, nullptr, canonicalName); + } + } + } + + // ASTs for member arguments to be used for populating magic methods + std::vector memberArgs; + for (auto &a : args) { + if (a.isValue()) + memberArgs.emplace_back(clone(a)); + } + + // Handle class members + if (!stmt->hasAttribute(Attr::Extend)) { + ctx->typecheckLevel++; // to avoid unifying generics early + if (canonicalName == TYPE_TUPLE) { + // Special tuple handling! + for (auto aj = 0; aj < MAX_TUPLE; aj++) { + auto genName = fmt::format("T{}", aj + 1); + auto genCanName = ctx->generateCanonicalName(genName); + auto generic = instantiateUnbound(); + generic->getLink()->genericName = genName; + Expr *te = N(genCanName); + cls.fields.emplace_back(fmt::format("item{}", aj + 1), + generic->generalize(ctx->typecheckLevel), "", te); + } } else { - typ->hiddenGenerics.push_back(g); + for (auto ai = 0, aj = 0; ai < args.size(); ai++) { + if (args[ai].isValue() && !ClassStmt::isClassVar(args[ai])) { + cls.fields[aj].typeExpr = clean_clone(args[ai].getType()); + cls.fields[aj].type = + extractType(args[ai].getType())->generalize(ctx->typecheckLevel - 1); + cls.fields[aj].type->setSrcInfo(args[ai].getType()->getSrcInfo()); + aj++; + } + } + } + ctx->typecheckLevel--; + } + + // Parse class members (arguments) and methods + if (!stmt->hasAttribute(Attr::Extend)) { + // Now that we are done with arguments, add record type to the context + if (stmt->hasAttribute(Attr::Tuple)) { + ctx->add(name, timedItem); + ctx->addAlwaysVisible(classItem); + } + // Create a cached AST. + stmt->setAttribute(Attr::Module, ctx->moduleName.status == ImportFile::STDLIB + ? STDLIB_IMPORT + : ctx->moduleName.path); + cls.ast = N(canonicalName, args, N()); + cls.ast->cloneAttributesFrom(stmt); + cls.ast->baseClasses = stmt->baseClasses; + for (auto &b : staticBaseASTs) + cls.staticParentClasses.emplace_back(b->getClass()->name); + cls.module = ctx->moduleName.path; + + // Codegen default magic methods + // __new__ must be the first + if (auto aa = stmt->getAttribute(Attr::ClassMagic)) + for (auto &m : aa->values) { + fnStmts.push_back(transform( + codegenMagic(m, typeAst, memberArgs, stmt->hasAttribute(Attr::Tuple)))); + } + // Add inherited methods + for (auto &base : staticBaseASTs) { + for (auto &mm : getClass(base->getClass())->methods) + for (auto &mf : getOverloads(mm.second)) { + const auto &fp = getFunction(mf); + auto f = fp->origAst; + if (f && !f->hasAttribute(Attr::AutoGenerated)) { + fnStmts.push_back( + cast(withClassGenerics(base->getClass(), [&]() { + auto cf = clean_clone(f); + // since functions can come from other modules + // make sure to transform them in their respective module + // however makle sure to add/pop generics :/ + if (!ctx->isStdlibLoading && fp->module != ctx->moduleName.path) { + auto ictx = getImport(fp->module)->ctx; + TypeContext::BaseGuard br(ictx.get(), canonicalName); + ictx->getBase()->type = typ->shared_from_this(); + auto tv = TypecheckVisitor(ictx); + auto e = tv.withClassGenerics( + typ, [&]() { return tv.transform(clean_clone(f)); }, false, + false, + /*instantiate*/ true); + return e; + } else { + return transform(clean_clone(f)); + } + }))); + } + } + } + } + + // Add class methods + for (const auto &sp : getClassMethods(stmt->getSuite())) + if (auto fp = cast(sp)) { + for (auto *&dc : fp->decorators) { + // Handle @setter setters + if (match(dc, M(M(fp->getName()), "setter")) && + fp->size() == 2) { + fp->name = format("{}{}", FN_SETTER_SUFFIX, fp->getName()); + dc = nullptr; + break; + } + } + fnStmts.emplace_back(transform(sp)); + } + + // After popping context block, record types and nested classes will disappear. + // Store their references and re-add them to the context after popping + addLater.reserve(clsStmts.size() + 1); + for (auto &c : clsStmts) + addLater.emplace_back(ctx->find(cast(c)->getName())); + if (stmt->hasAttribute(Attr::Tuple)) + addLater.emplace_back(ctx->forceFind(name)); + + // Mark functions as virtual: + auto banned = + std::set{"__init__", "__new__", "__raw__", "__tuplesize__"}; + for (auto &m : cls.methods) { + auto method = m.first; + for (size_t mi = 1; mi < cls.mro.size(); mi++) { + // ... in the current class + auto b = cls.mro[mi]->name; + if (in(getClass(b)->methods, method) && !in(banned, method)) { + cls.virtuals.insert(method); + } + } + for (auto &v : cls.virtuals) { + for (size_t mi = 1; mi < cls.mro.size(); mi++) { + // ... and in parent classes + auto b = cls.mro[mi]->name; + getClass(b)->virtuals.insert(v); + } } } + + // Generalize generics and remove them from the context + for (const auto &g : args) + if (!g.isValue()) { + auto generic = ctx->forceFind(g.name)->type; + if (g.status == Param::Generic) { + // Generalize generics. Hidden generics are linked to the class generics so + // ignore them + seqassert(generic && generic->getLink() && + generic->getLink()->kind != types::LinkType::Link, + "generic has been unified"); + generic->getLink()->kind = LinkType::Generic; + } + ctx->remove(g.name); + } + + // Debug information + LOG_REALIZE("[class] {} -> {:c} / {}", canonicalName, *typ, cls.fields.size()); + for (auto &m : cls.fields) + LOG_REALIZE(" - member: {}: {:c}", m.name, *(m.type)); + for (auto &m : cls.methods) + LOG_REALIZE(" - method: {}: {}", m.first, m.second); + for (auto &m : cls.mro) + LOG_REALIZE(" - mro: {:c}", *m); + } catch (const exc::ParserException &) { + if (!stmt->hasAttribute(Attr::Tuple)) + ctx->remove(name); + ctx->cache->classes.erase(name); + throw; + } + for (auto &i : addLater) + ctx->add(getUnmangledName(i->canonicalName), i); + + // Extensions are not needed as the cache is already populated + if (!stmt->hasAttribute(Attr::Extend)) { + auto c = cls.ast; + seqassert(c, "not a class AST for {}", canonicalName); + c->setDone(); + clsStmts.push_back(c); } - // Handle class members - ctx->typecheckLevel++; // to avoid unifying generics early - auto &fields = ctx->cache->classes[stmt->name].fields; - for (auto ai = 0, aj = 0; ai < stmt->args.size(); ai++) - if (stmt->args[ai].status == Param::Normal) { - fields[aj].type = transformType(stmt->args[ai].type) - ->getType() - ->generalize(ctx->typecheckLevel - 1); - fields[aj].type->setSrcInfo(stmt->args[ai].type->getSrcInfo()); - if (stmt->isRecord()) - typ->getRecord()->args.push_back(fields[aj].type); - aj++; - } - ctx->typecheckLevel--; - - // Handle MRO - for (auto &m : ctx->cache->classes[stmt->name].mro) { - m = transformType(m); + clsStmts.insert(clsStmts.end(), fnStmts.begin(), fnStmts.end()); + for (auto &a : varStmts) { + // Transform class variables here to allow self-references + clsStmts.push_back(transform(a)); } + resultStmt = N(clsStmts); +} + +/// Parse statically inherited classes. +/// Returns a list of their ASTs. Also updates the class fields. +/// @param args Class fields that are to be updated with base classes' fields. +/// @param typeAst Transformed AST for base class type (e.g., `A[T]`). +/// Only set when dealing with dynamic polymorphism. +std::vector TypecheckVisitor::parseBaseClasses( + std::vector &baseClasses, std::vector &args, Stmt *attr, + const std::string &canonicalName, Expr *typeAst, types::ClassType *typ) { + std::vector asts; + + // TODO)) fix MRO it to work with generic classes (maybe replacements? IDK...) + std::vector> mro{{typ->shared_from_this()}}; + for (auto &cls : baseClasses) { + std::vector subs; + + // Get the base class and generic replacements (e.g., if there is Bar[T], + // Bar in Foo(Bar[int]) will have `T = int`) + cls = transformType(cls); + if (!cls->getClassType()) + E(Error::CLASS_ID_NOT_FOUND, getSrcInfo(), FormatVisitor::apply(cls)); + + auto clsTyp = extractClassType(cls); + asts.push_back(clsTyp->shared_from_this()); + auto cachedCls = getClass(clsTyp); + if (!cachedCls->ast) + E(Error::CLASS_NO_INHERIT, getSrcInfo(), "nested", "surrounding"); + std::vector rootMro; + for (auto &t : cachedCls->mro) + rootMro.push_back(instantiateType(t.get(), clsTyp)); + mro.push_back(rootMro); + + // Sanity checks + if (attr->hasAttribute(Attr::Tuple) && typeAst) + E(Error::CLASS_NO_INHERIT, getSrcInfo(), "tuple", "other"); + if (!attr->hasAttribute(Attr::Tuple) && cachedCls->ast->hasAttribute(Attr::Tuple)) + E(Error::CLASS_TUPLE_INHERIT, getSrcInfo()); + if (cachedCls->ast->hasAttribute(Attr::Internal)) + E(Error::CLASS_NO_INHERIT, getSrcInfo(), "internal", "other"); - // Generalize generics and remove them from the context - for (const auto &g : stmt->args) - if (g.status != Param::Normal) { - auto generic = ctx->forceFind(g.name)->type; - if (g.status == Param::Generic) { - // Generalize generics. Hidden generics are linked to the class generics so - // ignore them - seqassert(generic && generic->getLink() && - generic->getLink()->kind != types::LinkType::Link, - "generic has been unified"); - generic->getLink()->kind = LinkType::Generic; + // Mark parent classes as polymorphic as well. + if (typeAst) + cachedCls->rtti = true; + + // Add hidden generics + addClassGenerics(clsTyp); + for (auto &g : clsTyp->generics) + typ->hiddenGenerics.push_back(g); + for (auto &g : clsTyp->hiddenGenerics) + typ->hiddenGenerics.push_back(g); + } + // Add normal fields + auto cls = getClass(canonicalName); + for (auto &clsTyp : asts) { + withClassGenerics(clsTyp->getClass(), [&]() { + int ai = 0; + auto ast = getClass(clsTyp->getClass())->ast; + for (auto &a : *ast) { + auto acls = getClass(ast->name); + if (a.isValue() && !ClassStmt::isClassVar(a)) { + auto name = a.getName(); + int i = 0; + for (auto &aa : args) + i += aa.getName() == a.getName() || + startswith(aa.getName(), a.getName() + "#"); + if (i) + name = format("{}#{}", name, i); + seqassert(acls->fields[ai].name == a.getName(), "bad class fields: {} vs {}", + acls->fields[ai].name, a.getName()); + args.emplace_back(name, transformType(clean_clone(a.getType())), + transform(clean_clone(a.getDefault()))); + cls->fields.emplace_back( + name, extractType(args.back().getType())->shared_from_this(), + acls->fields[ai].baseClass); + ai++; + } } - ctx->remove(g.name); + return true; + }); + } + if (typeAst) { + if (!asts.empty()) { + mro.push_back(asts); + cls->rtti = true; + } + cls->mro = Cache::mergeC3(mro); + if (cls->mro.empty()) { + E(Error::CLASS_BAD_MRO, getSrcInfo()); } + } + return asts; +} - // Debug information - LOG_REALIZE("[class] {} -> {}", stmt->name, typ); - for (auto &m : ctx->cache->classes[stmt->name].fields) - LOG_REALIZE(" - member: {}: {}", m.name, m.type); +/// Find the first __init__ with self parameter and use it to deduce class members. +/// Each deduced member will be treated as generic. +/// @example +/// ```@deduce +/// class Foo: +/// def __init__(self): +/// self.x, self.y = 1, 2``` +/// will result in +/// ```class Foo[T1, T2]: +/// x: T1 +/// y: T2``` +/// @return the transformed init and the pointer to the original function. +void TypecheckVisitor::autoDeduceMembers(ClassStmt *stmt, std::vector &args) { + std::set members; + for (const auto &sp : getClassMethods(stmt->suite)) + if (auto f = cast(sp)) { + if (f->name == "__init__") + if (auto b = f->getAttribute(Attr::ClassDeduce)) { + f->setAttribute(Attr::RealizeWithoutSelf); + for (auto m : b->values) + members.insert(m); + } + } + for (auto m : members) { + auto genericName = fmt::format("T_{}", m); + args.emplace_back(genericName, N(TYPE_TYPE), N("NoneType"), + Param::Generic); + args.emplace_back(m, N(genericName)); + } } -/// Generate a tuple class `Tuple[T1,...,TN]`. -/// @param len Tuple length (`N`) -/// @param name Tuple name. `Tuple` by default. -/// Can be something else (e.g., `KwTuple`) -/// @param names Member names. By default `item1`...`itemN`. -/// @param hasSuffix Set if the tuple name should have `.N` suffix. -std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name, - std::vector names, - bool hasSuffix) { +/// Return a list of all statements within a given class suite. +/// Checks each suite recursively, and assumes that each statement is either +/// a function, a class or a docstring. +std::vector TypecheckVisitor::getClassMethods(Stmt *s) { + std::vector v; + if (!s) + return v; + if (auto sp = cast(s)) { + for (auto *ss : *sp) + for (auto *u : getClassMethods(ss)) + v.push_back(u); + } else if (cast(s) || cast(s)) { + v.push_back(s); + } else if (!match(s, M(M()))) { + E(Error::CLASS_BAD_ATTR, s); + } + return v; +} + +/// Extract nested classes and transform them before the main class. +void TypecheckVisitor::transformNestedClasses(ClassStmt *stmt, + std::vector &clsStmts, + std::vector &varStmts, + std::vector &fnStmts) { + for (const auto &sp : getClassMethods(stmt->suite)) + if (auto cp = cast(sp)) { + auto origName = cp->getName(); + // If class B is nested within A, it's name is always A.B, never B itself. + // Ensure that parent class name is appended + auto parentName = stmt->getName(); + cp->name = fmt::format("{}.{}", parentName, origName); + auto tsp = transform(cp); + std::string name; + if (auto tss = cast(tsp)) { + for (auto &s : *tss) + if (auto c = cast(s)) { + clsStmts.push_back(s); + name = c->getName(); + } else if (auto a = cast(s)) { + varStmts.push_back(s); + } else { + fnStmts.push_back(s); + } + ctx->add(origName, ctx->forceFind(name)); + } + } +} + +/// Generate a magic method `__op__` for each magic `op` +/// described by @param typExpr and its arguments. +/// Currently generate: +/// @li Constructors: __new__, __init__ +/// @li Utilities: __raw__, __hash__, __repr__, __tuplesize__, __add__, __mul__, __len__ +/// @li Iteration: __iter__, __getitem__, __len__, __contains__ +/// @li Comparisons: __eq__, __ne__, __lt__, __le__, __gt__, __ge__ +/// @li Pickling: __pickle__, __unpickle__ +/// @li Python: __to_py__, __from_py__ +/// @li GPU: __to_gpu__, __from_gpu__, __from_gpu_new__ +/// TODO: move to Codon as much as possible +Stmt *TypecheckVisitor::codegenMagic(const std::string &op, Expr *typExpr, + const std::vector &allArgs, bool isRecord) { +#define I(s) N(s) +#define NS(x) N(N("__magic__"), (x)) + seqassert(typExpr, "typExpr is null"); + Expr *ret = nullptr; + std::vector fargs; + std::vector stmts; + std::vector attrs{"autogenerated"}; + + std::vector args; + args.reserve(allArgs.size()); + for (auto &a : allArgs) + args.push_back(clone(a)); + + if (op == "new") { + ret = clone(typExpr); + if (isRecord) { + // Tuples: def __new__() -> T (internal) + for (auto &a : args) + fargs.emplace_back(a.getName(), clone(a.getType()), clone(a.getDefault())); + attrs.push_back(Attr::Internal); + } else { + // Classes: def __new__() -> T + stmts.emplace_back(N(N(NS(op), clone(typExpr)))); + } + } else if (op == "init") { + // Classes: def __init__(self: T, a1: T1, ..., aN: TN) -> None: + // self.aI = aI ... + ret = I("NoneType"); + fargs.emplace_back("self", clone(typExpr)); + for (auto &a : args) { + fargs.emplace_back(a.getName(), clean_clone(a.getType()), clone(a.getDefault())); + stmts.push_back( + N(N(I("self"), a.getName()), I(a.getName()))); + } + } else if (op == "raw" || op == "dict") { + // Classes: def __raw__(self: T) + fargs.emplace_back("self", clone(typExpr)); + stmts.emplace_back(N(N(NS(op), I("self")))); + } else if (op == "tuplesize") { + // def __tuplesize__() -> int + ret = I("int"); + stmts.emplace_back(N(N(NS(op)))); + } else if (op == "getitem") { + // Tuples: def __getitem__(self: T, index: int) + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("index", I("int")); + stmts.emplace_back(N(N(NS(op), I("self"), I("index")))); + } else if (op == "iter") { + // Tuples: def __iter__(self: T) + fargs.emplace_back("self", clone(typExpr)); + stmts.emplace_back(N(N(NS(op), I("self")))); + } else if (op == "contains") { + // Tuples: def __contains__(self: T, what) -> bool + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("what", nullptr); + ret = I("bool"); + stmts.emplace_back(N(N(NS(op), I("self"), I("what")))); + } else if (op == "eq" || op == "ne" || op == "lt" || op == "le" || op == "gt" || + op == "ge") { + // def __op__(self: T, obj: T) -> bool + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("obj", clone(typExpr)); + ret = I("bool"); + stmts.emplace_back(N(N(NS(op), I("self"), I("obj")))); + } else if (op == "hash" || op == "len") { + // def __hash__(self: T) -> int + fargs.emplace_back("self", clone(typExpr)); + ret = I("int"); + stmts.emplace_back(N(N(NS(op), I("self")))); + } else if (op == "pickle") { + // def __pickle__(self: T, dest: Ptr[byte]) + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("dest", N(I("Ptr"), I("byte"))); + stmts.emplace_back(N(N(NS(op), I("self"), I("dest")))); + } else if (op == "unpickle" || op == "from_py") { + // def __unpickle__(src: Ptr[byte]) -> T + fargs.emplace_back("src", N(I("Ptr"), I("byte"))); + ret = clone(typExpr); + stmts.emplace_back(N(N(NS(op), I("src"), clone(typExpr)))); + } else if (op == "to_py") { + // def __to_py__(self: T) -> Ptr[byte] + fargs.emplace_back("self", clone(typExpr)); + ret = N(I("Ptr"), I("byte")); + stmts.emplace_back(N(N(NS(op), I("self")))); + } else if (op == "to_gpu") { + // def __to_gpu__(self: T, cache) -> T + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("cache"); + ret = clone(typExpr); + stmts.emplace_back(N(N(NS(op), I("self"), I("cache")))); + } else if (op == "from_gpu") { + // def __from_gpu__(self: T, other: T) + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("other", clone(typExpr)); + stmts.emplace_back(N(N(NS(op), I("self"), I("other")))); + } else if (op == "from_gpu_new") { + // def __from_gpu_new__(other: T) -> T + fargs.emplace_back("other", clone(typExpr)); + ret = clone(typExpr); + stmts.emplace_back(N(N(NS(op), I("other")))); + } else if (op == "repr") { + // def __repr__(self: T) -> str + fargs.emplace_back("self", clone(typExpr)); + ret = I("str"); + stmts.emplace_back(N(N(NS(op), I("self")))); + } else if (op == "add") { + // def __add__(self, obj) + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("obj", nullptr); + stmts.emplace_back(N(N(NS(op), I("self"), I("obj")))); + } else if (op == "mul") { + // def __mul__(self, i: Static[int]) + fargs.emplace_back("self", clone(typExpr)); + fargs.emplace_back("i", N(I("Static"), I("int"))); + stmts.emplace_back(N(N(NS(op), I("self"), I("i")))); + } else { + seqassert(false, "invalid magic {}", op); + } +#undef I +#undef NS + auto t = NC(format("__{}__", op), ret, fargs, NC(stmts)); + for (auto &a : attrs) + t->setAttribute(a); + t->setSrcInfo(ctx->cache->generateSrcInfo()); + return t; +} + +int TypecheckVisitor::generateKwId(const std::vector &names) { auto key = join(names, ";"); std::string suffix; if (!names.empty()) { // Each set of names generates different tuple (i.e., `KwArgs[foo, bar]` is not the // same as `KwArgs[bar, baz]`). Cache the names and use an integer for each name // combination. - if (!in(ctx->cache->generatedTuples, key)) - ctx->cache->generatedTuples[key] = int(ctx->cache->generatedTuples.size()); - suffix = format("_{}", ctx->cache->generatedTuples[key]); + if (!in(ctx->cache->generatedTuples, key)) { + ctx->cache->generatedTupleNames.push_back(names); + ctx->cache->generatedTuples[key] = int(ctx->cache->generatedTuples.size()) + 1; + } + return ctx->cache->generatedTuples[key]; } else { - for (size_t i = 1; i <= len; i++) - names.push_back(format("item{}", i)); + return 0; } +} - auto typeName = format("{}{}", name, hasSuffix ? format("{}{}", len, suffix) : ""); - if (!ctx->find(typeName)) { - // Generate the appropriate ClassStmt - std::vector args; - for (size_t i = 0; i < len; i++) - args.emplace_back(Param(names[i], N(format("T{}", i + 1)), nullptr)); - for (size_t i = 0; i < len; i++) - args.emplace_back(Param(format("T{}", i + 1), N("type"), nullptr, true)); - StmtPtr stmt = N(ctx->cache->generateSrcInfo(), typeName, args, nullptr, - std::vector{N("tuple")}); - - // Add helpers for KwArgs: - // `def __getitem__(self, key: Static[str]): return getattr(self, key)` - // `def __contains__(self, key: Static[str]): return hasattr(self, key)` - auto getItem = N( - "__getitem__", nullptr, - std::vector{Param{"self"}, Param{"key", N(N("Static"), - N("str"))}}, - N(N( - N(N("getattr"), N("self"), N("key"))))); - auto contains = N( - "__contains__", nullptr, - std::vector{Param{"self"}, Param{"key", N(N("Static"), - N("str"))}}, - N(N( - N(N("hasattr"), N("self"), N("key"))))); - auto getDef = N( - "get", nullptr, - std::vector{ - Param{"self"}, - Param{"key", N(N("Static"), N("str"))}, - Param{"default", nullptr, N(N("NoneType"))}}, - N(N( - N(N(N("__internal__"), "kwargs_get"), - N("self"), N("key"), N("default"))))); - if (startswith(typeName, TYPE_KWTUPLE)) - stmt->getClass()->suite = N(getItem, contains, getDef); - - // Add repr and call for partials: - // `def __repr__(self): return __magic__.repr_partial(self)` - auto repr = N( - "__repr__", nullptr, std::vector{Param{"self"}}, - N(N(N( - N(N("__magic__"), "repr_partial"), N("self"))))); - auto pcall = N( - "__call__", nullptr, - std::vector{Param{"self"}, Param{"*args"}, Param{"**kwargs"}}, - N( - N(N(N("self"), N(N("args")), - N(N("kwargs")))))); - if (startswith(typeName, TYPE_PARTIAL)) - stmt->getClass()->suite = N(repr, pcall); - - // Simplify in the standard library context and type check - stmt = SimplifyVisitor::apply(ctx->cache->imports[STDLIB_IMPORT].ctx, stmt, - FILE_GENERATED, 0); - stmt = TypecheckVisitor(ctx).transform(stmt); - prependStmts->push_back(stmt); +types::ClassType *TypecheckVisitor::generateTuple(size_t n, bool generateNew) { + static std::unordered_set funcArgTypes; + + if (n > MAX_TUPLE) + E(Error::CUSTOM, getSrcInfo(), "tuple too large ({})", n); + + auto key = fmt::format("{}.{}", TYPE_TUPLE, n); + auto val = getImport(STDLIB_IMPORT)->ctx->find(key); + if (!val) { + auto t = std::make_shared(ctx->cache, TYPE_TUPLE, TYPE_TUPLE); + t->isTuple = true; + auto cls = getClass(t.get()); + seqassert(n <= cls->fields.size(), "tuple too large"); + for (size_t i = 0; i < n; i++) { + const auto &f = cls->fields[i]; + auto gt = f.getType()->getLink(); + t->generics.emplace_back(cast(f.typeExpr)->getValue(), gt->genericName, + f.type, gt->id, 0); + } + val = getImport(STDLIB_IMPORT)->ctx->addType(key, key, t); + } + auto t = val->getType()->getClass(); + if (generateNew && !in(funcArgTypes, n)) { + funcArgTypes.insert(n); + std::vector newFnArgs; + std::vector typeArgs; + for (size_t i = 0; i < n; i++) { + newFnArgs.emplace_back(format("item{}", i + 1), N(format("T{}", i + 1))); + typeArgs.emplace_back(N(format("T{}", i + 1))); + } + for (size_t i = 0; i < n; i++) { + newFnArgs.emplace_back(format("T{}", i + 1), N(TYPE_TYPE)); + } + Stmt *fn = N( + "__new__", N(N(TYPE_TUPLE), N(typeArgs)), + newFnArgs, nullptr); + fn->setAttribute(Attr::Internal); + Stmt *ext = N(TYPE_TUPLE, std::vector{}, fn); + ext->setAttribute(Attr::Extend); + ext = N(ext); + + llvm::cantFail(ScopingVisitor::apply(ctx->cache, ext)); + auto rctx = getImport(STDLIB_IMPORT)->ctx; + auto oldBases = rctx->bases; + rctx->bases.clear(); + rctx->bases.push_back(oldBases[0]); + ext = TypecheckVisitor::apply(rctx, ext); + rctx->bases = oldBases; + preamble->push_back(ext); } - return typeName; + return t; } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/collections.cpp b/codon/parser/visitors/typecheck/collections.cpp index 8f714f32..5d84e951 100644 --- a/codon/parser/visitors/typecheck/collections.cpp +++ b/codon/parser/visitors/typecheck/collections.cpp @@ -3,7 +3,6 @@ #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -13,34 +12,142 @@ namespace codon::ast { using namespace types; +/// Transform tuples. +/// @example +/// `(a1, ..., aN)` -> `Tuple.__new__(a1, ..., aN)` +void TypecheckVisitor::visit(TupleExpr *expr) { + resultExpr = + transform(N(N(N(TYPE_TUPLE), "__new__"), expr->items)); +} + /// Transform a list `[a1, ..., aN]` to the corresponding statement expression. /// See @c transformComprehension void TypecheckVisitor::visit(ListExpr *expr) { - expr->setType(ctx->getUnbound()); - if ((resultExpr = transformComprehension("std.internal.types.ptr.List", "append", - expr->items))) { - resultExpr->setAttr(ExprAttr::List); + expr->setType(instantiateUnbound()); + auto name = getStdLibType("List")->name; + if ((resultExpr = transformComprehension(name, "append", expr->items))) { + resultExpr->setAttribute(Attr::ExprList); } } /// Transform a set `{a1, ..., aN}` to the corresponding statement expression. /// See @c transformComprehension void TypecheckVisitor::visit(SetExpr *expr) { - expr->setType(ctx->getUnbound()); - auto name = ctx->cache->imports[STDLIB_IMPORT].ctx->forceFind("Set"); - if ((resultExpr = transformComprehension(name->canonicalName, "add", expr->items))) { - resultExpr->setAttr(ExprAttr::Set); + expr->setType(instantiateUnbound()); + auto name = getStdLibType("Set")->name; + if ((resultExpr = transformComprehension(name, "add", expr->items))) { + resultExpr->setAttribute(Attr::ExprSet); } } /// Transform a dictionary `{k1: v1, ..., kN: vN}` to a corresponding statement /// expression. See @c transformComprehension void TypecheckVisitor::visit(DictExpr *expr) { - expr->setType(ctx->getUnbound()); - auto name = ctx->cache->imports[STDLIB_IMPORT].ctx->forceFind("Dict"); - if ((resultExpr = - transformComprehension(name->canonicalName, "__setitem__", expr->items))) { - resultExpr->setAttr(ExprAttr::Dict); + expr->setType(instantiateUnbound()); + auto name = getStdLibType("Dict")->name; + if ((resultExpr = transformComprehension(name, "__setitem__", expr->items))) { + resultExpr->setAttribute(Attr::ExprDict); + } +} + +/// Transform a tuple generator expression. +/// @example +/// `tuple(expr for i in tuple_generator)` -> `Tuple.N.__new__(expr...)` +void TypecheckVisitor::visit(GeneratorExpr *expr) { + // List comprehension optimization: + // Use `iter.__len__()` when creating list if there is a single for loop + // without any if conditions in the comprehension + bool canOptimize = + expr->kind == GeneratorExpr::ListGenerator && expr->loopCount() == 1; + if (canOptimize) { + auto iter = transform(clone(cast(expr->getFinalSuite())->getIter())); + auto ce = cast(iter); + IdExpr *id = nullptr; + if (ce && (id = cast(ce->getExpr()))) { + // Turn off this optimization for static items + canOptimize &= + !startswith(id->getValue(), "std.internal.types.range.staticrange"); + canOptimize &= !startswith(id->getValue(), "statictuple"); + } + } + + Expr *var = N(getTemporaryVar("gen")); + if (expr->kind == GeneratorExpr::ListGenerator) { + // List comprehensions + expr->setFinalExpr( + N(N(clone(var), "append"), expr->getFinalExpr())); + auto suite = expr->getFinalSuite(); + auto noOptStmt = + N(N(clone(var), N(N("List"))), suite); + + if (canOptimize) { + auto optimizeVar = getTemporaryVar("i"); + auto origIter = cast(expr->getFinalSuite())->getIter(); + + auto optStmt = clone(noOptStmt); + cast((*cast(optStmt))[1])->iter = N(optimizeVar); + optStmt = N( + N(N(optimizeVar), clone(origIter)), + N( + clone(var), + N(N("List"), + N(N(N(optimizeVar), "__len__")))), + (*cast(optStmt))[1]); + resultExpr = N( + N(N("hasattr"), clone(origIter), N("__len__")), + N(optStmt, clone(var)), N(noOptStmt, var)); + } else { + resultExpr = N(noOptStmt, var); + } + resultExpr = transform(resultExpr); + } else if (expr->kind == GeneratorExpr::SetGenerator) { + // Set comprehensions + auto head = N(clone(var), N(N("Set"))); + expr->setFinalExpr( + N(N(clone(var), "add"), expr->getFinalExpr())); + auto suite = expr->getFinalSuite(); + resultExpr = transform(N(N(head, suite), var)); + } else if (expr->kind == GeneratorExpr::DictGenerator) { + // Dictionary comprehensions + auto head = N(clone(var), N(N("Dict"))); + expr->setFinalExpr(N(N(clone(var), "__setitem__"), + N(expr->getFinalExpr()))); + auto suite = expr->getFinalSuite(); + resultExpr = transform(N(N(head, suite), var)); + } else if (expr->kind == GeneratorExpr::TupleGenerator) { + seqassert(expr->loopCount() == 1, "invalid tuple generator"); + auto gen = transform(cast(expr->getFinalSuite())->getIter()); + if (!gen->getType()->canRealize()) + return; // Wait until the iterator can be realized + + auto block = N(); + // `tuple = tuple_generator` + auto tupleVar = getTemporaryVar("tuple"); + block->addStmt(N(N(tupleVar), gen)); + + auto forStmt = clone(cast(expr->getFinalSuite())); + auto finalExpr = expr->getFinalExpr(); + auto [ok, delay, preamble, staticItems] = transformStaticLoopCall( + cast(expr->getFinalSuite())->getVar(), &forStmt->suite, gen, + [&](Stmt *wrap) { return N(clone(wrap), clone(finalExpr)); }, true); + if (!ok) + E(Error::CALL_BAD_ITER, gen, gen->getType()->prettyString()); + if (delay) + return; + + std::vector tupleItems; + for (auto &i : staticItems) + tupleItems.push_back(cast(i)); + if (preamble) + block->addStmt(preamble); + resultExpr = transform(N(block, N(tupleItems))); + } else { + expr->loops = + transform(expr->getFinalSuite()); // assume: internal data will be changed + unify(expr->getType(), instantiateType(getStdLibType("Generator"), + {expr->getFinalExpr()->getType()})); + if (realize(expr->getType())) + expr->setDone(); } } @@ -60,33 +167,35 @@ void TypecheckVisitor::visit(DictExpr *expr) { /// `{a: 1, **d}` -> ```cont = Dict() /// cont.__setitem__((a, 1)) /// for i in b.items(): cont.__setitem__((i[0], i[i]))``` -ExprPtr TypecheckVisitor::transformComprehension(const std::string &type, - const std::string &fn, - std::vector &items) { +Expr *TypecheckVisitor::transformComprehension(const std::string &type, + const std::string &fn, + std::vector &items) { // Deduce the super type of the collection--- in other words, the least common // ancestor of all types in the collection. For example, `type([1, 1.2]) == type([1.2, // 1]) == float` because float is an "ancestor" of int. - auto superTyp = [&](const ClassTypePtr &collectionCls, - const ClassTypePtr &ti) -> ClassTypePtr { + // TOOD: use wrapExpr... + auto superTyp = [&](ClassType *collectionCls, ClassType *ti) -> TypePtr { if (!collectionCls) - return ti; + return ti->shared_from_this(); if (collectionCls->is("int") && ti->is("float")) { // Rule: int derives from float - return ti; + return ti->shared_from_this(); } else if (collectionCls->name != TYPE_OPTIONAL && ti->name == TYPE_OPTIONAL) { // Rule: T derives from Optional[T] - return ctx->instantiateGeneric(ctx->getType("Optional"), {collectionCls}) - ->getClass(); + return instantiateType(getStdLibType("Optional"), + std::vector{collectionCls}); + } else if (collectionCls->name == TYPE_OPTIONAL && ti->name != TYPE_OPTIONAL) { + return instantiateType(getStdLibType("Optional"), std::vector{ti}); } else if (!collectionCls->is("pyobj") && ti->is("pyobj")) { // Rule: anything derives from pyobj - return ti; + return ti->shared_from_this(); } else if (collectionCls->name != ti->name) { // Rule: subclass derives from superclass - auto &mros = ctx->cache->classes[collectionCls->name].mro; + const auto &mros = getClass(collectionCls)->mro; for (size_t i = 1; i < mros.size(); i++) { - auto t = ctx->instantiate(mros[i]->type, collectionCls); - if (t->unify(ti.get(), nullptr) >= 0) { - return ti; + auto t = instantiateType(mros[i].get(), collectionCls); + if (t->unify(ti, nullptr) >= 0) { + return ti->shared_from_this(); break; } } @@ -94,207 +203,108 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type, return nullptr; }; - TypePtr collectionTyp = ctx->getUnbound(); + TypePtr collectionTyp = instantiateUnbound(); bool done = true; - bool isDict = endswith(type, "Dict"); + bool isDict = type == getStdLibType("Dict")->name; for (auto &i : items) { - ClassTypePtr typ = nullptr; - if (!isDict && i->getStar()) { - auto star = i->getStar(); - star->what = transform(N(N(star->what, "__iter__"))); - if (star->what->type->is("Generator")) - typ = star->what->type->getClass()->generics[0].type->getClass(); - } else if (isDict && CAST(i, KeywordStarExpr)) { - auto star = CAST(i, KeywordStarExpr); - star->what = transform(N(N(star->what, "items"))); - if (star->what->type->is("Generator")) - typ = star->what->type->getClass()->generics[0].type->getClass(); + ClassType *typ = nullptr; + if (!isDict && cast(i)) { + auto star = cast(i); + star->expr = transform(N(N(star->getExpr(), "__iter__"))); + if (star->getExpr()->getType()->is("Generator")) + typ = extractClassGeneric(star->getExpr()->getType())->getClass(); + } else if (isDict && cast(i)) { + auto star = cast(i); + star->expr = transform(N(N(star->getExpr(), "items"))); + if (star->getExpr()->getType()->is("Generator")) + typ = extractClassGeneric(star->getExpr()->getType())->getClass(); } else { i = transform(i); - typ = i->type->getClass(); + typ = i->getClassType(); } if (!typ) { done = false; continue; } if (!collectionTyp->getClass()) { - unify(collectionTyp, typ); + unify(collectionTyp.get(), typ); } else if (!isDict) { if (auto t = superTyp(collectionTyp->getClass(), typ)) collectionTyp = t; } else { - seqassert(collectionTyp->getRecord() && - collectionTyp->getRecord()->args.size() == 2, + auto tt = unify(typ, instantiateType(generateTuple(2)))->getClass(); + seqassert(collectionTyp->getClass() && + collectionTyp->getClass()->generics.size() == 2 && + tt->generics.size() == 2, "bad dict"); - auto tt = unify(typ, ctx->instantiateTuple(2))->getRecord(); - auto nt = collectionTyp->getRecord()->args; + std::vector nt; for (int di = 0; di < 2; di++) { + nt.push_back(extractClassGeneric(collectionTyp.get(), di)->shared_from_this()); if (!nt[di]->getClass()) - unify(nt[di], tt->args[di]); - else if (auto dt = superTyp(nt[di]->getClass(), tt->args[di]->getClass())) + unify(nt[di].get(), extractClassGeneric(tt, di)); + else if (auto dt = superTyp(nt[di]->getClass(), + extractClassGeneric(tt, di)->getClass())) nt[di] = dt; } - collectionTyp = ctx->instantiateTuple(nt); + collectionTyp = + instantiateType(generateTuple(nt.size()), ctx->cache->castVectorPtr(nt)); } } if (!done) return nullptr; - std::vector stmts; - ExprPtr var = N(ctx->cache->getTemporaryVar("cont")); + std::vector stmts; + Expr *var = N(getTemporaryVar("cont")); - std::vector constructorArgs{}; - if (endswith(type, "List") && !items.empty()) { + std::vector constructorArgs{}; + if (type == getStdLibType("List")->name && !items.empty()) { // Optimization: pre-allocate the list with the exact number of elements constructorArgs.push_back(N(items.size())); } - auto t = NT(type); - if (isDict && collectionTyp->getRecord()) { - t->setType( - ctx->instantiateGeneric(ctx->getType(type), collectionTyp->getRecord()->args)); - } else if (isDict) { - t->setType(ctx->instantiate(ctx->getType(type))); - } else { - t->setType(ctx->instantiateGeneric(ctx->getType(type), {collectionTyp})); + auto t = N(type); + auto ta = instantiateType(getStdLibType(type)); + if (isDict && collectionTyp->getClass()) { + seqassert(collectionTyp->getClass()->isRecord(), "bad dict"); + std::vector nt; + for (auto &g : collectionTyp->getClass()->generics) + nt.push_back(g.getType()); + ta = instantiateType(getStdLibType(type), nt); + } else if (!isDict) { + ta = instantiateType(getStdLibType(type), {collectionTyp.get()}); } - stmts.push_back( - transform(N(clone(var), N(t, constructorArgs)))); + t->setType(instantiateTypeVar(ta.get())); + stmts.push_back(N(clone(var), N(t, constructorArgs))); for (const auto &it : items) { - if (!isDict && it->getStar()) { + if (!isDict && cast(it)) { // Unpack star-expression by iterating over it // `*star` -> `for i in star: cont.[fn](i)` - auto star = it->getStar(); - ExprPtr forVar = N(ctx->cache->getTemporaryVar("i")); - star->what->setAttr(ExprAttr::StarSequenceItem); - stmts.push_back(transform(N( - clone(forVar), star->what, - N(N(N(clone(var), fn), clone(forVar)))))); - } else if (isDict && CAST(it, KeywordStarExpr)) { + auto star = cast(it); + Expr *forVar = N(getTemporaryVar("i")); + star->getExpr()->setAttribute(Attr::ExprStarSequenceItem); + stmts.push_back(N( + clone(forVar), star->getExpr(), + N(N(N(clone(var), fn), clone(forVar))))); + } else if (isDict && cast(it)) { // Expand kwstar-expression by iterating over it: see the example above - auto star = CAST(it, KeywordStarExpr); - ExprPtr forVar = N(ctx->cache->getTemporaryVar("it")); - star->what->setAttr(ExprAttr::StarSequenceItem); - stmts.push_back(transform(N( - clone(forVar), star->what, + auto star = cast(it); + Expr *forVar = N(getTemporaryVar("it")); + star->getExpr()->setAttribute(Attr::ExprStarSequenceItem); + stmts.push_back(N( + clone(forVar), star->getExpr(), N(N(N(clone(var), fn), N(clone(forVar), N(0)), - N(clone(forVar), N(1))))))); + N(clone(forVar), N(1)))))); } else { - it->setAttr(ExprAttr::SequenceItem); + it->setAttribute(Attr::ExprSequenceItem); if (isDict) { - stmts.push_back(transform(N( - N(N(clone(var), fn), N(it, N(0)), - N(it, N(1)))))); + stmts.push_back(N(N(N(clone(var), fn), + N(it, N(0)), + N(it, N(1))))); } else { - stmts.push_back( - transform(N(N(N(clone(var), fn), it)))); + stmts.push_back(N(N(N(clone(var), fn), it))); } } } return transform(N(stmts, var)); } -/// Transform tuples. -/// Generate tuple classes (e.g., `Tuple`) if not available. -/// @example -/// `(a1, ..., aN)` -> `Tuple.__new__(a1, ..., aN)` -void TypecheckVisitor::visit(TupleExpr *expr) { - expr->setType(ctx->getUnbound()); - for (int ai = 0; ai < expr->items.size(); ai++) - if (auto star = expr->items[ai]->getStar()) { - // Case: unpack star expressions (e.g., `*arg` -> `arg.item1, arg.item2, ...`) - transform(star->what); - auto typ = star->what->type->getClass(); - while (typ && typ->is(TYPE_OPTIONAL)) { - star->what = transform(N(N(FN_UNWRAP), star->what)); - typ = star->what->type->getClass(); - } - if (!typ) - return; // continue later when the type becomes known - if (!typ->getRecord()) - E(Error::CALL_BAD_UNPACK, star, typ->prettyString()); - auto ff = getClassFields(typ.get()); - for (int i = 0; i < typ->getRecord()->args.size(); i++, ai++) { - expr->items.insert(expr->items.begin() + ai, - transform(N(clone(star->what), ff[i].name))); - } - // Remove the star - expr->items.erase(expr->items.begin() + ai); - ai--; - } else { - expr->items[ai] = transform(expr->items[ai]); - } - auto s = ctx->generateTuple(expr->items.size()); - resultExpr = transform(N(N(s), clone(expr->items))); - unify(expr->type, resultExpr->type); -} - -/// Transform a tuple generator expression. -/// @example -/// `tuple(expr for i in tuple_generator)` -> `Tuple.__new__(expr...)` -void TypecheckVisitor::visit(GeneratorExpr *expr) { - seqassert(expr->kind == GeneratorExpr::Generator && expr->loops.size() == 1 && - expr->loops[0].conds.empty(), - "invalid tuple generator"); - - unify(expr->type, ctx->getUnbound()); - - auto gen = transform(expr->loops[0].gen); - if (!gen->type->canRealize()) - return; // Wait until the iterator can be realized - - auto block = N(); - // `tuple = tuple_generator` - auto tupleVar = ctx->cache->getTemporaryVar("tuple"); - block->stmts.push_back(N(N(tupleVar), gen)); - - seqassert(expr->loops[0].vars->getId(), "tuple() not simplified"); - std::vector vars{expr->loops[0].vars->getId()->value}; - auto suiteVec = expr->expr->getStmtExpr() - ? expr->expr->getStmtExpr()->stmts[0]->getSuite() - : nullptr; - auto oldSuite = suiteVec ? suiteVec->clone() : nullptr; - for (int validI = 0; suiteVec && validI < suiteVec->stmts.size(); validI++) { - if (auto a = suiteVec->stmts[validI]->getAssign()) - if (a->rhs && a->rhs->getIndex()) - if (a->rhs->getIndex()->expr->isId(vars[0])) { - vars.push_back(a->lhs->getId()->value); - suiteVec->stmts[validI] = nullptr; - continue; - } - break; - } - if (vars.size() > 1) - vars.erase(vars.begin()); - auto [ok, staticItems] = - transformStaticLoopCall(vars, expr->loops[0].gen, [&](StmtPtr wrap) { - return N(wrap, clone(expr->expr)); - }); - if (ok) { - std::vector tupleItems; - for (auto &i : staticItems) - tupleItems.push_back(std::dynamic_pointer_cast(i)); - resultExpr = transform(N(block, N(tupleItems))); - return; - } else if (oldSuite) { - expr->expr->getStmtExpr()->stmts[0] = oldSuite; - } - - auto tuple = gen->type->getRecord(); - if (!tuple || !(tuple->name == TYPE_TUPLE || startswith(tuple->name, TYPE_KWTUPLE))) - E(Error::CALL_BAD_ITER, gen, gen->type->prettyString()); - - // `a := tuple[i]; expr...` for each i - std::vector items; - items.reserve(tuple->args.size()); - for (int ai = 0; ai < tuple->args.size(); ai++) { - items.emplace_back( - N(N(clone(expr->loops[0].vars), - N(N(tupleVar), N(ai))), - clone(expr->expr))); - } - - // `((a := tuple[0]; expr), (a := tuple[1]; expr), ...)` - resultExpr = transform(N(block, N(items))); -} - } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/cond.cpp b/codon/parser/visitors/typecheck/cond.cpp index 05d8d23d..2356d3bf 100644 --- a/codon/parser/visitors/typecheck/cond.cpp +++ b/codon/parser/visitors/typecheck/cond.cpp @@ -2,10 +2,10 @@ #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; +using namespace codon::error; namespace codon::ast { @@ -14,66 +14,71 @@ using namespace types; /// Call `ready` and `notReady` depending whether the provided static expression can be /// evaluated or not. template -auto evaluateStaticCondition(const ExprPtr &cond, TT ready, TF notReady) { - seqassertn(cond->isStatic(), "not a static condition"); - if (cond->staticValue.evaluated) { +auto evaluateStaticCondition(Expr *cond, TT ready, TF notReady) { + seqassertn(cond->getType()->isStaticType(), "not a static condition"); + if (cond->getType()->canRealize()) { bool isTrue = false; - if (cond->staticValue.type == StaticValue::STRING) - isTrue = !cond->staticValue.getString().empty(); - else - isTrue = cond->staticValue.getInt(); + if (auto as = cond->getType()->getStrStatic()) + isTrue = !as->value.empty(); + else if (auto ai = cond->getType()->getIntStatic()) + isTrue = ai->value; + else if (auto ab = cond->getType()->getBoolStatic()) + isTrue = ab->value; return ready(isTrue); } else { return notReady(); } } +/// Only allowed in @c MatchStmt +void TypecheckVisitor::visit(RangeExpr *expr) { + E(Error::UNEXPECTED_TYPE, expr, "range"); +} + /// Typecheck if expressions. Evaluate static if blocks if possible. /// Also wrap the condition with `__bool__()` if needed and wrap both conditional /// expressions. See @c wrapExpr for more details. void TypecheckVisitor::visit(IfExpr *expr) { - transform(expr->cond); - + // C++ call order is not defined; make sure to transform the conditional first + expr->cond = transform(expr->getCond()); // Static if evaluation - if (expr->cond->isStatic()) { + if (expr->getCond()->getType()->isStaticType()) { resultExpr = evaluateStaticCondition( - expr->cond, + expr->getCond(), [&](bool isTrue) { LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue); - return transform(isTrue ? expr->ifexpr : expr->elsexpr); + return transform(isTrue ? expr->getIf() : expr->getElse()); }, - [&]() -> ExprPtr { - // Check if both subexpressions are static; if so, this if expression is also - // static and should be marked as such - auto i = transform(clone(expr->ifexpr)); - auto e = transform(clone(expr->elsexpr)); - if (i->isStatic() && e->isStatic()) { - expr->staticValue.type = i->staticValue.type; - unify(expr->type, - ctx->getType(expr->staticValue.type == StaticValue::INT ? "int" - : "str")); - } - return nullptr; - }); + [&]() -> Expr * { return nullptr; }); if (resultExpr) - unify(expr->type, resultExpr->getType()); + unify(expr->getType(), resultExpr->getType()); else - unify(expr->type, ctx->getUnbound()); + expr->getType()->getUnbound()->isStatic = 1; // TODO: determine later! return; } - transform(expr->ifexpr); - transform(expr->elsexpr); + expr->ifexpr = transform(expr->getIf()); + expr->elsexpr = transform(expr->getElse()); + // Add __bool__ wrapper - while (expr->cond->type->getClass() && !expr->cond->type->is("bool")) - expr->cond = transform(N(N(expr->cond, "__bool__"))); + while (expr->getCond()->getClassType() && !expr->getCond()->getType()->is("bool")) + expr->cond = transform(N(N(expr->getCond(), "__bool__"))); // Add wrappers and unify both sides - wrapExpr(expr->elsexpr, expr->ifexpr->getType(), nullptr, /*allowUnwrap*/ false); - wrapExpr(expr->ifexpr, expr->elsexpr->getType(), nullptr, /*allowUnwrap*/ false); - unify(expr->type, expr->ifexpr->getType()); - unify(expr->type, expr->elsexpr->getType()); + if (expr->getIf()->getType()->getStatic()) + expr->getIf()->setType( + expr->getIf()->getType()->getStatic()->getNonStaticType()->shared_from_this()); + if (expr->getElse()->getType()->getStatic()) + expr->getElse()->setType(expr->getElse() + ->getType() + ->getStatic() + ->getNonStaticType() + ->shared_from_this()); + wrapExpr(&expr->elsexpr, expr->getIf()->getType(), nullptr, /*allowUnwrap*/ false); + wrapExpr(&expr->ifexpr, expr->getElse()->getType(), nullptr, /*allowUnwrap*/ false); - if (expr->cond->isDone() && expr->ifexpr->isDone() && expr->elsexpr->isDone()) + unify(expr->getType(), expr->getIf()->getType()); + unify(expr->getType(), expr->getElse()->getType()); + if (expr->getCond()->isDone() && expr->getIf()->isDone() && expr->getElse()->isDone()) expr->setDone(); } @@ -81,31 +86,169 @@ void TypecheckVisitor::visit(IfExpr *expr) { /// Also wrap the condition with `__bool__()` if needed. /// See @c wrapExpr for more details. void TypecheckVisitor::visit(IfStmt *stmt) { - transform(stmt->cond); + stmt->cond = transform(stmt->getCond()); // Static if evaluation - if (stmt->cond->isStatic()) { + if (stmt->getCond()->getType()->isStaticType()) { resultStmt = evaluateStaticCondition( - stmt->cond, + stmt->getCond(), [&](bool isTrue) { LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue); - auto t = transform(isTrue ? stmt->ifSuite : stmt->elseSuite); + auto t = transform(isTrue ? stmt->getIf() : stmt->getElse()); return t ? t : transform(N()); }, - [&]() -> StmtPtr { return nullptr; }); + [&]() -> Stmt * { return nullptr; }); return; } - while (stmt->cond->type->getClass() && !stmt->cond->type->is("bool")) - stmt->cond = transform(N(N(stmt->cond, "__bool__"))); + while (stmt->getCond()->getClassType() && !stmt->getCond()->getType()->is("bool")) + stmt->cond = transform(N(N(stmt->getCond(), "__bool__"))); ctx->blockLevel++; - transform(stmt->ifSuite); - transform(stmt->elseSuite); + stmt->ifSuite = SuiteStmt::wrap(transform(stmt->getIf())); + stmt->elseSuite = SuiteStmt::wrap(transform(stmt->getElse())); ctx->blockLevel--; - if (stmt->cond->isDone() && (!stmt->ifSuite || stmt->ifSuite->isDone()) && - (!stmt->elseSuite || stmt->elseSuite->isDone())) + if (stmt->cond->isDone() && (!stmt->getIf() || stmt->getIf()->isDone()) && + (!stmt->getElse() || stmt->getElse()->isDone())) stmt->setDone(); } +/// Simplify match statement by transforming it into a series of conditional statements. +/// @example +/// ```match e: +/// case pattern1: ... +/// case pattern2 if guard: ... +/// ...``` -> +/// ```_match = e +/// while True: # used to simulate goto statement with break +/// [pattern1 transformation]: (...; break) +/// [pattern2 transformation]: if guard: (...; break) +/// ... +/// break # exit the loop no matter what``` +/// The first pattern that matches the given expression will be used; other patterns +/// will not be used (i.e., there is no fall-through). See @c transformPattern for +/// pattern transformations +void TypecheckVisitor::visit(MatchStmt *stmt) { + auto var = getTemporaryVar("match"); + auto result = N(); + result->addStmt(transform(N(N(var), clone(stmt->getExpr())))); + for (auto &c : *stmt) { + Stmt *suite = N(c.getSuite(), N()); + if (c.getGuard()) + suite = N(c.getGuard(), suite); + result->addStmt(transformPattern(N(var), c.getPattern(), suite)); + } + // Make sure to break even if there is no case _ to prevent infinite loop + result->addStmt(N()); + resultStmt = transform(N(N(true), result)); +} + +/// Transform a match pattern into a series of if statements. +/// @example +/// `case True` -> `if isinstance(var, "bool"): if var == True` +/// `case 1` -> `if isinstance(var, "int"): if var == 1` +/// `case 1...3` -> ```if isinstance(var, "int"): +/// if var >= 1: if var <= 3``` +/// `case (1, pat)` -> ```if isinstance(var, "Tuple"): if staticlen(var) == 2: +/// if match(var[0], 1): if match(var[1], pat)``` +/// `case [1, ..., pat]` -> ```if isinstance(var, "List"): if len(var) >= 2: +/// if match(var[0], 1): if match(var[-1], pat)``` +/// `case 1 or pat` -> `if match(var, 1): if match(var, pat)` +/// (note: pattern suite is cloned for each `or`) +/// `case (x := pat)` -> `(x := var; if match(var, pat))` +/// `case x` -> `(x := var)` +/// (only when `x` is not '_') +/// `case expr` -> `if hasattr(typeof(var), "__match__"): if +/// var.__match__(foo())` +/// (any expression that does not fit above patterns) +Stmt *TypecheckVisitor::transformPattern(Expr *var, Expr *pattern, Stmt *suite) { + // Convenience function to generate `isinstance(e, typ)` calls + auto isinstance = [&](Expr *e, const std::string &typ) -> Expr * { + return N(N("isinstance"), clone(e), N(typ)); + }; + // Convenience function to find the index of an ellipsis within a list pattern + auto findEllipsis = [&](const std::vector &items) { + size_t i = items.size(); + for (auto it = 0; it < items.size(); it++) + if (cast(items[it])) { + if (i != items.size()) + E(Error::MATCH_MULTI_ELLIPSIS, items[it], "multiple ellipses in pattern"); + i = it; + } + return i; + }; + + // See the above examples for transformation details + if (cast(pattern) || cast(pattern)) { + // Bool and int patterns + return N(isinstance(var, cast(pattern) ? "bool" : "int"), + N(N(var, "==", pattern), suite)); + } else if (auto er = cast(pattern)) { + // Range pattern + return N( + isinstance(var, "int"), + N(N(var, ">=", er->start), + N(N(clone(var), "<=", er->stop), suite))); + } else if (auto et = cast(pattern)) { + // Tuple pattern + for (auto it = et->items.size(); it-- > 0;) { + suite = + transformPattern(N(clone(var), N(it)), (*et)[it], suite); + } + return N(isinstance(var, "Tuple"), + N(N(N(N("staticlen"), var), + "==", N(et->size())), + suite)); + } else if (auto el = cast(pattern)) { + // List pattern + size_t ellipsis = findEllipsis(el->items), sz = el->size(); + std::string op; + if (ellipsis == el->size()) { + op = "=="; + } else { + op = ">=", sz -= 1; + } + for (auto it = el->size(); it-- > ellipsis + 1;) { + suite = transformPattern(N(clone(var), N(it - el->size())), + (*el)[it], suite); + } + for (auto it = ellipsis; it-- > 0;) { + suite = + transformPattern(N(clone(var), N(it)), (*el)[it], suite); + } + return N( + isinstance(var, "List"), + N(N(N(N("len"), var), op, N(sz)), + suite)); + } else if (auto eb = cast(pattern)) { + // Or pattern + if (eb->op == "|" || eb->op == "||") { + return N(transformPattern(clone(var), eb->lexpr, clone(suite)), + transformPattern(var, eb->rexpr, suite)); + } + } else if (auto ei = cast(pattern)) { + // Wildcard pattern + if (ei->value != "_") { + return N(N(pattern, var), suite); + } else { + return suite; + } + } else if (auto ea = cast(pattern)) { + // Bound pattern + seqassert(cast(ea->getVar()), + "only simple assignment expressions are supported"); + return N(N(ea->getVar(), clone(var)), + transformPattern(var, ea->getExpr(), suite)); + } + pattern = transform(pattern); // transform to check for pattern errors + if (cast(pattern)) + pattern = N(N("ellipsis")); + // Fallback (`__match__`) pattern + auto p = + N(N(N("hasattr"), clone(var), + N("__match__"), clone(pattern)), + N(N(N(var, "__match__"), pattern), suite)); + return p; +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index 13ace312..f92d5da7 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -7,10 +7,11 @@ #include #include +#include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" #include "codon/parser/visitors/format/format.h" -#include "codon/parser/visitors/simplify/ctx.h" +#include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -18,458 +19,225 @@ using namespace codon::error; namespace codon::ast { -TypeContext::TypeContext(Cache *cache) - : Context(""), cache(cache), typecheckLevel(0), age(0), - blockLevel(0), returnEarly(false), changedNodes(0) { - realizationBases.push_back({"", nullptr, nullptr}); - pushSrcInfo(cache->generateSrcInfo()); // Always have srcInfo() around -} - -std::shared_ptr TypeContext::add(TypecheckItem::Kind kind, - const std::string &name, - const types::TypePtr &type) { - auto t = std::make_shared(kind, type); - add(name, t); - return t; -} +TypecheckItem::TypecheckItem(std::string canonicalName, std::string baseName, + std::string moduleName, types::TypePtr type, + std::vector scope) + : canonicalName(std::move(canonicalName)), baseName(std::move(baseName)), + moduleName(std::move(moduleName)), type(std::move(type)), + scope(std::move(scope)) {} -std::shared_ptr TypeContext::find(const std::string &name) const { - if (auto t = Context::find(name)) - return t; - if (in(cache->globals, name)) - return std::make_shared(TypecheckItem::Var, getUnbound()); - return nullptr; +TypeContext::TypeContext(Cache *cache, std::string filename) + : Context(std::move(filename)), cache(cache) { + bases.emplace_back(); + scope.emplace_back(0); + auto e = cache->N(); + e->setSrcInfo(cache->generateSrcInfo()); + pushNode(e); // Always have srcInfo() around } -std::shared_ptr TypeContext::forceFind(const std::string &name) const { - auto t = find(name); - seqassert(t, "cannot find '{}'", name); - return t; +void TypeContext::add(const std::string &name, const TypeContext::Item &var) { + seqassert(!var->scope.empty(), "bad scope for '{}'", name); + Context::add(name, var); } -types::TypePtr TypeContext::getType(const std::string &name) const { - return forceFind(name)->type; +void TypeContext::removeFromMap(const std::string &name) { + Context::removeFromMap(name); } -TypeContext::RealizationBase *TypeContext::getRealizationBase() { - return &(realizationBases.back()); +TypeContext::Item TypeContext::addVar(const std::string &name, + const std::string &canonicalName, + const types::TypePtr &type, + const SrcInfo &srcInfo) { + seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); + // seqassert(type->getLink(), "bad var"); + auto t = std::make_shared(canonicalName, getBaseName(), getModule(), + type, getScope()); + t->setSrcInfo(srcInfo); + add(name, t); + addAlwaysVisible(t); + return t; } -size_t TypeContext::getRealizationDepth() const { return realizationBases.size(); } - -std::string TypeContext::getRealizationStackName() const { - if (realizationBases.empty()) - return ""; - std::vector s; - for (auto &b : realizationBases) - if (b.type) - s.push_back(b.type->realizedName()); - return join(s, ":"); +TypeContext::Item TypeContext::addType(const std::string &name, + const std::string &canonicalName, + const types::TypePtr &type, + const SrcInfo &srcInfo) { + seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); + // seqassert(type->getClass(), "bad type"); + auto t = std::make_shared(canonicalName, getBaseName(), getModule(), + type, getScope()); + t->setSrcInfo(srcInfo); + add(name, t); + addAlwaysVisible(t); + return t; } -std::shared_ptr TypeContext::getUnbound(const SrcInfo &srcInfo, - int level) const { - auto typ = std::make_shared(cache, types::LinkType::Unbound, - cache->unboundCount++, level, nullptr); - typ->setSrcInfo(srcInfo); - return typ; +TypeContext::Item TypeContext::addFunc(const std::string &name, + const std::string &canonicalName, + const types::TypePtr &type, + const SrcInfo &srcInfo) { + seqassert(!canonicalName.empty(), "empty canonical name for '{}'", name); + seqassert(type->getFunc(), "bad func"); + auto t = std::make_shared(canonicalName, getBaseName(), getModule(), + type, getScope()); + t->setSrcInfo(srcInfo); + add(name, t); + addAlwaysVisible(t); + return t; } -std::shared_ptr TypeContext::getUnbound(const SrcInfo &srcInfo) const { - return getUnbound(srcInfo, typecheckLevel); -} +TypeContext::Item TypeContext::addAlwaysVisible(const TypeContext::Item &item, + bool pop) { + add(item->canonicalName, item); + if (pop) + stack.front().pop_back(); // do not remove it later! + if (!cache->typeCtx->Context::find(item->canonicalName)) { + cache->typeCtx->add(item->canonicalName, item); + if (pop) + cache->typeCtx->stack.front().pop_back(); // do not remove it later! -std::shared_ptr TypeContext::getUnbound() const { - return getUnbound(getSrcInfo(), typecheckLevel); + // Realizations etc. + if (!in(cache->reverseIdentifierLookup, item->canonicalName)) + cache->reverseIdentifierLookup[item->canonicalName] = item->canonicalName; + } + return item; } -types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo, - const types::TypePtr &type, - const types::ClassTypePtr &generics) { - seqassert(type, "type is null"); - std::unordered_map genericCache; - if (generics) { - for (auto &g : generics->generics) - if (g.type && - !(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) { - genericCache[g.id] = g.type; +TypeContext::Item TypeContext::find(const std::string &name, int64_t time) const { + auto it = map.find(name); + bool isMangled = in(name, "."); + if (it != map.end()) { + for (auto &i : it->second) { + if (!isMangled && !startswith(getBaseName(), i->getBaseName())) { + continue; // avoid middle realizations } - } - auto t = type->instantiate(typecheckLevel, &(cache->unboundCount), &genericCache); - for (auto &i : genericCache) { - if (auto l = i.second->getLink()) { - i.second->setSrcInfo(srcInfo); - if (l->defaultType) { - getRealizationBase()->pendingDefaults.insert(i.second); + if (isMangled || i->getBaseName() != getBaseName() || !time) { + return i; + } else { + if (i->getTime() <= time) + return i; } } } - if (t->getUnion() && !t->getUnion()->isSealed()) { - t->setSrcInfo(srcInfo); - getRealizationBase()->pendingDefaults.insert(t); - } - if (auto r = t->getRecord()) - if (r->repeats && r->repeats->canRealize()) - r->flatten(); + + // Item is not found in the current module. Time to look in the standard library! + // Note: the standard library items cannot be dominated. + TypeContext::Item t = nullptr; + auto stdlib = cache->imports[STDLIB_IMPORT].ctx; + if (stdlib.get() != this) + t = stdlib->Context::find(name); + + // Maybe we are looking for a canonical identifier? + if (!t && cache->typeCtx.get() != this) + t = cache->typeCtx->Context::find(name); + return t; } -types::TypePtr -TypeContext::instantiateGeneric(const SrcInfo &srcInfo, const types::TypePtr &root, - const std::vector &generics) { - auto c = root->getClass(); - seqassert(c, "root class is null"); - // dummy generic type - auto g = std::make_shared(cache, "", ""); - if (generics.size() != c->generics.size()) { - E(Error::GENERICS_MISMATCH, srcInfo, cache->rev(c->name), c->generics.size(), - generics.size()); - } - for (int i = 0; i < c->generics.size(); i++) { - seqassert(c->generics[i].type, "generic is null"); - g->generics.emplace_back("", "", generics[i], c->generics[i].id); - } - return instantiate(srcInfo, root, g); +TypeContext::Item TypeContext::forceFind(const std::string &name) const { + auto f = find(name); + seqassert(f, "cannot find '{}'", name); + return f; } -std::shared_ptr -TypeContext::instantiateTuple(const SrcInfo &srcInfo, - const std::vector &generics) { - auto key = generateTuple(generics.size()); - auto root = forceFind(key)->type->getRecord(); - return instantiateGeneric(srcInfo, root, generics)->getRecord(); -} +/// Getters and setters -std::string TypeContext::generateTuple(size_t n) { - auto key = format("_{}:{}", TYPE_TUPLE, n); - if (!in(cache->classes, key)) { - cache->classes[key].fields.clear(); - cache->classes[key].ast = - std::static_pointer_cast(clone(cache->classes[TYPE_TUPLE].ast)); - auto root = std::make_shared(cache, TYPE_TUPLE, TYPE_TUPLE); - for (size_t i = 0; i < n; i++) { // generate unique ID - auto g = getUnbound()->getLink(); - g->kind = types::LinkType::Generic; - g->genericName = format("T{}", i + 1); - auto gn = cache->imports[MAIN_IMPORT].ctx->generateCanonicalName(g->genericName); - root->generics.emplace_back(gn, g->genericName, g, g->id); - root->args.emplace_back(g); - cache->classes[key].ast->args.emplace_back( - g->genericName, std::make_shared("type"), nullptr, Param::Generic); - cache->classes[key].fields.push_back( - Cache::Class::ClassField{format("item{}", i + 1), g, ""}); - } - std::vector eTypeArgs; - for (size_t i = 0; i < n; i++) - eTypeArgs.push_back(std::make_shared(format("T{}", i + 1))); - auto eType = std::make_shared(std::make_shared(TYPE_TUPLE), - eTypeArgs); - eType->type = root; - cache->classes[key].mro = {eType}; - addToplevel(key, std::make_shared(TypecheckItem::Type, root)); - } - return key; -} +std::string TypeContext::getBaseName() const { return bases.back().name; } -std::shared_ptr TypeContext::instantiateTuple(size_t n) { - std::vector t(n); - for (size_t i = 0; i < n; i++) { - auto g = getUnbound()->getLink(); - g->genericName = format("T{}", i + 1); - t[i] = g; - } - return instantiateTuple(getSrcInfo(), t); +std::string TypeContext::getModule() const { + std::string base = moduleName.status == ImportFile::STDLIB ? "std." : ""; + base += moduleName.module; + if (auto sz = startswith(base, "__main__")) + base = base.substr(sz); + return base; } -std::vector TypeContext::findMethod(types::ClassType *type, - const std::string &method, - bool hideShadowed) { - auto typeName = type->name; - if (type->is(TYPE_TUPLE)) { - auto sz = type->getRecord()->getRepeats(); - if (sz != -1) - type->getRecord()->flatten(); - sz = int64_t(type->getRecord()->args.size()); - typeName = format("_{}:{}", TYPE_TUPLE, sz); - if (in(cache->classes[TYPE_TUPLE].methods, method) && - !in(cache->classes[typeName].methods, method)) { - auto type = forceFind(typeName)->type; - - cache->classes[typeName].methods[method] = - cache->classes[TYPE_TUPLE].methods[method]; - auto &o = cache->overloads[cache->classes[typeName].methods[method]]; - auto f = cache->functions[o[0].name]; - f.realizations.clear(); - - seqassert(f.type, "tuple fn type not yet set"); - f.ast->attributes.parentClass = typeName; - f.ast = std::static_pointer_cast(clone(f.ast)); - f.ast->name = format("{}{}", f.ast->name.substr(0, f.ast->name.size() - 1), sz); - f.ast->attributes.set(Attr::Method); - - auto eType = clone(cache->classes[typeName].mro[0]); - eType->type = nullptr; - for (auto &a : f.ast->args) - if (a.type && a.type->isId(TYPE_TUPLE)) { - a.type = eType; - } - if (f.ast->ret && f.ast->ret->isId(TYPE_TUPLE)) - f.ast->ret = eType; - // TODO: resurrect Tuple[N].__new__(defaults...) - if (method == "__new__") { - for (size_t i = 0; i < sz; i++) { - auto n = format("item{}", i + 1); - f.ast->args.emplace_back( - cache->imports[MAIN_IMPORT].ctx->generateCanonicalName(n), - std::make_shared(format("T{}", i + 1)) - // std::make_shared( - // std::make_shared(format("T{}", i + 1))) - ); - } - } - cache->reverseIdentifierLookup[f.ast->name] = method; - cache->functions[f.ast->name] = f; - cache->functions[f.ast->name].type = - TypecheckVisitor(cache->typeCtx).makeFunctionType(f.ast.get()); - addToplevel(f.ast->name, - std::make_shared(TypecheckItem::Func, - cache->functions[f.ast->name].type)); - o.push_back(Cache::Overload{f.ast->name, 0}); - } - } +std::string TypeContext::getModulePath() const { return moduleName.path; } - std::vector vv; - std::unordered_set signatureLoci; - - auto populate = [&](const auto &cls) { - auto t = in(cls.methods, method); - if (!t) - return; - auto mt = cache->overloads[*t]; - for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { - auto &method = mt[mti]; - if (endswith(method.name, ":dispatch") || !cache->functions[method.name].type) - continue; - if (method.age <= age) { - if (hideShadowed) { - auto sig = cache->functions[method.name].ast->signature(); - if (!in(signatureLoci, sig)) { - signatureLoci.insert(sig); - vv.emplace_back(cache->functions[method.name].type); - } - } else { - vv.emplace_back(cache->functions[method.name].type); - } - } - } - }; - if (auto cls = in(cache->classes, typeName)) { - for (auto &pt : cls->mro) { - if (auto pc = pt->type->getClass()) { - auto mc = in(cache->classes, pc->name); - seqassert(mc, "class '{}' not found", pc->name); - populate(*mc); - } +void TypeContext::dump() { dump(0); } + +std::string TypeContext::generateCanonicalName(const std::string &name, + bool includeBase, bool noSuffix) const { + std::string newName = name; + bool alreadyGenerated = name.find('.') != std::string::npos; + if (alreadyGenerated) + return name; + includeBase &= !(!name.empty() && name[0] == '%'); + if (includeBase && !alreadyGenerated) { + std::string base = getBaseName(); + if (base.empty()) + base = getModule(); + if (base == "std.internal.core") { + noSuffix = true; + base = ""; } + newName = (base.empty() ? "" : (base + ".")) + newName; } - return vv; + auto num = cache->identifierCount[newName]++; + if (!noSuffix && !alreadyGenerated) + newName = format("{}.{}", newName, num); + if (name != newName) + cache->identifierCount[newName]++; + cache->reverseIdentifierLookup[newName] = name; + return newName; } -types::TypePtr TypeContext::findMember(const types::ClassTypePtr &type, - const std::string &member) const { - if (type->is(TYPE_TUPLE)) { - if (!startswith(member, "item") || member.size() < 5) - return nullptr; - int id = 0; - for (int i = 4; i < member.size(); i++) { - if (member[i] >= '0' + (i == 4) && member[i] <= '9') - id = id * 10 + member[i] - '0'; - else - return nullptr; - } - auto sz = type->getRecord()->getRepeats(); - if (sz != -1) - type->getRecord()->flatten(); - if (id < 1 || id > type->getRecord()->args.size()) - return nullptr; - return type->getRecord()->args[id - 1]; - } - if (auto cls = in(cache->classes, type->name)) { - for (auto &pt : cls->mro) { - if (auto pc = pt->type->getClass()) { - auto mc = in(cache->classes, pc->name); - seqassert(mc, "class '{}' not found", pc->name); - for (auto &mm : mc->fields) { - if (mm.name == member) - return mm.type; - } - } - } - } - return nullptr; +bool TypeContext::isGlobal() const { return bases.size() == 1; } + +bool TypeContext::isConditional() const { return scope.size() > 1; } + +TypeContext::Base *TypeContext::getBase() { + return bases.empty() ? nullptr : &(bases.back()); } -int TypeContext::reorderNamedArgs(types::FuncType *func, - const std::vector &args, - const ReorderDoneFn &onDone, - const ReorderErrorFn &onError, - const std::vector &known) { - // See https://docs.python.org/3.6/reference/expressions.html#calls for details. - // Final score: - // - +1 for each matched argument - // - 0 for *args/**kwargs/default arguments - // - -1 for failed match - int score = 0; - - // 0. Find *args and **kwargs - // True if there is a trailing ellipsis (full partial: fn(all_args, ...)) - bool partial = !args.empty() && args.back().value->getEllipsis() && - args.back().value->getEllipsis()->mode != EllipsisExpr::PIPE && - args.back().name.empty(); - - int starArgIndex = -1, kwstarArgIndex = -1; - for (int i = 0; i < func->ast->args.size(); i++) { - if (startswith(func->ast->args[i].name, "**")) - kwstarArgIndex = i, score -= 2; - else if (startswith(func->ast->args[i].name, "*")) - starArgIndex = i, score -= 2; - } +bool TypeContext::inFunction() const { return !isGlobal() && !bases.back().isType(); } - // 1. Assign positional arguments to slots - // Each slot contains a list of arg's indices - std::vector> slots(func->ast->args.size()); - seqassert(known.empty() || func->ast->args.size() == known.size(), - "bad 'known' string"); - std::vector extra; - std::map namedArgs, - extraNamedArgs; // keep the map--- we need it sorted! - for (int ai = 0, si = 0; ai < args.size() - partial; ai++) { - if (args[ai].name.empty()) { - while (!known.empty() && si < slots.size() && known[si]) - si++; - if (si < slots.size() && (starArgIndex == -1 || si < starArgIndex)) - slots[si++] = {ai}; - else - extra.emplace_back(ai); - } else { - namedArgs[args[ai].name] = ai; - } - } - score += 2 * int(slots.size() - func->funcGenerics.size()); +bool TypeContext::inClass() const { return !isGlobal() && bases.back().isType(); } - for (auto ai : std::vector{std::max(starArgIndex, kwstarArgIndex), - std::min(starArgIndex, kwstarArgIndex)}) - if (ai != -1 && !slots[ai].empty()) { - extra.insert(extra.begin(), ai); - slots[ai].clear(); - } +bool TypeContext::isOuter(const Item &val) const { + return getBaseName() != val->getBaseName() || getModule() != val->getModule(); +} - // 2. Assign named arguments to slots - if (!namedArgs.empty()) { - std::map slotNames; - for (int i = 0; i < func->ast->args.size(); i++) - if (known.empty() || !known[i]) { - slotNames[cache->reverseIdentifierLookup[func->ast->args[i].name]] = i; - } - for (auto &n : namedArgs) { - if (!in(slotNames, n.first)) - extraNamedArgs[n.first] = n.second; - else if (slots[slotNames[n.first]].empty()) - slots[slotNames[n.first]].push_back(n.second); - else - return onError(Error::CALL_REPEATED_NAME, args[n.second].value->getSrcInfo(), - Emsg(Error::CALL_REPEATED_NAME, n.first)); - } - } +TypeContext::Base *TypeContext::getClassBase() { + if (bases.size() >= 2 && bases[bases.size() - 2].isType()) + return &(bases[bases.size() - 2]); + return nullptr; +} - // 3. Fill in *args, if present - if (!extra.empty() && starArgIndex == -1) - return onError(Error::CALL_ARGS_MANY, getSrcInfo(), - Emsg(Error::CALL_ARGS_MANY, cache->rev(func->ast->name), - func->ast->args.size(), args.size() - partial)); - - if (starArgIndex != -1) - slots[starArgIndex] = extra; - - // 4. Fill in **kwargs, if present - if (!extraNamedArgs.empty() && kwstarArgIndex == -1) - return onError(Error::CALL_ARGS_INVALID, - args[extraNamedArgs.begin()->second].value->getSrcInfo(), - Emsg(Error::CALL_ARGS_INVALID, extraNamedArgs.begin()->first, - cache->rev(func->ast->name))); - if (kwstarArgIndex != -1) - for (auto &e : extraNamedArgs) - slots[kwstarArgIndex].push_back(e.second); - - // 5. Fill in the default arguments - for (auto i = 0; i < func->ast->args.size(); i++) - if (slots[i].empty() && i != starArgIndex && i != kwstarArgIndex) { - if (func->ast->args[i].status == Param::Normal && - (func->ast->args[i].defaultValue || (!known.empty() && known[i]))) - score -= 2; - else if (!partial && func->ast->args[i].status == Param::Normal) - return onError(Error::CALL_ARGS_MISSING, getSrcInfo(), - Emsg(Error::CALL_ARGS_MISSING, cache->rev(func->ast->name), - cache->reverseIdentifierLookup[func->ast->args[i].name])); - } - auto s = onDone(starArgIndex, kwstarArgIndex, slots, partial); - return s != -1 ? score + s : -1; +size_t TypeContext::getRealizationDepth() const { return bases.size(); } + +std::string TypeContext::getRealizationStackName() const { + if (bases.empty()) + return ""; + std::vector s; + for (auto &b : bases) + if (b.type) + s.push_back(b.type->realizedName()); + return join(s, ":"); } void TypeContext::dump(int pad) { auto ordered = std::map(map.begin(), map.end()); - LOG("base: {}", getRealizationStackName()); + LOG("current module: {} ({})", moduleName.module, moduleName.path); + LOG("current base: {} / {}", getRealizationStackName(), getBase()->name); for (auto &i : ordered) { std::string s; auto t = i.second.front(); - LOG("{}{:.<25} {}", std::string(pad * 2, ' '), i.first, t->type); + LOG("{}{:.<25}", std::string(size_t(pad) * 2, ' '), i.first); + LOG(" ... kind: {}", t->isType() * 100 + t->isFunc() * 10 + t->isVar()); + LOG(" ... canonical: {}", t->canonicalName); + LOG(" ... base: {}", t->baseName); + LOG(" ... module: {}", t->moduleName); + LOG(" ... type: {}", t->type ? t->type->debugString(2) : ""); + LOG(" ... scope: {}", t->scope); + LOG(" ... gnrc/sttc: {} / {}", t->generic, int(t->isStatic())); } } std::string TypeContext::debugInfo() { - return fmt::format("[{}:i{}@{}]", getRealizationBase()->name, - getRealizationBase()->iteration, getSrcInfo()); -} - -std::shared_ptr, std::vector>> -TypeContext::getFunctionArgs(types::TypePtr t) { - if (!t->getFunc()) - return nullptr; - auto fn = t->getFunc(); - auto ret = std::make_shared< - std::pair, std::vector>>(); - for (auto &t : fn->funcGenerics) - ret->first.push_back(t.type); - for (auto &t : fn->generics[0].type->getRecord()->args) - ret->second.push_back(t); - return ret; -} - -std::shared_ptr TypeContext::getStaticString(types::TypePtr t) { - if (auto s = t->getStatic()) { - auto r = s->evaluate(); - if (r.type == StaticValue::STRING) - return std::make_shared(r.getString()); - } - return nullptr; -} - -std::shared_ptr TypeContext::getStaticInt(types::TypePtr t) { - if (auto s = t->getStatic()) { - auto r = s->evaluate(); - if (r.type == StaticValue::INT) - return std::make_shared(r.getInt()); - } - return nullptr; -} - -types::FuncTypePtr TypeContext::extractFunction(types::TypePtr t) { - if (auto f = t->getFunc()) - return f; - if (auto p = t->getPartial()) - return p->func; - return nullptr; + return fmt::format("[{}:i{}@{}]", getBase()->name, getBase()->iteration, + getSrcInfo()); } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 54166b66..f78bb8f9 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -16,21 +16,53 @@ namespace codon::ast { +class TypecheckVisitor; + /** * Typecheck context identifier. * Can be either a function, a class (type), or a variable. */ -struct TypecheckItem { - /// Identifier kind - enum Kind { Func, Type, Var } kind; +struct TypecheckItem : public SrcObject { + /// Unique identifier (canonical name) + std::string canonicalName; + /// Base name (e.g., foo.bar.baz) + std::string baseName; + /// Full module name + std::string moduleName; /// Type - types::TypePtr type; + types::TypePtr type = nullptr; + + /// Full base scope information + std::vector scope = {0}; + /// Specifies at which time the name was added to the context. + /// Used to prevent using later definitions early (can happen in + /// advanced type checking iterations). + int64_t time = 0; + + /// Set if an identifier is a class or a function generic + bool generic = false; - TypecheckItem(Kind k, types::TypePtr type) : kind(k), type(std::move(type)) {} + TypecheckItem(std::string, std::string, std::string, types::TypePtr, + std::vector = {0}); /* Convenience getters */ - bool isType() const { return kind == Type; } - bool isVar() const { return kind == Var; } + std::string getBaseName() const { return baseName; } + std::string getModule() const { return moduleName; } + bool isVar() const { return !generic && !isFunc() && !isType(); } + bool isFunc() const { return type->getFunc() != nullptr; } + bool isType() const { return type->is(TYPE_TYPE); } + + bool isGlobal() const { return scope.size() == 1 && baseName.empty(); } + /// True if an identifier is within a conditional block + /// (i.e., a block that might not be executed during the runtime) + bool isConditional() const { return scope.size() > 1; } + bool isGeneric() const { return generic; } + char isStatic() const { return type->isStaticType(); } + + types::Type *getType() const { return type.get(); } + std::string getName() const { return canonicalName; } + + int64_t getTime() const { return time; } }; /** Context class that tracks identifiers during the typechecking. **/ @@ -38,143 +70,198 @@ struct TypeContext : public Context { /// A pointer to the shared cache. Cache *cache; - /// A realization base definition. Each function realization defines a new base scope. - /// Used to properly realize enclosed functions and to prevent mess with mutually - /// recursive enclosed functions. - struct RealizationBase { - /// Function name + /// Holds the information about current scope. + /// A scope is defined as a stack of conditional blocks + /// (i.e., blocks that might not get executed during the runtime). + /// Used mainly to support Python's variable scoping rules. + struct ScopeBlock { + int id; + std::unordered_map> replacements; + /// List of statements that are to be prepended to a block + /// after its transformation. + std::vector stmts; + ScopeBlock(int id) : id(id) {} + }; + /// Current hierarchy of conditional blocks. + std::vector scope; + std::vector getScope() const { + std::vector result; + result.reserve(scope.size()); + for (const auto &b : scope) + result.emplace_back(b.id); + return result; + } + + /// Holds the information about current base. + /// A base is defined as a function or a class block. + struct Base { + /// Canonical name of a function or a class that owns this base. std::string name; /// Function type types::TypePtr type; /// The return type of currently realized function - types::TypePtr returnType = nullptr; + types::TypePtr returnType; /// Typechecking iteration int iteration = 0; - std::set pendingDefaults; + /// Only set for functions. + FunctionStmt *func = nullptr; + Stmt *suite = nullptr; + /// Index of the parent base + int parent = 0; + + struct { + /// Set if the base is class base and if class is marked with @deduce. + /// Stores the list of class fields in the order of traversal. + std::shared_ptr> deducedMembers = nullptr; + /// Canonical name of `self` parameter that is used to deduce class fields + /// (e.g., self in self.foo). + std::string selfName; + } deduce; + + /// Map of captured identifiers (i.e., identifiers not defined in a function). + /// Captured (canonical) identifiers are mapped to the new canonical names + /// (representing the canonical function argument names that are appended to the + /// function after processing) and their types (indicating if they are a type, a + /// static or a variable). + // std::unordered_set captures; + + /// Map of identifiers that are to be fetched from Python. + std::unordered_set *pyCaptures = nullptr; + + // /// Scope that defines the base. + // std::vector scope; + + /// A stack of nested loops enclosing the current statement used for transforming + /// "break" statement in loop-else constructs. Each loop is defined by a "break" + /// variable created while parsing a loop-else construct. If a loop has no else + /// block, the corresponding loop variable is empty. + struct Loop { + std::string breakVar; + /// False if a loop has continue/break statement. Used for flattening static + /// loops. + bool flat = true; + Loop(const std::string &breakVar) : breakVar(breakVar) {} + }; + std::vector loops; + + std::map> pendingDefaults; + + public: + Loop *getLoop() { return loops.empty() ? nullptr : &(loops.back()); } + bool isType() const { return func == nullptr; } + }; + /// Current base stack (the last enclosing base is the last base in the stack). + std::vector bases; + + struct BaseGuard { + TypeContext *holder; + BaseGuard(TypeContext *holder, const std::string &name) : holder(holder) { + holder->bases.emplace_back(); + holder->bases.back().name = name; + holder->addBlock(); + } + ~BaseGuard() { + holder->bases.pop_back(); + holder->popBlock(); + } }; - std::vector realizationBases; + + /// Current module. The default module is named `__main__`. + ImportFile moduleName = {ImportFile::PACKAGE, "", ""}; + /// Set if the standard library is currently being loaded. + bool isStdlibLoading = false; + /// Allow type() expressions. Currently used to disallow type() in class + /// and function definitions. + bool allowTypeOf = true; /// The current type-checking level (for type instantiation and generalization). - int typecheckLevel; - int changedNodes; + int typecheckLevel = 0; + int changedNodes = 0; - /// The age of the currently parsed statement. - int age; /// Number of nested realizations. Used to prevent infinite instantiations. - int realizationDepth; - /// Nested default argument calls. Used to prevent infinite CallExpr chains - /// (e.g. class A: def __init__(a: A = A())). - std::set defaultCallDepth; + int realizationDepth = 0; /// Number of nested blocks (0 for toplevel) - int blockLevel; + int blockLevel = 0; /// True if an early return is found (anything afterwards won't be typechecked) - bool returnEarly; + bool returnEarly = false; /// Stack of static loop control variables (used to emulate goto statements). - std::vector staticLoops; + std::vector staticLoops = {}; + + /// Current statement time. + int64_t time; public: - explicit TypeContext(Cache *cache); + explicit TypeContext(Cache *cache, std::string filename = ""); + + void add(const std::string &name, const Item &var) override; + /// Convenience method for adding an object to the context. + Item addVar(const std::string &name, const std::string &canonicalName, + const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo()); + Item addType(const std::string &name, const std::string &canonicalName, + const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo()); + Item addFunc(const std::string &name, const std::string &canonicalName, + const types::TypePtr &type, const SrcInfo &srcInfo = SrcInfo()); + /// Add the item to the standard library module, thus ensuring its visibility from all + /// modules. + Item addAlwaysVisible(const Item &item, bool = false); + + /// Get an item from the context before given srcInfo. If the item does not exist, + /// nullptr is returned. + Item find(const std::string &name, int64_t time = 0) const; + /// Get an item that exists in the context. If the item does not exist, assertion is + /// raised. + Item forceFind(const std::string &name) const; + + /// Return a canonical name of the current base. + /// An empty string represents the toplevel base. + std::string getBaseName() const; + /// Return the current module. + std::string getModule() const; + /// Return the current module path. + std::string getModulePath() const; + /// Pretty-print the current context state. + void dump() override; + + /// Generate a unique identifier (name) for a given string. + std::string generateCanonicalName(const std::string &name, bool includeBase = false, + bool noSuffix = false) const; + /// True if we are at the toplevel. + bool isGlobal() const; + /// True if we are within a conditional block. + bool isConditional() const; + /// Get the current base. + Base *getBase(); + /// True if the current base is function. + bool inFunction() const; + /// True if the current base is class. + bool inClass() const; + /// True if an item is defined outside of the current base or a module. + bool isOuter(const Item &val) const; + /// Get the enclosing class base (or nullptr if such does not exist). + Base *getClassBase(); - using Context::add; /// Convenience method for adding an object to the context. - std::shared_ptr add(TypecheckItem::Kind kind, const std::string &name, - const types::TypePtr &type = nullptr); std::shared_ptr addToplevel(const std::string &name, const std::shared_ptr &item) { map[name].push_front(item); return item; } - std::shared_ptr find(const std::string &name) const override; - std::shared_ptr find(const char *name) const { - return find(std::string(name)); - } - /// Find an internal type. Assumes that it exists. - std::shared_ptr forceFind(const std::string &name) const; - types::TypePtr getType(const std::string &name) const; - - /// Pretty-print the current context state. - void dump() override { dump(0); } public: /// Get the current realization depth (i.e., the number of nested realizations). size_t getRealizationDepth() const; - /// Get the current base. - RealizationBase *getRealizationBase(); /// Get the name of the current realization stack (e.g., `fn1:fn2:...`). std::string getRealizationStackName() const; -public: - /// Create an unbound type with the provided typechecking level. - std::shared_ptr getUnbound(const SrcInfo &info, int level) const; - std::shared_ptr getUnbound(const SrcInfo &info) const; - std::shared_ptr getUnbound() const; - - /// Call `type->instantiate`. - /// Prepare the generic instantiation table with the given generics parameter. - /// Example: when instantiating List[T].foo, generics=List[int].foo will ensure that - /// T=int. - /// @param expr Expression that needs the type. Used to set type's srcInfo. - /// @param setActive If True, add unbounds to activeUnbounds. - types::TypePtr instantiate(const SrcInfo &info, const types::TypePtr &type, - const types::ClassTypePtr &generics = nullptr); - types::TypePtr instantiate(types::TypePtr type, - const types::ClassTypePtr &generics = nullptr) { - return instantiate(getSrcInfo(), std::move(type), generics); - } - - /// Instantiate the generic type root with the provided generics. - /// @param expr Expression that needs the type. Used to set type's srcInfo. - types::TypePtr instantiateGeneric(const SrcInfo &info, const types::TypePtr &root, - const std::vector &generics); - types::TypePtr instantiateGeneric(types::TypePtr root, - const std::vector &generics) { - return instantiateGeneric(getSrcInfo(), std::move(root), generics); - } - - std::shared_ptr - instantiateTuple(const SrcInfo &info, const std::vector &generics); - std::shared_ptr - instantiateTuple(const std::vector &generics) { - return instantiateTuple(getSrcInfo(), generics); - } - std::shared_ptr instantiateTuple(size_t n); - std::string generateTuple(size_t n); - - /// Returns the list of generic methods that correspond to typeName.method. - std::vector findMethod(types::ClassType *type, - const std::string &method, - bool hideShadowed = true); - /// Returns the generic type of typeName.member, if it exists (nullptr otherwise). - /// Special cases: __elemsize__ and __atomic__. - types::TypePtr findMember(const types::ClassTypePtr &, const std::string &) const; - - using ReorderDoneFn = - std::function> &, bool)>; - using ReorderErrorFn = std::function; - /// Reorders a given vector or named arguments (consisting of names and the - /// corresponding types) according to the signature of a given function. - /// Returns the reordered vector and an associated reordering score (missing - /// default arguments' score is half of the present arguments). - /// Score is -1 if the given arguments cannot be reordered. - /// @param known Bitmask that indicated if an argument is already provided - /// (partial function) or not. - int reorderNamedArgs(types::FuncType *func, const std::vector &args, - const ReorderDoneFn &onDone, const ReorderErrorFn &onError, - const std::vector &known = std::vector()); - private: /// Pretty-print the current context state. void dump(int pad); /// Pretty-print the current realization context. std::string debugInfo(); -public: - std::shared_ptr, std::vector>> - getFunctionArgs(types::TypePtr t); - std::shared_ptr getStaticString(types::TypePtr t); - std::shared_ptr getStaticInt(types::TypePtr t); - types::FuncTypePtr extractFunction(types::TypePtr t); +protected: + void removeFromMap(const std::string &name) override; }; } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/error.cpp b/codon/parser/visitors/typecheck/error.cpp index 651c0594..73d7baef 100644 --- a/codon/parser/visitors/typecheck/error.cpp +++ b/codon/parser/visitors/typecheck/error.cpp @@ -2,7 +2,7 @@ #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -10,6 +10,32 @@ using fmt::format; namespace codon::ast { using namespace types; +using namespace matcher; +using namespace error; + +/// Transform asserts. +/// @example +/// `assert foo()` -> +/// `if not foo(): raise __internal__.seq_assert([file], [line], "")` +/// `assert foo(), msg` -> +/// `if not foo(): raise __internal__.seq_assert([file], [line], str(msg))` +/// Use `seq_assert_test` instead of `seq_assert` and do not raise anything during unit +/// testing (i.e., when the enclosing function is marked with `@test`). +void TypecheckVisitor::visit(AssertStmt *stmt) { + Expr *msg = N(""); + if (stmt->getMessage()) + msg = N(N("str"), stmt->getMessage()); + auto test = ctx->inFunction() && ctx->getBase()->func->hasAttribute(Attr::Test); + auto ex = N( + N(N("__internal__"), test ? "seq_assert_test" : "seq_assert"), + N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), msg); + auto cond = N("!", stmt->getExpr()); + if (test) { + resultStmt = transform(N(cond, N(ex))); + } else { + resultStmt = transform(N(cond, N(ex))); + } +} /// Typecheck try-except statements. Handle Python exceptions separately. /// @example @@ -27,101 +53,117 @@ using namespace types; /// f = exc; ...; break # PyExc /// raise``` void TypecheckVisitor::visit(TryStmt *stmt) { - // TODO: static can-compile check - // if (stmt->catches.size() == 1 && stmt->catches[0].var.empty() && - // stmt->catches[0].exc->isId("std.internal.types.error.StaticCompileError")) { - // /// TODO: this is right now _very_ dangerous; inferred types here will remain! - // bool compiled = true; - // try { - // auto nctx = std::make_shared(*ctx); - // TypecheckVisitor(nctx).transform(clone(stmt->suite)); - // } catch (const exc::ParserException &exc) { - // compiled = false; - // } - // resultStmt = compiled ? transform(stmt->suite) : - // transform(stmt->catches[0].suite); LOG("testing!! {} {}", getSrcInfo(), - // compiled); return; - // } + if (stmt->getElse()) { + auto successVar = getTemporaryVar("try_success"); + prependStmts->push_back( + transform(N(N(successVar), N(false)))); + stmt->getSuite()->addStmt(N(N(successVar), N(true), + nullptr, AssignStmt::UpdateMode::Update)); + stmt->finally = + N(N(N(successVar), stmt->getElse()), stmt->finally); + stmt->elseSuite = nullptr; + } ctx->blockLevel++; - transform(stmt->suite); + stmt->suite = SuiteStmt::wrap(transform(stmt->getSuite())); ctx->blockLevel--; - std::vector catches; - auto pyVar = ctx->cache->getTemporaryVar("pyexc"); + std::vector catches; + auto pyVar = getTemporaryVar("pyexc"); auto pyCatchStmt = N(N(true), N()); - auto done = stmt->suite->isDone(); - for (auto &c : stmt->catches) { - transform(c.exc); - if (c.exc && c.exc->type->is("pyobj")) { + auto done = stmt->getSuite()->isDone(); + for (auto &c : *stmt) { + TypeContext::Item val = nullptr; + if (!c->getVar().empty()) { + if (!c->hasAttribute(Attr::ExprDominated) && + !c->hasAttribute(Attr::ExprDominatedUsed)) { + val = ctx->addVar(c->getVar(), ctx->generateCanonicalName(c->getVar()), + instantiateUnbound()); + val->time = getTime(); + } else if (c->hasAttribute(Attr::ExprDominatedUsed)) { + val = ctx->forceFind(c->getVar()); + c->eraseAttribute(Attr::ExprDominatedUsed); + c->setAttribute(Attr::ExprDominated); + c->suite = N( + N(N(format("{}{}", getUnmangledName(c->getVar()), + VAR_USED_SUFFIX)), + N(true), nullptr, AssignStmt::UpdateMode::Update), + c->getSuite()); + } else { + val = ctx->forceFind(c->getVar()); + } + c->var = val->canonicalName; + } + c->exc = transform(c->getException()); + if (c->getException() && extractClassType(c->getException())->is("pyobj")) { // Transform python.Error exceptions - if (!c.var.empty()) { - c.suite = N( - N(N(c.var), N(N(pyVar), "pytype")), - c.suite); + if (!c->var.empty()) { + c->suite = N( + N(N(c->var), N(N(pyVar), "pytype")), + c->getSuite()); } - c.suite = - N(N(N("isinstance"), - N(N(pyVar), "pytype"), clone(c.exc)), - N(c.suite, N()), nullptr); - pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite); - } else if (c.exc && c.exc->type->is("std.internal.types.error.PyError")) { + c->suite = SuiteStmt::wrap(N( + N(N("isinstance"), N(N(pyVar), "pytype"), + c->getException()), + N(c->getSuite(), N()), nullptr)); + cast(pyCatchStmt->getSuite())->addStmt(c->getSuite()); + } else if (c->getException() && extractClassType(c->getException()) + ->is("std.internal.python.PyError.0")) { // Transform PyExc exceptions - if (!c.var.empty()) { - c.suite = - N(N(N(c.var), N(pyVar)), c.suite); + if (!c->var.empty()) { + c->suite = N(N(N(c->var), N(pyVar)), + c->getSuite()); } - c.suite = N(c.suite, N()); - pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite); + c->suite = N(c->getSuite(), N()); + cast(pyCatchStmt->getSuite())->addStmt(c->getSuite()); } else { // Handle all other exceptions - transformType(c.exc); - if (!c.var.empty()) { - // Handle dominated except bindings - auto changed = in(ctx->cache->replacements, c.var); - while (auto s = in(ctx->cache->replacements, c.var)) - c.var = s->first, changed = s; - if (changed && changed->second) { - auto update = - N(N(format("{}.__used__", c.var)), N(true)); - update->setUpdate(); - c.suite = N(update, c.suite); - } - if (changed) - c.exc->setAttr(ExprAttr::Dominated); - auto val = ctx->find(c.var); - if (!changed) - val = ctx->add(TypecheckItem::Var, c.var, c.exc->getType()); - unify(val->type, c.exc->getType()); + c->exc = transformType(c->getException()); + + if (c->getException()) { + auto t = extractClassType(c->getException()); + bool exceptionOK = false; + for (auto &p : getSuperTypes(t)) + if (p->is("std.internal.types.error.BaseException.0")) { + exceptionOK = true; + break; + } + if (!exceptionOK) + E(Error::CATCH_EXCEPTION_TYPE, c->getException(), t->toString()); + if (val) + unify(val->getType(), extractType(c->getException())); } ctx->blockLevel++; - transform(c.suite); + c->suite = SuiteStmt::wrap(transform(c->getSuite())); ctx->blockLevel--; - done &= (!c.exc || c.exc->isDone()) && c.suite->isDone(); + done &= (!c->getException() || c->getException()->isDone()) && + c->getSuite()->isDone(); catches.push_back(c); } } - if (!pyCatchStmt->suite->getSuite()->stmts.empty()) { + if (!cast(pyCatchStmt->getSuite())->empty()) { // Process PyError catches - auto exc = NT("std.internal.types.error.PyError"); - pyCatchStmt->suite->getSuite()->stmts.push_back(N(nullptr)); - TryStmt::Catch c{pyVar, transformType(exc), pyCatchStmt}; + auto exc = N("std.internal.python.PyError.0"); + cast(pyCatchStmt->getSuite())->addStmt(N(nullptr)); + auto c = N(pyVar, transformType(exc), pyCatchStmt); - auto val = ctx->add(TypecheckItem::Var, pyVar, c.exc->getType()); - unify(val->type, c.exc->getType()); + auto val = + ctx->addVar(pyVar, pyVar, extractType(c->getException())->shared_from_this()); + val->time = getTime(); ctx->blockLevel++; - transform(c.suite); + c->suite = SuiteStmt::wrap(transform(c->getSuite())); ctx->blockLevel--; - done &= (!c.exc || c.exc->isDone()) && c.suite->isDone(); + done &= (!c->exc || c->exc->isDone()) && c->getSuite()->isDone(); catches.push_back(c); } - stmt->catches = catches; - if (stmt->finally) { + stmt->items = catches; + + if (stmt->getFinally()) { ctx->blockLevel++; - transform(stmt->finally); + stmt->finally = SuiteStmt::wrap(transform(stmt->getFinally())); ctx->blockLevel--; - done &= stmt->finally->isDone(); + done &= stmt->getFinally()->isDone(); } if (done) @@ -137,19 +179,55 @@ void TypecheckVisitor::visit(ThrowStmt *stmt) { return; } - transform(stmt->expr); - - if (!(stmt->expr->getCall() && stmt->expr->getCall()->expr->getId() && - startswith(stmt->expr->getCall()->expr->getId()->value, - "__internal__.set_header:0"))) { + stmt->expr = transform(stmt->getExpr()); + if (!match(stmt->getExpr(), + M(M("__internal__.set_header:0"), M_))) { stmt->expr = transform(N( - N("__internal__.set_header:0"), stmt->expr, - N(ctx->getRealizationBase()->name), - N(stmt->getSrcInfo().file), N(stmt->getSrcInfo().line), - N(stmt->getSrcInfo().col))); + N("__internal__.set_header:0"), stmt->getExpr(), + N(ctx->getBase()->name), N(stmt->getSrcInfo().file), + N(stmt->getSrcInfo().line), N(stmt->getSrcInfo().col), + stmt->getFrom() + ? (Expr *)N(N(N("__internal__"), "class_super"), + stmt->getFrom(), + N("std.internal.types.error.BaseException.0")) + : N(N("NoneType")))); } - if (stmt->expr->isDone()) + if (stmt->getExpr()->isDone()) stmt->setDone(); } +/// Transform with statements. +/// @example +/// `with foo(), bar() as a: ...` -> +/// ```tmp = foo() +/// tmp.__enter__() +/// try: +/// a = bar() +/// a.__enter__() +/// try: +/// ... +/// finally: +/// a.__exit__() +/// finally: +/// tmp.__exit__()``` +void TypecheckVisitor::visit(WithStmt *stmt) { + seqassert(!stmt->empty(), "stmt->items is empty"); + std::vector content; + for (auto i = stmt->items.size(); i-- > 0;) { + std::string var = stmt->vars[i].empty() ? getTemporaryVar("with") : stmt->vars[i]; + auto as = N(N(var), (*stmt)[i], nullptr, + (*stmt)[i]->hasAttribute(Attr::ExprDominated) + ? AssignStmt::UpdateMode::Update + : AssignStmt::UpdateMode::Assign); + content = std::vector{ + as, N(N(N(N(var), "__enter__"))), + N(!content.empty() ? N(content) : clone(stmt->getSuite()), + std::vector{}, + nullptr, + N(N( + N(N(N(var), "__exit__")))))}; + } + resultStmt = transform(N(content)); +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index aa8ff170..33057489 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -1,11 +1,15 @@ // Copyright (C) 2022-2025 Exaloop Inc. +#include #include #include +#include "codon/cir/attribute.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" +#include "codon/parser/peg/peg.h" +#include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -14,14 +18,37 @@ using namespace codon::error; namespace codon::ast { using namespace types; +using namespace matcher; + +/// Unify the function return type with `Generator[?]`. +/// The unbound type will be deduced from return/yield statements. +void TypecheckVisitor::visit(LambdaExpr *expr) { + std::vector params; + std::string name = getTemporaryVar("lambda"); + params.reserve(expr->size()); + for (auto &s : *expr) + params.emplace_back(s); + Stmt *f = N(name, nullptr, params, + N(N(expr->getExpr()))); + if (auto err = ScopingVisitor::apply(ctx->cache, N(f))) + throw exc::ParserException(std::move(err)); + f->setAttribute(Attr::ExprTime, getTime()); // to handle captures properly + f = transform(f); + if (auto a = expr->getAttribute(Attr::Bindings)) + f->setAttribute(Attr::Bindings, a->clone()); + prependStmts->push_back(f); + resultExpr = transform(N(N(name), N())); +} /// Unify the function return type with `Generator[?]`. /// The unbound type will be deduced from return/yield statements. void TypecheckVisitor::visit(YieldExpr *expr) { - unify(expr->type, ctx->getUnbound()); - unify(ctx->getRealizationBase()->returnType, - ctx->instantiateGeneric(ctx->getType("Generator"), {expr->type})); - if (realize(expr->type)) + if (!ctx->inFunction()) + E(Error::FN_OUTSIDE_ERROR, expr, "yield"); + + unify(ctx->getBase()->returnType.get(), + instantiateType(getStdLibType("Generator"), {expr->getType()})); + if (realize(expr->getType())) expr->setDone(); } @@ -29,28 +56,43 @@ void TypecheckVisitor::visit(YieldExpr *expr) { /// Also partialize functions if they are being returned. /// See @c wrapExpr for more details. void TypecheckVisitor::visit(ReturnStmt *stmt) { - if (!stmt->expr && ctx->getRealizationBase()->type && - ctx->getRealizationBase()->type->getFunc()->ast->hasAttr(Attr::IsGenerator)) { + if (stmt->hasAttribute(Attr::Internal)) { + stmt->expr = transform(N(N("NoneType.__new__:0"))); + stmt->setDone(); + return; + } + + if (!ctx->inFunction()) + E(Error::FN_OUTSIDE_ERROR, stmt, "return"); + + if (!stmt->expr && ctx->getBase()->func->hasAttribute(Attr::IsGenerator)) { stmt->setDone(); } else { - if (!stmt->expr) { + if (!stmt->expr) stmt->expr = N(N("NoneType")); - } - transform(stmt->expr); + stmt->expr = transform(stmt->getExpr()); + // Wrap expression to match the return type - if (!ctx->getRealizationBase()->returnType->getUnbound()) - if (!wrapExpr(stmt->expr, ctx->getRealizationBase()->returnType)) { + if (!ctx->getBase()->returnType->getUnbound()) + if (!wrapExpr(&stmt->expr, ctx->getBase()->returnType.get())) { return; } // Special case: partialize functions if we are returning them - if (stmt->expr->getType()->getFunc() && - !(ctx->getRealizationBase()->returnType->getClass() && - ctx->getRealizationBase()->returnType->is("Function"))) { - stmt->expr = partializeFunction(stmt->expr->type->getFunc()); + if (stmt->getExpr()->getType()->getFunc() && + !(ctx->getBase()->returnType->getClass() && + ctx->getBase()->returnType->is("Function"))) { + stmt->expr = transform(partializeFunction(stmt->getExpr()->getType()->getFunc())); } - unify(ctx->getRealizationBase()->returnType, stmt->expr->type); + if (!ctx->getBase()->returnType->isStaticType() && + stmt->getExpr()->getType()->getStatic()) + stmt->getExpr()->setType(stmt->getExpr() + ->getType() + ->getStatic() + ->getNonStaticType() + ->shared_from_this()); + unify(ctx->getBase()->returnType.get(), stmt->getExpr()->getType()); } // If we are not within conditional block, ignore later statements in this function. @@ -58,198 +100,551 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { if (!ctx->blockLevel) ctx->returnEarly = true; - if (!stmt->expr || stmt->expr->isDone()) + if (!stmt->getExpr() || stmt->getExpr()->isDone()) stmt->setDone(); } /// Typecheck yield statements. Empty yields assume `NoneType`. void TypecheckVisitor::visit(YieldStmt *stmt) { - stmt->expr = transform(stmt->expr ? stmt->expr : N(N("NoneType"))); + if (!ctx->inFunction()) + E(Error::FN_OUTSIDE_ERROR, stmt, "yield"); - auto t = ctx->instantiateGeneric(ctx->getType("Generator"), {stmt->expr->type}); - unify(ctx->getRealizationBase()->returnType, t); + stmt->expr = + transform(stmt->getExpr() ? stmt->getExpr() : N(N("NoneType"))); + unify(ctx->getBase()->returnType.get(), + instantiateType(getStdLibType("Generator"), {stmt->getExpr()->getType()})); - if (stmt->expr->isDone()) + if (stmt->getExpr()->isDone()) stmt->setDone(); } +/// Transform `yield from` statements. +/// @example +/// `yield from a` -> `for var in a: yield var` +void TypecheckVisitor::visit(YieldFromStmt *stmt) { + auto var = getTemporaryVar("yield"); + resultStmt = transform( + N(N(var), stmt->getExpr(), N(N(var)))); +} + +/// Process `global` statements. Remove them upon completion. +void TypecheckVisitor::visit(GlobalStmt *stmt) { resultStmt = N(); } + /// Parse a function stub and create a corresponding generic function type. /// Also realize built-ins and extern C functions. void TypecheckVisitor::visit(FunctionStmt *stmt) { - // Function should be constructed only once - stmt->setDone(); + if (stmt->hasAttribute(Attr::Python)) { + // Handle Python block + resultStmt = + transformPythonDefinition(stmt->getName(), stmt->items, stmt->getReturn(), + stmt->getSuite()->firstInBlock()); + return; + } + auto origStmt = clean_clone(stmt); + + // Parse attributes + std::vector attributes; + for (auto i = stmt->decorators.size(); i-- > 0;) { + if (!stmt->decorators[i]) + continue; + auto [isAttr, attrName] = getDecorator(stmt->decorators[i]); + if (!attrName.empty()) { + if (isAttr) { + attributes.push_back(attrName); + stmt->setAttribute(attrName); + stmt->decorators[i] = nullptr; // remove it from further consideration + } + } + } + + bool isClassMember = ctx->inClass(); + if (stmt->hasAttribute(Attr::ForceRealize) && (!ctx->isGlobal() || isClassMember)) + E(Error::EXPECTED_TOPLEVEL, getSrcInfo(), "builtin function"); + + // All overloads share the same canonical name except for the number at the + // end (e.g., `foo.1:0`, `foo.1:1` etc.) + std::string rootName; + if (isClassMember) { + // Case 1: method overload + if (auto n = in(getClass(ctx->getBase()->name)->methods, stmt->getName())) + rootName = *n; + } else if (stmt->hasAttribute(Attr::Overload)) { + // Case 2: function overload + if (auto c = ctx->find(stmt->getName(), getTime())) { + if (c->isFunc() && c->getModule() == ctx->getModule() && + c->getBaseName() == ctx->getBaseName()) { + rootName = c->canonicalName; + } + } + } + if (rootName.empty()) + rootName = ctx->generateCanonicalName(stmt->getName(), true, isClassMember); + // Append overload number to the name + auto canonicalName = rootName; + if (!in(ctx->cache->overloads, rootName)) + ctx->cache->overloads.insert({rootName, {}}); + canonicalName += format(":{}", getOverloads(rootName).size()); + ctx->cache->reverseIdentifierLookup[canonicalName] = stmt->getName(); + + if (isClassMember) { + // Set the enclosing class name + stmt->setAttribute(Attr::ParentClass, ctx->getBase()->name); + // Add the method to the class' method list + getClass(ctx->getBase()->name)->methods[stmt->getName()] = rootName; + } + + // Handle captures. Add additional argument to the function for every capture. + // Make sure to account for **kwargs if present + std::map captures; + if (auto b = stmt->getAttribute(Attr::Bindings)) + for (auto &[c, t] : b->captures) { + if (auto v = ctx->find(c, getTime())) { + if (t != BindingsAttribute::CaptureType::Global && !v->isGlobal()) { + bool parentClassGeneric = + ctx->bases.back().isType() && ctx->bases.back().name == v->getBaseName(); + if (v->isGeneric() && parentClassGeneric) { + stmt->setAttribute(Attr::Method); + } + if (!v->isGeneric() || (v->isStatic() && !parentClassGeneric)) { + captures[c] = v; + } + } + } + } + std::vector partialArgs; + if (!captures.empty()) { + std::vector itemKeys; + itemKeys.reserve(captures.size()); + for (const auto &[key, _] : captures) + itemKeys.emplace_back(key); + + // Handle partial arguments (and static special cases) + Param kw; + if (!stmt->empty() && startswith(stmt->back().name, "**")) { + kw = stmt->back(); + stmt->items.pop_back(); + } + std::array op{"", "int", "str", "bool"}; + for (auto &[c, v] : captures) { + if (v->isType()) { + stmt->items.emplace_back(c, N(TYPE_TYPE)); + } else if (auto si = v->isStatic()) { + stmt->items.emplace_back(c, + N(N("Static"), N(op[si]))); + } else { + stmt->items.emplace_back(c); + } + if (v->isFunc()) { + partialArgs.emplace_back(c, + N(N(v->canonicalName), + N(EllipsisExpr::PARTIAL))); + } else { + partialArgs.emplace_back(c, N(v->canonicalName)); + } + } + if (!kw.name.empty()) + stmt->items.emplace_back(kw); + partialArgs.emplace_back("", N(EllipsisExpr::PARTIAL)); + } + + std::vector args; + Stmt *suite = nullptr; + Expr *ret = nullptr; + std::vector explicits; + std::shared_ptr baseType = nullptr; + { + // Set up the base + TypeContext::BaseGuard br(ctx.get(), canonicalName); + ctx->getBase()->func = stmt; + + // Parse arguments and add them to the context + for (auto &a : *stmt) { + auto [stars, varName] = a.getNameWithStars(); + auto name = ctx->generateCanonicalName(varName); + + // Mark as method if the first argument is self + if (isClassMember && stmt->hasAttribute(Attr::HasSelf) && a.getName() == "self") + stmt->setAttribute(Attr::Method); + + // Handle default values + auto defaultValue = a.getDefault(); + if (a.getType() && defaultValue && cast(defaultValue)) { + // Special case: `arg: Callable = None` -> `arg: Callable = NoneType()` + if (match(a.getType(), M(M(TYPE_CALLABLE), M_))) + defaultValue = N(N("NoneType")); + // Special case: `arg: type = None` -> `arg: type = NoneType` + if (match(a.getType(), M(MOr(TYPE_TYPE, TYPE_TYPEVAR)))) + defaultValue = N("NoneType"); + } + /// TODO: Python-style defaults + args.emplace_back(std::string(stars, '*') + name, a.getType(), defaultValue, + a.status); + + // Add generics to the context + if (!a.isValue()) { + // Generic and static types + auto generic = instantiateUnbound(); + auto typId = generic->getLink()->id; + generic->genericName = varName; + auto defType = transform(clone(a.getDefault())); + if (auto st = getStaticGeneric(a.getType())) { + auto val = ctx->addVar(varName, name, generic); + val->generic = true; + generic->isStatic = st; + if (defType) + generic->defaultType = extractType(defType)->shared_from_this(); + } else { + if (match(a.getType(), M(M(TYPE_TYPEVAR), M_))) { + // Parse TraitVar + auto l = + transformType(cast(a.getType())->front())->getType(); + if (l->getLink() && l->getLink()->trait) + generic->getLink()->trait = l->getLink()->trait; + else + generic->getLink()->trait = + std::make_shared(l->shared_from_this()); + } + auto val = ctx->addType(varName, name, generic); + val->generic = true; + if (defType) + generic->defaultType = extractType(defType)->shared_from_this(); + } + auto g = generic->generalize(ctx->typecheckLevel); + explicits.emplace_back(name, varName, g, typId, g->isStaticType()); + } + } + + // Prepare list of all generic types + ClassType *parentClass = nullptr; + if (isClassMember && stmt->hasAttribute(Attr::Method)) { + // Get class generics (e.g., T for `class Cls[T]: def foo:`) + auto aa = stmt->getAttribute(Attr::ParentClass); + parentClass = extractClassType(aa->value); + } + // Add function generics + std::vector generics; + generics.reserve(explicits.size()); + for (const auto &i : explicits) + generics.emplace_back(extractType(i.name)->shared_from_this()); - auto funcTyp = makeFunctionType(stmt); - // If this is a class method, update the method lookup table - bool isClassMember = !stmt->attributes.parentClass.empty(); + // Handle function arguments + // Base type: `Function[[args,...], ret]` + baseType = getFuncTypeBase(stmt->size() - explicits.size()); + ctx->typecheckLevel++; + + // Parse arguments to the context. Needs to be done after adding generics + // to support cases like `foo(a: T, T: type)` + for (auto &a : args) { + a.type = transformType(a.getType(), false); + a.defaultValue = transform(a.getDefault(), true); + } + + // Unify base type generics with argument types. Add non-generic arguments to the + // context. Delayed to prevent cases like `def foo(a, b=a)` + auto argType = extractClassGeneric(baseType.get())->getClass(); + for (int ai = 0, aj = 0; ai < stmt->size(); ai++) { + if (!(*stmt)[ai].isValue()) + continue; + auto [_, canName] = (*stmt)[ai].getNameWithStars(); + if (!(*stmt)[ai].getType()) { + if (parentClass && ai == 0 && (*stmt)[ai].getName() == "self") { + // Special case: self in methods + unify(extractClassGeneric(argType, aj), parentClass); + } else { + generics.push_back(extractClassGeneric(argType, aj)->shared_from_this()); + } + } else if (startswith((*stmt)[ai].getName(), "*")) { + // Special case: `*args: type` and `**kwargs: type`. Do not add this type to the + // signature (as the real type is `Tuple[type, ...]`); it will be used during + // call typechecking + generics.push_back(extractClassGeneric(argType, aj)->shared_from_this()); + } else { + unify(extractClassGeneric(argType, aj), + extractType(transformType((*stmt)[ai].getType()))); + } + aj++; + } + + // Parse the return type + ret = transformType(stmt->getReturn(), false); + auto retType = extractClassGeneric(baseType.get(), 1); + if (ret) { + unify(retType, extractType(ret)); + if (isId(ret, "Union")) + extractClassGeneric(retType)->getUnbound()->kind = LinkType::Generic; + } else { + generics.push_back(unify(retType, instantiateUnbound())->shared_from_this()); + } + ctx->typecheckLevel--; + + // Generalize generics and remove them from the context + for (const auto &g : generics) { + for (auto &u : g->getUnbounds()) + if (u->getUnbound()) { + u->getUnbound()->kind = LinkType::Generic; + } + } + + // Parse function body + if (!stmt->hasAttribute(Attr::Internal) && !stmt->hasAttribute(Attr::C)) { + if (stmt->hasAttribute(Attr::LLVM)) { + suite = transformLLVMDefinition(stmt->getSuite()->firstInBlock()); + } else if (stmt->hasAttribute(Attr::C)) { + // Do nothing + } else { + suite = clone(stmt->getSuite()); + } + } + } + stmt->setAttribute(Attr::Module, ctx->moduleName.path); + + // Make function AST and cache it for later realization + auto f = N(canonicalName, ret, args, suite); + f->cloneAttributesFrom(stmt); + auto &fn = ctx->cache->functions[canonicalName] = + Cache::Function{ctx->getModulePath(), + rootName, + f, + nullptr, + origStmt, + ctx->getModule().empty() && ctx->isGlobal()}; + f->setDone(); + auto aa = stmt->getAttribute(Attr::ParentClass); + auto parentClass = aa ? extractClassType(aa->value) : nullptr; + + // Construct the type + auto funcTyp = + std::make_shared(baseType.get(), fn.ast, 0, explicits); + funcTyp->setSrcInfo(getSrcInfo()); + if (isClassMember && stmt->hasAttribute(Attr::Method)) { + funcTyp->funcParent = parentClass->shared_from_this(); + } + funcTyp = std::static_pointer_cast( + funcTyp->generalize(ctx->typecheckLevel)); + fn.type = funcTyp; + + auto &overloads = ctx->cache->overloads[rootName]; + if (rootName == "Tuple.__new__") { + overloads.insert( + std::upper_bound(overloads.begin(), overloads.end(), canonicalName, + [&](const auto &a, const auto &b) { + return getFunction(a)->getType()->funcGenerics.size() < + getFunction(b)->getType()->funcGenerics.size(); + }), + canonicalName); + } else { + overloads.push_back(canonicalName); + } + + auto val = ctx->addFunc(stmt->name, rootName, funcTyp); + // val->time = getTime(); + ctx->addFunc(canonicalName, canonicalName, funcTyp); + if (stmt->hasAttribute(Attr::Overload) || isClassMember) { + ctx->remove(stmt->name); // first overload will handle it! + } + + // Special method handling if (isClassMember) { - auto m = - ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(), - ctx->cache->rev(stmt->name)); + auto m = getClassMethod(parentClass, getUnmangledName(canonicalName)); bool found = false; - for (auto &i : ctx->cache->overloads[m]) - if (i.name == stmt->name) { - ctx->cache->functions[i.name].type = funcTyp; + for (auto &i : getOverloads(m)) + if (i == canonicalName) { + getFunction(i)->type = funcTyp; found = true; break; } - seqassert(found, "cannot find matching class method for {}", stmt->name); + seqassert(found, "cannot find matching class method for {}", canonicalName); + } else { + // Hack so that we can later use same helpers for class overloads + getClass(VAR_CLASS_TOPLEVEL)->methods[stmt->getName()] = rootName; } - // Update the visited table - // Functions should always be visible, so add them to the toplevel - ctx->addToplevel(stmt->name, - std::make_shared(TypecheckItem::Func, funcTyp)); - ctx->cache->functions[stmt->name].type = funcTyp; - // Ensure that functions with @C, @force_realize, and @export attributes can be // realized - if (stmt->attributes.has(Attr::ForceRealize) || stmt->attributes.has(Attr::Export) || - (stmt->attributes.has(Attr::C) && !stmt->attributes.has(Attr::CVarArg))) { + if (stmt->hasAttribute(Attr::ForceRealize) || stmt->hasAttribute(Attr::Export) || + (stmt->hasAttribute(Attr::C) && !stmt->hasAttribute(Attr::CVarArg))) { if (!funcTyp->canRealize()) E(Error::FN_REALIZE_BUILTIN, stmt); } - // Debug information - LOG_REALIZE("[stmt] added func {}: {}", stmt->name, funcTyp); -} + // Expression to be used if function binding is modified by captures or decorators + Expr *finalExpr = nullptr; + Expr *selfAssign = nullptr; + // If there are captures, replace `fn` with `fn(cap1=cap1, cap2=cap2, ...)` + if (!captures.empty()) { + if (isClassMember) + E(Error::ID_CANNOT_CAPTURE, getSrcInfo(), captures.begin()->first); -types::FuncTypePtr TypecheckVisitor::makeFunctionType(FunctionStmt *stmt) { - // Handle generics - bool isClassMember = !stmt->attributes.parentClass.empty(); - auto explicits = std::vector(); - for (const auto &a : stmt->args) { - if (a.status == Param::Generic) { - // Generic and static types - auto generic = ctx->getUnbound(); - generic->isStatic = getStaticGeneric(a.type.get()); - auto typId = generic->getLink()->id; - generic->genericName = ctx->cache->rev(a.name); - if (a.defaultValue) { - auto defType = transformType(clone(a.defaultValue)); - generic->defaultType = defType->type; - } - ctx->add(TypecheckItem::Type, a.name, generic); - explicits.emplace_back(a.name, ctx->cache->rev(a.name), - generic->generalize(ctx->typecheckLevel), typId); - } - } - - // Prepare list of all generic types - std::vector generics; - ClassTypePtr parentClass = nullptr; - if (isClassMember && stmt->attributes.has(Attr::Method)) { - // Get class generics (e.g., T for `class Cls[T]: def foo:`) - auto parentClassAST = ctx->cache->classes[stmt->attributes.parentClass].ast.get(); - parentClass = ctx->forceFind(stmt->attributes.parentClass)->type->getClass(); - parentClass = - parentClass->instantiate(ctx->typecheckLevel - 1, nullptr, nullptr)->getClass(); - seqassert(parentClass, "parent class not set"); - for (int i = 0, j = 0, k = 0; i < parentClassAST->args.size(); i++) { - if (parentClassAST->args[i].status != Param::Normal) { - generics.push_back(parentClassAST->args[i].status == Param::Generic - ? parentClass->generics[j++].type - : parentClass->hiddenGenerics[k++].type); - ctx->add(TypecheckItem::Type, parentClassAST->args[i].name, generics.back()); - } + finalExpr = N(N(canonicalName), partialArgs); + // Add updated self reference in case function is recursive! + auto pa = partialArgs; + for (auto &a : pa) { + if (!a.getName().empty()) + a.value = N(a.getName()); + else + a.value = clone(a.getExpr()); } + // todo)) right now this adds a capture hook for recursive calls + selfAssign = N(N(stmt->getName()), pa); } - // Add function generics - for (const auto &i : explicits) - generics.push_back(ctx->find(i.name)->type); - // Handle function arguments - // Base type: `Function[[args,...], ret]` - auto baseType = getFuncTypeBase(stmt->args.size() - explicits.size()); - ctx->typecheckLevel++; - if (stmt->ret) { - unify(baseType->generics[1].type, transformType(stmt->ret)->getType()); - if (stmt->ret->isId("Union")) { - baseType->generics[1].type->getUnion()->generics[0].type->getUnbound()->kind = - LinkType::Generic; + // Parse remaining decorators + for (auto i = stmt->decorators.size(); i-- > 0;) { + if (stmt->decorators[i]) { + if (isClassMember) + E(Error::FN_NO_DECORATORS, stmt->decorators[i]); + // Replace each decorator with `decorator(finalExpr)` in the reverse order + finalExpr = N(stmt->decorators[i], + finalExpr ? finalExpr : N(canonicalName)); + // selfAssign = N(clone(stmt->decorators[i]), + // selfAssign ? selfAssign : N(canonicalName)); } + } + + if (selfAssign) + f->suite = + N(N(N(stmt->getName()), selfAssign), suite); + if (finalExpr) { + resultStmt = N( + f, transform(N(N(stmt->getName()), finalExpr))); } else { - generics.push_back(unify(baseType->generics[1].type, ctx->getUnbound())); - } - // Unify base type generics with argument types - auto argType = baseType->generics[0].type->getRecord(); - for (int ai = 0, aj = 0; ai < stmt->args.size(); ai++) { - if (stmt->args[ai].status == Param::Normal && !stmt->args[ai].type) { - if (parentClass && ai == 0 && ctx->cache->rev(stmt->args[ai].name) == "self") { - // Special case: self in methods - unify(argType->args[aj], parentClass); + resultStmt = f; + } +} + +/// Transform Python code blocks. +/// @example +/// ```@python +/// def foo(x: int, y) -> int: +/// [code] +/// ``` -> ``` +/// pyobj._exec("def foo(x, y): [code]") +/// from python import __main__.foo(int, _) -> int +/// ``` +Stmt *TypecheckVisitor::transformPythonDefinition(const std::string &name, + const std::vector &args, + Expr *ret, Stmt *codeStmt) { + seqassert(codeStmt && cast(codeStmt) && + cast(cast(codeStmt)->getExpr()), + "invalid Python definition"); + + auto code = cast(cast(codeStmt)->getExpr())->getValue(); + std::vector pyargs; + pyargs.reserve(args.size()); + for (const auto &a : args) + pyargs.emplace_back(a.getName()); + code = format("def {}({}):\n{}\n", name, join(pyargs, ", "), code); + return transform(N( + N( + N(N(N("pyobj"), "_exec"), N(code))), + N(N("python"), N(N("__main__"), name), + clone(args), ret ? clone(ret) : N("pyobj")))); +} + +/// Transform LLVM functions. +/// @example +/// ```@llvm +/// def foo(x: int) -> float: +/// [code] +/// ``` -> ``` +/// def foo(x: int) -> float: +/// StringExpr("[code]") +/// SuiteStmt(referenced_types) +/// ``` +/// As LLVM code can reference types and static expressions in `{=expr}` blocks, +/// all block expression will be stored in the `referenced_types` suite. +/// "[code]" is transformed accordingly: each `{=expr}` block will +/// be replaced with `{}` so that @c fmt::format can fill the gaps. +/// Note that any brace (`{` or `}`) that is not part of a block is +/// escaped (e.g. `{` -> `{{` and `}` -> `}}`) so that @c fmt::format can process them. +Stmt *TypecheckVisitor::transformLLVMDefinition(Stmt *codeStmt) { + StringExpr *codeExpr; + auto m = match(codeStmt, M(MVar(codeExpr))); + seqassert(m, "invalid LLVM definition"); + auto code = codeExpr->getValue(); + + std::vector items; + std::string finalCode; + items.push_back(nullptr); + + // Parse LLVM code and look for expression blocks that start with `{=` + int braceCount = 0, braceStart = 0; + for (int i = 0; i < code.size(); i++) { + if (i < code.size() - 1 && code[i] == '{' && code[i + 1] == '=') { + if (braceStart < i) + finalCode += escapeFStringBraces(code, braceStart, i - braceStart) + '{'; + if (!braceCount) { + braceStart = i + 2; + braceCount++; } else { - unify(argType->args[aj], ctx->getUnbound()); + E(Error::FN_BAD_LLVM, getSrcInfo()); } - generics.push_back(argType->args[aj++]); - } else if (stmt->args[ai].status == Param::Normal && - startswith(stmt->args[ai].name, "*")) { - // Special case: `*args: type` and `**kwargs: type`. Do not add this type to the - // signature (as the real type is `Tuple[type, ...]`); it will be used during call - // typechecking - unify(argType->args[aj], ctx->getUnbound()); - generics.push_back(argType->args[aj++]); - } else if (stmt->args[ai].status == Param::Normal) { - unify(argType->args[aj], transformType(stmt->args[ai].type)->getType()); - generics.push_back(argType->args[aj++]); + } else if (braceCount && code[i] == '}') { + braceCount--; + std::string exprCode = code.substr(braceStart, i - braceStart); + auto offset = getSrcInfo(); + offset.col += i; + auto exprOrErr = parseExpr(ctx->cache, exprCode, offset); + if (!exprOrErr) + throw exc::ParserException(exprOrErr.takeError()); + auto expr = transform(exprOrErr->first, true); + items.push_back(N(expr)); + braceStart = i + 1; + finalCode += '}'; } } - ctx->typecheckLevel--; - - // Generalize generics and remove them from the context - for (const auto &g : generics) { - for (auto &u : g->getUnbounds()) - if (u->getUnbound()) - u->getUnbound()->kind = LinkType::Generic; - } - - // Construct the type - auto funcTyp = std::make_shared( - baseType, ctx->cache->functions[stmt->name].ast.get(), explicits); + if (braceCount) + E(Error::FN_BAD_LLVM, getSrcInfo()); + if (braceStart != code.size()) + finalCode += escapeFStringBraces(code, braceStart, int(code.size()) - braceStart); + items[0] = N(N(finalCode)); + return N(items); +} - funcTyp->setSrcInfo(getSrcInfo()); - if (isClassMember && stmt->attributes.has(Attr::Method)) { - funcTyp->funcParent = ctx->find(stmt->attributes.parentClass)->type; +/// Fetch a decorator canonical name. The first pair member indicates if a decorator is +/// actually an attribute (a function with `@__attribute__`). +std::pair TypecheckVisitor::getDecorator(Expr *e) { + auto dt = transform(clone(e)); + auto id = cast(cast(dt) ? cast(dt)->getExpr() : dt); + if (id) { + auto ci = ctx->find(id->getValue(), getTime()); + if (ci && ci->isFunc()) { + auto fn = ci->getName(); + auto f = getFunction(fn); + if (!f) { + if (auto o = in(ctx->cache->overloads, fn)) { + if (o->size() == 1) + f = getFunction(o->front()); + } + } + if (f) + return {f->ast->hasAttribute(Attr::Attribute), ci->getName()}; + } } - funcTyp = - std::static_pointer_cast(funcTyp->generalize(ctx->typecheckLevel)); - return funcTyp; + return {false, ""}; } /// Make an empty partial call `fn(...)` for a given function. -ExprPtr TypecheckVisitor::partializeFunction(const types::FuncTypePtr &fn) { +Expr *TypecheckVisitor::partializeFunction(types::FuncType *fn) { // Create function mask - std::vector mask(fn->ast->args.size(), 0); - for (int i = 0, j = 0; i < fn->ast->args.size(); i++) - if (fn->ast->args[i].status == Param::Generic) { - if (!fn->funcGenerics[j].type->getUnbound()) + std::vector mask(fn->ast->size(), 0); + for (int i = 0, j = 0; i < fn->ast->size(); i++) + if ((*fn->ast)[i].isGeneric()) { + if (!extractFuncGeneric(fn, j)->getUnbound()) mask[i] = 1; j++; } // Generate partial class - auto partialTypeName = generatePartialStub(mask, fn.get()); - std::string var = ctx->cache->getTemporaryVar("partial"); - // Generate kwtuple for potential **kwargs - auto kwName = generateTuple(0, TYPE_KWTUPLE, {}); - // `partial = Partial.MASK((), KwTuple())` - // (`()` for *args and `KwTuple()` for **kwargs) - ExprPtr call = - N(N(N(var), - N(N(partialTypeName), N(), - N(N(kwName)))), - N(var)); - call->setAttr(ExprAttr::Partial); - transform(call); - seqassert(call->type->getPartial(), "expected partial type"); - return call; + return generatePartialCall(mask, fn); } /// Generate and return `Function[Tuple[args...], ret]` type -std::shared_ptr TypecheckVisitor::getFuncTypeBase(size_t nargs) { - auto baseType = ctx->instantiate(ctx->forceFind("Function")->type)->getRecord(); - unify(baseType->generics[0].type, ctx->instantiateTuple(nargs)->getRecord()); - return baseType; +std::shared_ptr TypecheckVisitor::getFuncTypeBase(size_t nargs) { + auto baseType = instantiateType(getStdLibType("Function")); + unify(extractClassGeneric(baseType->getClass()), + instantiateType(generateTuple(nargs, false))); + return std::static_pointer_cast(baseType); } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/import.cpp b/codon/parser/visitors/typecheck/import.cpp new file mode 100644 index 00000000..7ac745fe --- /dev/null +++ b/codon/parser/visitors/typecheck/import.cpp @@ -0,0 +1,403 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#include +#include +#include +#include + +#include "codon/parser/ast.h" +#include "codon/parser/common.h" +#include "codon/parser/match.h" +#include "codon/parser/peg/peg.h" +#include "codon/parser/visitors/scoping/scoping.h" +#include "codon/parser/visitors/typecheck/typecheck.h" + +using fmt::format; +using namespace codon::error; +using namespace codon::matcher; + +namespace codon::ast { + +/// Import and parse a new module into its own context. +/// Also handle special imports ( see @c transformSpecialImport ). +/// To simulate Python's dynamic import logic and import stuff only once, +/// each import statement is guarded as follows: +/// if not _import_N_done: +/// _import_N() +/// _import_N_done = True +/// See @c transformNewImport and below for more details. +void TypecheckVisitor::visit(ImportStmt *stmt) { + seqassert(!ctx->inClass(), "imports within a class"); + if ((resultStmt = transformSpecialImport(stmt))) + return; + + // Fetch the import + auto components = getImportPath(stmt->getFrom(), stmt->getDots()); + auto path = combine2(components, "/"); + auto file = getImportFile(getArgv(), path, ctx->getFilename(), false, + getRootModulePath(), getPluginImportPaths()); + if (!file) { + std::string s(stmt->getDots(), '.'); + for (auto &c : components) { + if (c == "..") { + continue; + } else if (!s.empty() && s.back() != '.') { + s += "." + c; + } else { + s += c; + } + } + bool allDot = true; + for (auto cp : s) + if (cp != '.') { + allDot = false; + break; + } + if (allDot && match(stmt->getWhat(), M())) + s = cast(stmt->getWhat())->getValue(); + E(Error::IMPORT_NO_MODULE, stmt, s); + } + + // If the file has not been seen before, load it into cache + bool handled = true; + if (!in(ctx->cache->imports, file->path)) { + resultStmt = transformNewImport(*file); + if (!resultStmt) + handled = false; // we need an import + } + + const auto &import = getImport(file->path); + std::string importVar = import->importVar; + if (!import->loadedAtToplevel) + handled = false; + + // Construct `if _import_done.__invert__(): (_import(); _import_done = True)`. + // Do not do this during the standard library loading (we assume that standard library + // imports are "clean" and do not need guards). Note that the importVar is empty if + // the import has been loaded during the standard library loading. + if (!handled) { + resultStmt = + N(N(N(fmt::format("{}_call.0", importVar)))); + LOG_TYPECHECK("[import] loading {}", importVar); + } + + // Import requested identifiers from the import's scope to the current scope + if (!stmt->getWhat()) { + // Case: import foo + auto name = stmt->as.empty() ? path : stmt->getAs(); + auto e = ctx->forceFind(importVar); + ctx->add(name, e); + } else if (cast(stmt->getWhat()) && + cast(stmt->getWhat())->getValue() == "*") { + // Case: from foo import * + seqassert(stmt->getAs().empty(), "renamed star-import"); + // Just copy all symbols from import's context here. + for (auto &[i, ival] : *(import->ctx)) { + if ((!startswith(i, "_") || (ctx->isStdlibLoading && startswith(i, "__")))) { + // Ignore all identifiers that start with `_` but not those that start with + // `__` while the standard library is being loaded + auto c = ival.front(); + if (c->isConditional() && i.find('.') == std::string::npos) + c = import->ctx->find(i); + // Imports should ignore noShadow property + ctx->add(i, c); + } + } + } else { + // Case 3: from foo import bar + auto i = cast(stmt->getWhat()); + seqassert(i, "not a valid import what expression"); + auto c = import->ctx->find(i->getValue()); + // Make sure that we are importing an existing global symbol + if (!c) + E(Error::IMPORT_NO_NAME, i, i->getValue(), file->module); + if (c->isConditional()) + c = import->ctx->find(i->getValue()); + // Imports should ignore noShadow property + ctx->add(stmt->getAs().empty() ? i->getValue() : stmt->getAs(), c); + } + resultStmt = transform(!resultStmt ? N() : resultStmt); // erase it +} + +/// Transform special `from C` and `from python` imports. +/// See @c transformCImport, @c transformCDLLImport and @c transformPythonImport +Stmt *TypecheckVisitor::transformSpecialImport(ImportStmt *stmt) { + if (auto fi = cast(stmt->getFrom())) { + if (fi->getValue() == "C") { + auto wi = cast(stmt->getWhat()); + if (wi && !stmt->isCVar()) { + // C function imports + return transformCImport(wi->getValue(), stmt->getArgs(), stmt->getReturnType(), + stmt->getAs()); + } else if (wi) { + // C variable imports + return transformCVarImport(wi->getValue(), stmt->getReturnType(), + stmt->getAs()); + } else if (auto de = cast(stmt->getWhat())) { + // dylib C imports + return transformCDLLImport(de->getExpr(), de->getMember(), stmt->getArgs(), + stmt->getReturnType(), stmt->getAs(), + !stmt->isCVar()); + } + } else if (fi->getValue() == "python" && stmt->getWhat()) { + // Python imports + return transformPythonImport(stmt->getWhat(), stmt->getArgs(), + stmt->getReturnType(), stmt->getAs()); + } + } + return nullptr; +} + +/// Transform Dot(Dot(a, b), c...) into "{a, b, c, ...}". +/// Useful for getting import paths. +std::vector TypecheckVisitor::getImportPath(Expr *from, size_t dots) { + std::vector components; // Path components + if (from) { + for (; cast(from); from = cast(from)->getExpr()) + components.push_back(cast(from)->getMember()); + seqassert(cast(from), "invalid import statement"); + components.push_back(cast(from)->getValue()); + } + + // Handle dots (i.e., `..` in `from ..m import x`) + for (size_t i = 1; i < dots; i++) + components.emplace_back(".."); + std::reverse(components.begin(), components.end()); + return components; +} + +/// Transform a C function import. +/// @example +/// `from C import foo(int) -> float as f` -> +/// ```@.c +/// def foo(a1: int) -> float: +/// pass +/// f = foo # if altName is provided``` +/// No return type implies void return type. *args is treated as C VAR_ARGS. +Stmt *TypecheckVisitor::transformCImport(const std::string &name, + const std::vector &args, Expr *ret, + const std::string &altName) { + std::vector fnArgs; + bool hasVarArgs = false; + for (size_t ai = 0; ai < args.size(); ai++) { + seqassert(args[ai].getName().empty(), "unexpected argument name"); + seqassert(!args[ai].getDefault(), "unexpected default argument"); + seqassert(args[ai].getType(), "missing type"); + if (cast(args[ai].getType()) && ai + 1 == args.size()) { + // C VAR_ARGS support + hasVarArgs = true; + fnArgs.emplace_back("*args", nullptr, nullptr); + } else { + fnArgs.emplace_back(args[ai].getName().empty() ? format("a{}", ai) + : args[ai].getName(), + clone(args[ai].getType()), nullptr); + } + } + ctx->generateCanonicalName(name); // avoid canonicalName == name + Stmt *f = + N(name, ret ? clone(ret) : N("NoneType"), fnArgs, nullptr); + f->setAttribute(Attr::C); + if (hasVarArgs) + f->setAttribute(Attr::CVarArg); + f = transform(f); // Already in the preamble + if (!altName.empty()) { + auto v = ctx->find(altName); + auto val = ctx->forceFind(name); + ctx->add(altName, val); + ctx->remove(name); + } + return f; +} + +/// Transform a C variable import. +/// @example +/// `from C import foo: int as f` -> +/// ```f: int = "foo"``` +Stmt *TypecheckVisitor::transformCVarImport(const std::string &name, Expr *type, + const std::string &altName) { + auto canonical = ctx->generateCanonicalName(name); + auto typ = transformType(clone(type)); + auto val = ctx->addVar( + altName.empty() ? name : altName, canonical, + std::make_shared(extractClassType(typ)->shared_from_this())); + val->time = getTime(); + auto s = N(N(canonical), nullptr, typ); + s->lhs->setAttribute(Attr::ExprExternVar); + s->lhs->setType(val->type); + s->lhs->setDone(); + s->setDone(); + return s; +} + +/// Transform a dynamic C import. +/// @example +/// `from C import lib.foo(int) -> float as f` -> +/// `f = _dlsym(lib, "foo", Fn=Function[[int], float]); f` +/// No return type implies void return type. +Stmt *TypecheckVisitor::transformCDLLImport(Expr *dylib, const std::string &name, + const std::vector &args, Expr *ret, + const std::string &altName, + bool isFunction) { + Expr *type = nullptr; + if (isFunction) { + std::vector fnArgs{N(), ret ? clone(ret) : N("NoneType")}; + for (const auto &a : args) { + seqassert(a.getName().empty(), "unexpected argument name"); + seqassert(!a.getDefault(), "unexpected default argument"); + seqassert(a.getType(), "missing type"); + cast(fnArgs[0])->items.emplace_back(clone(a.getType())); + } + type = N(N("Function"), N(fnArgs)); + } else { + type = clone(ret); + } + + Expr *c = clone(dylib); + return transform(N( + N(altName.empty() ? name : altName), + N(N("_dlsym"), + std::vector{ + CallArg(c), CallArg(N(name)), {"Fn", type}}))); +} + +/// Transform a Python module and function imports. +/// @example +/// `from python import module as f` -> `f = pyobj._import("module")` +/// `from python import lib.foo(int) -> float as f` -> +/// ```def f(a0: int) -> float: +/// f = pyobj._import("lib")._getattr("foo") +/// return float.__from_py__(f(a0))``` +/// If a return type is nullptr, the function just returns f (raw pyobj). +Stmt *TypecheckVisitor::transformPythonImport(Expr *what, + const std::vector &args, Expr *ret, + const std::string &altName) { + // Get a module name (e.g., os.path) + auto components = getImportPath(what); + + if (!ret && args.empty()) { + // Simple import: `from python import foo.bar` -> `bar = pyobj._import("foo.bar")` + return transform( + N(N(altName.empty() ? components.back() : altName), + N(N(N("pyobj"), "_import"), + N(combine2(components, "."))))); + } + + // Python function import: + // `from python import foo.bar(int) -> float` -> + // ```def bar(a1: int) -> float: + // f = pyobj._import("foo")._getattr("bar") + // return float.__from_py__(f(a1))``` + + // f = pyobj._import("foo")._getattr("bar") + auto call = N( + N("f"), + N( + N(N(N(N("pyobj"), "_import"), + N(combine2(components, ".", 0, + int(components.size()) - 1))), + "_getattr"), + N(components.back()))); + // f(a1, ...) + std::vector params; + std::vector callArgs; + for (int i = 0; i < args.size(); i++) { + params.emplace_back(format("a{}", i), clone(args[i].getType()), nullptr); + callArgs.emplace_back(N(format("a{}", i))); + } + // `return ret.__from_py__(f(a1, ...))` + auto retType = (ret && !cast(ret)) ? clone(ret) : N("NoneType"); + auto retExpr = N(N(clone(retType), "__from_py__"), + N(N(N("f"), callArgs), "p")); + auto retStmt = N(retExpr); + // Create a function + return transform(N(altName.empty() ? components.back() : altName, + retType, params, N(call, retStmt))); +} + +/// Import a new file into its own context and wrap its top-level statements into a +/// function to support Python-like runtime import loading. +/// @example +/// ```_import_[I]_done = False +/// def _import_[I](): +/// global [imported global variables]... +/// __name__ = [I] +/// [imported top-level statements]``` +Stmt *TypecheckVisitor::transformNewImport(const ImportFile &file) { + // Use a clean context to parse a new file + auto moduleID = file.module; + std::replace(moduleID.begin(), moduleID.end(), '.', '_'); + auto ictx = std::make_shared(ctx->cache, file.path); + ictx->isStdlibLoading = ctx->isStdlibLoading; + ictx->moduleName = file; + auto import = + ctx->cache->imports.insert({file.path, {file.module, file.path, ictx}}).first; + import->second.loadedAtToplevel = + getImport(ctx->moduleName.path)->loadedAtToplevel && + (ctx->isStdlibLoading || (ctx->isGlobal() && ctx->scope.size() == 1)); + auto importVar = import->second.importVar = + getTemporaryVar(format("import_{}", moduleID)); + LOG_TYPECHECK("[import] initializing {} ({})", importVar, + import->second.loadedAtToplevel); + + // __name__ = [import name] + Stmt *n = nullptr; + if (file.module != "internal.core") { + // str is not defined when loading internal.core; __name__ is not needed anyway + n = N(N("__name__"), N(ictx->moduleName.module)); + ctx->addBlock(); + preamble->push_back(transform( + N(N(importVar), + N(N("Import.__new__"), N(false), + N(file.path), N(file.module)), + N("Import")))); + auto val = ctx->forceFind(importVar); + ctx->popBlock(); + val->scope = {0}; + val->baseName = ""; + val->moduleName = MODULE_MAIN; + getImport(STDLIB_IMPORT)->ctx->addToplevel(importVar, val); + registerGlobal(val->getName()); + } + auto nodeOrErr = parseFile(ctx->cache, file.path); + if (!nodeOrErr) + throw exc::ParserException(nodeOrErr.takeError()); + n = N(n, *nodeOrErr); + auto tv = TypecheckVisitor(ictx, preamble); + if (auto err = ScopingVisitor::apply(ctx->cache, n)) + throw exc::ParserException(std::move(err)); + + if (!ctx->cache->errors.empty()) + throw exc::ParserException(ctx->cache->errors); + // Add comment to the top of import for easier dump inspection + auto comment = N(format("import: {} at {}", file.module, file.path)); + auto suite = N(comment, n); + + if (ctx->isStdlibLoading) { + // When loading the standard library, imports are not wrapped. + // We assume that the standard library has no recursive imports and that all + // statements are executed before the user-provided code. + return tv.transform(suite); + } else { + // Generate import identifier + auto stmts = N(); + auto ret = N(); + ret->setAttribute(Attr::Internal); // do not trigger toplevel ReturnStmt error + stmts->addStmt(N(N(N(importVar), "loaded"), ret)); + stmts->addStmt(N( + N(N("Import._set_loaded"), + N(N("__ptr__"), N(importVar))))); + stmts->addStmt(suite); + + // Wrap all imported top-level statements into a function. + auto fnName = fmt::format("{}_call", importVar); + Stmt *fn = + N(fnName, N("NoneType"), std::vector{}, stmts); + fn = tv.transform(fn); + tv.realize(ictx->forceFind(fnName)->getType()); + preamble->push_back(fn); + // LOG_USER("[import] done importing {}", file.module); + } + return nullptr; +} + +} // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 03aebfec..b306e457 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -7,10 +7,12 @@ #include #include +#include "codon/cir/attribute.h" #include "codon/cir/types/types.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/visitors/scoping/scoping.h" +#include "codon/parser/visitors/translate/translate.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -28,36 +30,36 @@ using namespace types; /// @param a Type (by reference) /// @param b Type /// @return a -TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b) { - if (!a) - return a = b; - seqassert(b, "rhs is nullptr"); - types::Type::Unification undo; - if (a->unify(b.get(), &undo) >= 0) { - return a; - } else { - undo.undo(); +Type *TypecheckVisitor::unify(Type *a, Type *b) { + seqassert(a, "lhs is nullptr"); + if (!((*a) << b)) { + types::Type::Unification undo; + a->unify(b, &undo); + E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString()); + return nullptr; } - a->unify(b.get(), &undo); - E(Error::TYPE_UNIFY, getSrcInfo(), a->prettyString(), b->prettyString()); - return nullptr; + return a; } -/// Infer all types within a StmtPtr. Implements the LTS-DI typechecking. +/// Infer all types within a Stmt *. Implements the LTS-DI typechecking. /// @param isToplevel set if typechecking the program toplevel. -StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { +Stmt *TypecheckVisitor::inferTypes(Stmt *result, bool isToplevel) { if (!result) return nullptr; - for (ctx->getRealizationBase()->iteration = 1;; - ctx->getRealizationBase()->iteration++) { - LOG_TYPECHECK("[iter] {} :: {}", ctx->getRealizationBase()->name, - ctx->getRealizationBase()->iteration); - if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER) { - error(result, "cannot typecheck '{}' in reasonable time", - ctx->getRealizationBase()->name.empty() - ? "toplevel" - : ctx->cache->rev(ctx->getRealizationBase()->name)); + for (ctx->getBase()->iteration = 1;; ctx->getBase()->iteration++) { + LOG_TYPECHECK("[iter] {} :: {}", ctx->getBase()->name, ctx->getBase()->iteration); + if (ctx->getBase()->iteration >= MAX_TYPECHECK_ITER) { + // log("-> {}", result->toString(2)); + ParserErrors errors; + errors.addError({ErrorMessage{fmt::format("cannot typecheck '{}' in reasonable time", + ctx->getBase()->name.empty() + ? "toplevel" + : getUnmangledName(ctx->getBase()->name)), + result->getSrcInfo()}}); + for (auto &error : findTypecheckErrors(result)) + errors.addError(error); + throw exc::ParserException(errors); } // Keep iterating until: @@ -69,20 +71,21 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { ctx->changedNodes = 0; auto returnEarly = ctx->returnEarly; ctx->returnEarly = false; - TypecheckVisitor(ctx).transform(result); + auto tv = TypecheckVisitor(ctx, preamble); + result = tv.transform(result); std::swap(ctx->changedNodes, changedNodes); std::swap(ctx->returnEarly, returnEarly); ctx->typecheckLevel--; - if (ctx->getRealizationBase()->iteration == 1 && isToplevel) { + if (ctx->getBase()->iteration == 1 && isToplevel) { // Realize all @force_realize functions for (auto &f : ctx->cache->functions) { - auto &attr = f.second.ast->attributes; + auto ast = f.second.ast; if (f.second.type && f.second.realizations.empty() && - (attr.has(Attr::ForceRealize) || attr.has(Attr::Export) || - (attr.has(Attr::C) && !attr.has(Attr::CVarArg)))) { + (ast->hasAttribute(Attr::ForceRealize) || ast->hasAttribute(Attr::Export) || + (ast->hasAttribute(Attr::C) && !ast->hasAttribute(Attr::CVarArg)))) { seqassert(f.second.type->canRealize(), "cannot realize {}", f.first); - realize(ctx->instantiate(f.second.type)->getFunc()); + realize(instantiateType(f.second.getType())); seqassert(!f.second.realizations.empty(), "cannot realize {}", f.first); } } @@ -90,12 +93,13 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { if (result->isDone()) { // Special union case: if union cannot be inferred return type is Union[NoneType] - if (auto tr = ctx->getRealizationBase()->returnType) { + if (auto tr = ctx->getBase()->returnType) { if (auto tu = tr->getUnion()) { if (!tu->isSealed()) { if (tu->pendingTypes[0]->getLink() && tu->pendingTypes[0]->getLink()->kind == LinkType::Unbound) { - tu->addType(ctx->forceFind("NoneType")->type); + auto r = tu->addType(getStdLibType("NoneType")); + seqassert(r, "cannot add type to union {}", tu->debugString(2)); tu->seal(); } } @@ -110,26 +114,43 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { // their default values and then run another round to see if anything changed. bool anotherRound = false; // Special case: return type might have default as well (e.g., Union) - if (ctx->getRealizationBase()->returnType) - ctx->getRealizationBase()->pendingDefaults.insert( - ctx->getRealizationBase()->returnType); - for (auto &unbound : ctx->getRealizationBase()->pendingDefaults) { - if (auto tu = unbound->getUnion()) { - // Seal all dynamic unions after the iteration is over - if (!tu->isSealed()) { - tu->seal(); - anotherRound = true; + if (auto t = ctx->getBase()->returnType) { + ctx->getBase()->pendingDefaults[0].insert(t); + } + // First unify "explicit" generics (whose default type is explicit), + // then "implicit" ones (whose default type is compiler generated, + // e.g. compiler-generated variable placeholders with default NoneType) + for (auto &[level, unbounds] : ctx->getBase()->pendingDefaults) { + if (!unbounds.empty()) { + for (const auto &unbound : unbounds) { + if (auto tu = unbound->getUnion()) { + // Seal all dynamic unions after the iteration is over + if (!tu->isSealed()) { + tu->seal(); + anotherRound = true; + } + } else if (auto u = unbound->getLink()) { + types::Type::Unification undo; + if (u->defaultType) { + if (u->defaultType->getClass()) { // type[...] + if (u->unify(extractClassType(u->defaultType.get()), &undo) >= 0) { + anotherRound = true; + } + } else { // generic + if (u->unify(u->defaultType.get(), &undo) >= 0) { + anotherRound = true; + } + } + } + } } - } else if (auto u = unbound->getLink()) { - types::Type::Unification undo; - if (u->defaultType && u->unify(u->defaultType.get(), &undo) >= 0) - anotherRound = true; + unbounds.clear(); + if (anotherRound) + break; } } - ctx->getRealizationBase()->pendingDefaults.clear(); if (anotherRound) continue; - // Nothing helps. Return nullptr. return nullptr; } @@ -141,66 +162,62 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { /// Realize a type and create IR type stub. If type is a function type, also realize the /// underlying function and generate IR function stub. /// @return realized type or nullptr if the type cannot be realized -types::TypePtr TypecheckVisitor::realize(types::TypePtr typ) { +types::Type *TypecheckVisitor::realize(types::Type *typ) { if (!typ || !typ->canRealize()) { return nullptr; } - if (typ->getStatic()) { - // Nothing to realize here - return typ; - } - try { if (auto f = typ->getFunc()) { - if (auto ret = realizeFunc(f.get())) { + // Cache::CTimer t(ctx->cache, f->realizedName()); + if (auto ret = realizeFunc(f)) { // Realize Function[..] type as well - realizeType(ret->getClass().get()); - return unify(ret, typ); // Needed for return type unification + auto t = std::make_shared(ret->getClass()); + realizeType(t.get()); + // Needed for return type unification + unify(f->getRetType(), extractClassGeneric(ret, 1)); + return ret; } } else if (auto c = typ->getClass()) { - auto t = realizeType(c.get()); - if (auto p = typ->getPartial()) { - // Ensure that the partial type is preserved - t = std::make_shared(t->getRecord(), p->func, p->known); - } - if (t) { - return unify(t, typ); - } + auto t = realizeType(c); + return t; } - } catch (exc::ParserException &e) { - if (e.errorCode == Error::MAX_REALIZATION) + } catch (exc::ParserException &exc) { + seqassert(!exc.getErrors().empty(), "empty error trace"); + auto &bt = exc.getErrors().back(); + if (bt.front().getErrorCode() == Error::MAX_REALIZATION) throw; if (auto f = typ->getFunc()) { - if (f->ast->attributes.has(Attr::HiddenFromUser)) { - e.locations.back() = getSrcInfo(); + if (f->ast->hasAttribute(Attr::HiddenFromUser)) { + bt.back().setSrcInfo(getSrcInfo()); } else { std::vector args; - for (size_t i = 0, ai = 0, gi = 0; i < f->ast->args.size(); i++) { - auto an = f->ast->args[i].name; - auto ns = trimStars(an); - args.push_back(fmt::format("{}{}: {}", std::string(ns, '*'), - ctx->cache->rev(an), - f->ast->args[i].status == Param::Generic - ? f->funcGenerics[gi++].type->prettyString() - : f->getArgTypes()[ai++]->prettyString())); + for (size_t i = 0, ai = 0, gi = 0; i < f->ast->size(); i++) { + auto [ns, n] = (*f->ast)[i].getNameWithStars(); + args.push_back(fmt::format( + "{}{}: {}", std::string(ns, '*'), getUnmangledName(n), + (*f->ast)[i].isGeneric() ? extractFuncGeneric(f, gi++)->prettyString() + : extractFuncArgType(f, ai++)->prettyString())); } auto name = f->ast->name; std::string name_args; - if (startswith(name, "._import_")) { - name = name.substr(9); - auto p = name.rfind('_'); - if (p != std::string::npos) - name = name.substr(0, p); - name = ""; + if (startswith(name, "%_import_")) { + for (auto &[_, i] : ctx->cache->imports) + if (i.importVar + "_call.0:0" == name) { + name = i.name; + break; + } + name = fmt::format("", name); } else { - name = ctx->cache->rev(f->ast->name); + name = getUnmangledName(f->ast->name); name_args = fmt::format("({})", fmt::join(args, ", ")); } - e.trackRealize(fmt::format("{}{}", name, name_args), getSrcInfo()); + bt.addMessage(fmt::format("during the realization of {}{}", name, name_args), + getSrcInfo()); } } else { - e.trackRealize(typ->prettyString(), getSrcInfo()); + bt.addMessage(fmt::format("during the realization of {}", typ->prettyString()), + getSrcInfo()); } throw; } @@ -209,140 +226,199 @@ types::TypePtr TypecheckVisitor::realize(types::TypePtr typ) { /// Realize a type and create IR type stub. /// @return realized type or nullptr if the type cannot be realized -types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) { +types::Type *TypecheckVisitor::realizeType(types::ClassType *type) { if (!type || !type->canRealize()) return nullptr; - - if (auto tr = type->getRecord()) - tr->flatten(); - // Check if the type fields are all initialized // (sometimes that's not the case: e.g., `class X: x: List[X]`) - for (auto field : getClassFields(type)) { - if (!field.type) - return nullptr; + + // generalize generics to ensure that they do not get unified later! + if (type->is("unrealized_type")) + type->generics[0].type = extractClassGeneric(type)->generalize(0); + + if (type->is("__NTuple__")) { + auto n = std::max(int64_t(0), getIntLiteral(type)); + auto tt = extractClassGeneric(type, 1)->getClass(); + std::vector generics; + auto t = instantiateType(generateTuple(n * tt->generics.size())); + for (size_t i = 0, j = 0; i < n; i++) + for (const auto &ttg : tt->generics) { + unify(t->generics[j].getType(), ttg.getType()); + generics.push_back(t->generics[j]); + j++; + } + type->name = TYPE_TUPLE; + type->niceName = t->niceName; + type->generics = generics; + type->_rn = ""; } // Check if the type was already realized - if (auto r = - in(ctx->cache->classes[type->name].realizations, type->realizedTypeName())) { + auto rn = type->ClassType::realizedName(); + auto cls = getClass(type); + if (auto r = in(cls->realizations, rn)) { return (*r)->type->getClass(); } auto realized = type->getClass(); - if (type->getFunc()) { - // Just realize the function stub - realized = std::make_shared(realized, type->getFunc()->args); - } - - // Realize generics - for (auto &e : realized->generics) { - if (!realize(e.type)) + auto fields = getClassFields(realized); + if (!cls->ast) + return nullptr; // not yet done! + auto fTypes = getClassFieldTypes(realized); + for (auto &field : fTypes) { + if (!field) return nullptr; } - LOG_REALIZE("[realize] ty {} -> {}", realized->name, realized->realizedTypeName()); + if (auto s = type->getStatic()) + realized = + s->getNonStaticType()->getClass(); // do not cache static but its root type! + + // Realize generics + if (!type->is("unrealized_type")) + for (auto &e : realized->generics) { + if (!realize(e.getType())) + return nullptr; + if (e.type->getFunc() && !e.type->getFunc()->getRetType()->canRealize()) + return nullptr; + } // Realizations should always be visible, so add them to the toplevel - ctx->addToplevel(realized->realizedTypeName(), - std::make_shared(TypecheckItem::Type, realized)); - auto realization = - ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()] = - std::make_shared(); - realization->type = realized; + rn = type->ClassType::realizedName(); + auto rt = std::static_pointer_cast(realized->generalize(0)); + auto val = std::make_shared(rn, "", ctx->getModule(), rt); + if (!val->type->is(TYPE_TYPE)) + val->type = instantiateTypeVar(realized); + ctx->addAlwaysVisible(val, true); + auto realization = getClass(realized)->realizations[rn] = + std::make_shared(); + realization->type = rt; realization->id = ctx->cache->classRealizationCnt++; - // Realize tuple arguments - if (auto tr = realized->getRecord()) { - for (auto &a : tr->args) - realize(a); - } - // Create LLVM stub - auto lt = makeIRType(realized.get()); + auto lt = makeIRType(realized); // Realize fields std::vector typeArgs; // needed for IR std::vector names; // needed for IR std::map memberInfo; // needed for IR - if (realized->is(TYPE_TUPLE)) - realized->getRecord()->flatten(); - int i = 0; - for (auto &field : getClassFields(realized.get())) { - auto ftyp = ctx->instantiate(field.type, realized); - // HACK: repeated tuples have no generics so this is needed to fix the instantiation - // above - if (realized->is(TYPE_TUPLE)) - unify(ftyp, realized->getRecord()->args[i]); - - if (!realize(ftyp)) - E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), field.name, - ftyp->prettyString()); - realization->fields.emplace_back(field.name, ftyp); - names.emplace_back(field.name); - typeArgs.emplace_back(makeIRType(ftyp->getClass().get())); - memberInfo[field.name] = field.type->getSrcInfo(); - i++; + for (size_t i = 0; i < fTypes.size(); i++) { + if (!realize(fTypes[i].get())) { + // realize(fTypes[i].get()); + E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), fields[i].name, + realized->prettyString()); + } + // LOG_REALIZE("- member: {} -> {}: {}", field.name, field.type, fTypes[i]); + realization->fields.emplace_back(fields[i].name, fTypes[i]); + names.emplace_back(fields[i].name); + typeArgs.emplace_back(makeIRType(fTypes[i]->getClass())); + memberInfo[fields[i].name] = fTypes[i]->getSrcInfo(); } // Set IR attributes - if (auto *cls = ir::cast(lt)) - if (!names.empty()) { + if (!names.empty()) { + if (auto *cls = cast(lt)) { cls->getContents()->realize(typeArgs, names); cls->setAttribute(std::make_unique(memberInfo)); cls->getContents()->setAttribute( std::make_unique(memberInfo)); } - - // Fix for partial types - if (auto p = type->getPartial()) { - auto pt = std::make_shared(realized->getRecord(), p->func, p->known); - ctx->addToplevel(pt->realizedName(), - std::make_shared(TypecheckItem::Type, pt)); - ctx->cache->classes[pt->name].realizations[pt->realizedName()] = - ctx->cache->classes[realized->name].realizations[realized->realizedTypeName()]; } - return realized; + return rt.get(); } -types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { - auto &realizations = ctx->cache->functions[type->ast->name].realizations; +types::Type *TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) { + auto module = type->ast->getAttribute(Attr::Module)->value; + auto &realizations = getFunction(type->getFuncName())->realizations; + auto imp = getImport(module); if (auto r = in(realizations, type->realizedName())) { if (!force) { - return (*r)->type; + return (*r)->getType(); } } + auto oldCtx = this->ctx; + this->ctx = imp->ctx; if (ctx->getRealizationDepth() > MAX_REALIZATION_DEPTH) { - E(Error::MAX_REALIZATION, getSrcInfo(), ctx->cache->rev(type->ast->name)); + E(Error::MAX_REALIZATION, getSrcInfo(), getUnmangledName(type->getFuncName())); } - LOG_REALIZE("[realize] fn {} -> {} : base {} ; depth = {}", type->ast->name, - type->realizedName(), ctx->getRealizationStackName(), - ctx->getRealizationDepth()); - getLogger().level++; - ctx->addBlock(); - ctx->typecheckLevel++; + bool isImport = isImportFn(type->getFuncName()); + if (!isImport) { + getLogger().level++; + ctx->addBlock(); + ctx->typecheckLevel++; + ctx->bases.push_back({type->getFuncName(), type->getFunc()->shared_from_this(), + type->getRetType()->shared_from_this()}); + for (size_t t = ctx->bases.size() - 1; t-- > 0;) { + if (startswith(ctx->getBaseName(), ctx->bases[t].name)) { + ctx->getBase()->parent = t; + break; + } + } + // LOG("[realize] F {} -> {} : base {} ; depth = {} ; ctx-base: {}; ret = {}; " + // "parent = {}", + // type->getFuncName(), type->realizedName(), ctx->getRealizationStackName(), + // ctx->getRealizationDepth(), ctx->getBaseName(), + // ctx->getBase()->returnType->debugString(2), + // ctx->bases[ctx->getBase()->parent].name); + } - // Find function parents - ctx->realizationBases.push_back( - {type->ast->name, type->getFunc(), type->getRetType()}); + // Types might change after realization, fix it + for (auto &t : *type) + realizeType(t.getType()->getClass()); // Clone the generic AST that is to be realized - auto ast = generateSpecialAst(type); - addFunctionGenerics(type); + auto ast = clean_clone(type->ast); + if (auto s = generateSpecialAst(type)) + ast->suite = s; + addClassGenerics(type, true); + ctx->getBase()->func = ast; // Internal functions have no AST that can be realized - bool hasAst = ast->suite && !ast->attributes.has(Attr::Internal); + bool hasAst = ast->getSuite() && !ast->hasAttribute(Attr::Internal); // Add function arguments - for (size_t i = 0, j = 0; hasAst && i < ast->args.size(); i++) - if (ast->args[i].status == Param::Normal) { - std::string varName = ast->args[i].name; - trimStars(varName); - ctx->add(TypecheckItem::Var, varName, - std::make_shared(type->getArgTypes()[j++])); + if (auto b = ast->getAttribute(Attr::Bindings)) + for (auto &[c, t] : b->captures) { + if (t == BindingsAttribute::CaptureType::Global) { + auto cp = ctx->find(c); + if (!cp) + E(Error::ID_NOT_FOUND, getSrcInfo(), c); + if (!cp->isGlobal()) + E(Error::FN_GLOBAL_NOT_FOUND, getSrcInfo(), "global", c); + } + } + // Add self [recursive] reference! TODO: maybe remove later when doing contexts? + auto pc = ast->getAttribute(Attr::ParentClass); + if (!pc || pc->value.empty()) { + // Check if we already exist? + bool exists = false; + auto val = ctx->find(getUnmangledName(ast->getName())); + if (val && val->getType()->getFunc()) { + auto fn = getFunction(val->getType()); + exists = fn->rootName == getFunction(type)->rootName; + } + if (!exists) { + ctx->addFunc(getUnmangledName(ast->getName()), ast->getName(), + ctx->forceFind(ast->getName())->type); + } + } + for (size_t i = 0, j = 0; hasAst && i < ast->size(); i++) { + if ((*ast)[i].isValue()) { + auto [_, varName] = (*ast)[i].getNameWithStars(); + TypePtr at = extractFuncArgType(type, j++)->shared_from_this(); + bool isStatic = ast && getStaticGeneric((*ast)[i].getType()); + if (!isStatic && at && at->getStatic()) + at = at->getStatic()->getNonStaticType()->shared_from_this(); + if (at->is("TypeWrap")) { + ctx->addType(getUnmangledName(varName), varName, + instantiateTypeVar(extractClassGeneric(at.get()))); + } else { + ctx->addVar(getUnmangledName(varName), varName, std::make_shared(at)); + } } + } // Populate realization table in advance to support recursive realizations auto key = type->realizedName(); // note: the key might change later @@ -350,307 +426,146 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) if (auto i = in(realizations, key)) oldIR = (*i)->ir; auto r = realizations[key] = std::make_shared(); - r->type = type->getFunc(); + r->type = std::static_pointer_cast(type->shared_from_this()); r->ir = oldIR; + if (auto b = ast->getAttribute(Attr::Bindings)) + for (auto &[c, _] : b->captures) { + auto h = ctx->find(c); + r->captures.push_back(h ? h->canonicalName : ""); + } // Realizations should always be visible, so add them to the toplevel - ctx->addToplevel( - key, std::make_shared(TypecheckItem::Func, type->getFunc())); + auto val = std::make_shared(key, "", ctx->getModule(), + type->shared_from_this()); + ctx->addAlwaysVisible(val, true); + ctx->getBase()->suite = ast->getSuite(); if (hasAst) { auto oldBlockLevel = ctx->blockLevel; ctx->blockLevel = 0; - auto ret = inferTypes(ast->suite); + auto ret = inferTypes(ctx->getBase()->suite); ctx->blockLevel = oldBlockLevel; if (!ret) { realizations.erase(key); - ctx->realizationBases.pop_back(); - ctx->popBlock(); - ctx->typecheckLevel--; - getLogger().level--; - if (!startswith(ast->name, "._lambda")) { + ParserErrors errors; + if (!startswith(ast->name, "%_lambda")) { // Lambda typecheck failures are "ignored" as they are treated as statements, // not functions. // TODO: generalize this further. - LOG_REALIZE("[error] {}", ast->suite->toString(2)); - error("cannot typecheck the program"); + errors = findTypecheckErrors(ctx->getBase()->suite); + } + if (!isImport) { + ctx->bases.pop_back(); + ctx->popBlock(); + ctx->typecheckLevel--; + getLogger().level--; + } + if (!errors.empty()) { + throw exc::ParserException(errors); } + this->ctx = oldCtx; return nullptr; // inference must be delayed + } else { + ctx->getBase()->suite = ret; } - // Use NoneType as the return type when the return type is not specified and // function has no return statement - if (!ast->ret && type->getRetType()->getUnbound()) - unify(type->getRetType(), ctx->getType("NoneType")); + if (!ast->ret && isUnbound(type->getRetType())) { + unify(type->getRetType(), getStdLibType("NoneType")); + } } // Realize the return type auto ret = realize(type->getRetType()); - seqassert(ret, "cannot realize return type '{}'", type->getRetType()); + if (ast->hasAttribute(Attr::RealizeWithoutSelf) && + !extractFuncArgType(type)->canRealize()) { // For RealizeWithoutSelf + realizations.erase(key); + ctx->bases.pop_back(); + ctx->popBlock(); + ctx->typecheckLevel--; + getLogger().level--; + return nullptr; + } + seqassert(ret, "cannot realize return type '{}'", *(type->getRetType())); + + // LOG("[realize] F {} -> {} -> {}", type->getFuncName(), type->debugString(2), + // type->realizedName()); std::vector args; - for (auto &i : ast->args) { - std::string varName = i.name; - trimStars(varName); - args.emplace_back(Param{varName, nullptr, nullptr, i.status}); + for (auto &i : *ast) { + auto [_, varName] = i.getNameWithStars(); + args.emplace_back(varName, nullptr, nullptr, i.status); } - r->ast = N(ast->getSrcInfo(), r->type->realizedName(), nullptr, args, - ast->suite); - r->ast->attributes = ast->attributes; - - if (!in(ctx->cache->pendingRealizations, - make_pair(type->ast->name, type->realizedName()))) { + r->ast = + N(r->type->realizedName(), nullptr, args, ctx->getBase()->suite); + r->ast->setSrcInfo(ast->getSrcInfo()); + r->ast->cloneAttributesFrom(ast); + + auto newKey = type->realizedName(); + if (newKey != key) { + LOG("!! oldKey={}, newKey={}", key, newKey); + } + if (!in(ctx->cache->pendingRealizations, make_pair(type->getFuncName(), newKey))) { if (!r->ir) r->ir = makeIRFunction(r); - realizations[type->realizedName()] = r; + realizations[newKey] = r; } else { - realizations[key] = realizations[type->realizedName()]; + realizations[key] = realizations[newKey]; } if (force) - realizations[type->realizedName()]->ast = r->ast; - ctx->addToplevel(type->realizedName(), std::make_shared( - TypecheckItem::Func, type->getFunc())); - ctx->realizationBases.pop_back(); - ctx->popBlock(); - ctx->typecheckLevel--; - getLogger().level--; - - return type->getFunc(); -} - -/// Generate ASTs for all __internal__ functions that deal with vtable generation. -/// Intended to be called once the typechecking is done. -/// TODO: add JIT compatibility. -StmtPtr TypecheckVisitor::prepareVTables() { - auto rep = "__internal__.class_populate_vtables:0"; // see internal.codon - auto &initFn = ctx->cache->functions[rep]; - auto suite = N(); - for (auto &[_, cls] : ctx->cache->classes) { - for (auto &[r, real] : cls.realizations) { - size_t vtSz = 0; - for (auto &[base, vtable] : real->vtables) { - if (!vtable.ir) - vtSz += vtable.table.size(); - } - if (!vtSz) - continue; - // __internal__.class_set_rtti_vtable(real.ID, size, real.type) - suite->stmts.push_back(N( - N(N("__internal__.class_set_rtti_vtable:0"), - N(real->id), N(vtSz + 2), NT(r)))); - // LOG("[poly] {} -> {}", r, real->id); - vtSz = 0; - for (auto &[base, vtable] : real->vtables) { - if (!vtable.ir) { - for (auto &[k, v] : vtable.table) { - auto &[fn, id] = v; - std::vector ids; - for (auto &t : fn->getArgTypes()) - ids.push_back(NT(t->realizedName())); - // p[real.ID].__setitem__(f.ID, Function[](f).__raw__()) - LOG_REALIZE("[poly] vtable[{}][{}] = {}", real->id, vtSz + id, fn); - suite->stmts.push_back(N(N( - N("__internal__.class_set_rtti_vtable_fn:0"), - N(real->id), N(vtSz + id), - N(N( - N( - NT( - NT("Function"), - std::vector{ - NT(NT(TYPE_TUPLE), ids), - NT(fn->getRetType()->realizedName())}), - N(fn->realizedName())), - "__raw__")), - NT(r)))); - } - vtSz += vtable.table.size(); - } - } - } - } - initFn.ast->suite = suite; - auto typ = initFn.realizations.begin()->second->type; - LOG_REALIZE("[poly] {} : {}", typ, suite->toString(2)); - typ->ast = initFn.ast.get(); - realizeFunc(typ.get(), true); - - auto &initDist = ctx->cache->functions["__internal__.class_base_derived_dist:0"]; - // def class_base_derived_dist(B, D): - // return Tuple[].__elemsize__ - auto oldAst = initDist.ast; - for (auto &[_, real] : initDist.realizations) { - auto t = real->type; - auto baseTyp = t->funcGenerics[0].type->getClass(); - auto derivedTyp = t->funcGenerics[1].type->getClass(); - - const auto &fields = getClassFields(derivedTyp.get()); - auto types = std::vector{}; - auto found = false; - for (auto &f : fields) { - if (f.baseClass == baseTyp->name) { - found = true; - break; - } else { - auto ft = realize(ctx->instantiate(f.type, derivedTyp)); - types.push_back(NT(ft->realizedName())); - } - } - seqassert(found || getClassFields(baseTyp.get()).empty(), - "cannot find distance between {} and {}", derivedTyp->name, - baseTyp->name); - StmtPtr suite = N( - N(NT(NT(TYPE_TUPLE), types), "__elemsize__")); - LOG_REALIZE("[poly] {} : {}", t, *suite); - initDist.ast->suite = suite; - t->ast = initDist.ast.get(); - realizeFunc(t.get(), true); + realizations[newKey]->ast = r->ast; + r->type = std::static_pointer_cast(type->generalize(0)); + val = std::make_shared(newKey, "", ctx->getModule(), r->type); + ctx->addAlwaysVisible(val, true); + if (!isImport) { + ctx->bases.pop_back(); + ctx->popBlock(); + ctx->typecheckLevel--; + getLogger().level--; } - initDist.ast = oldAst; + this->ctx = oldCtx; - return nullptr; -} - -/// Generate thunks in all derived classes for a given virtual function (must be fully -/// realizable) and the corresponding base class. -/// @return unique thunk ID. -size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType *fp) { - seqassert(cp->canRealize() && fp->canRealize() && fp->getRetType()->canRealize(), - "{} not realized", fp->debugString(1)); - - // TODO: ugly, ugly; surely needs refactoring - - // Function signature for storing thunks - auto sig = [](types::FuncType *fp) { - std::vector gs; - for (auto &a : fp->getArgTypes()) - gs.push_back(a->realizedName()); - gs.push_back("|"); - for (auto &a : fp->funcGenerics) - if (!a.name.empty()) - gs.push_back(a.type->realizedName()); - return join(gs, ","); - }; - - // Set up the base class information - auto baseCls = cp->name; - auto fnName = ctx->cache->rev(fp->ast->name); - auto key = make_pair(fnName, sig(fp)); - auto &vt = ctx->cache->classes[baseCls] - .realizations[cp->realizedName()] - ->vtables[cp->realizedName()]; - - // Add or extract thunk ID - size_t vid = 0; - if (auto i = in(vt.table, key)) { - vid = i->second; - } else { - vid = vt.table.size() + 1; - vt.table[key] = {fp->getFunc(), vid}; - } - - // Iterate through all derived classes and instantiate the corresponding thunk - for (auto &[clsName, cls] : ctx->cache->classes) { - bool inMro = false; - for (auto &m : cls.mro) - if (m->type && m->type->getClass() && m->type->getClass()->name == baseCls) { - inMro = true; - break; - } - if (clsName != baseCls && inMro) { - for (auto &[_, real] : cls.realizations) { - auto &vtable = real->vtables[baseCls]; - - auto ct = - ctx->instantiate(ctx->forceFind(clsName)->type, cp->getClass())->getClass(); - std::vector args = fp->getArgTypes(); - args[0] = ct; - auto m = findBestMethod(ct, fnName, args); - if (!m) { - // Print a nice error message - std::vector a; - for (auto &t : args) - a.emplace_back(fmt::format("{}", t->prettyString())); - std::string argsNice = fmt::format("({})", fmt::join(a, ", ")); - E(Error::DOT_NO_ATTR_ARGS, getSrcInfo(), ct->prettyString(), fnName, - argsNice); - } - - std::vector ns; - for (auto &a : args) - ns.push_back(a->realizedName()); - - // Thunk name: _thunk... - auto thunkName = - format("_thunk.{}.{}.{}", baseCls, m->ast->name, fmt::join(ns, ".")); - if (in(ctx->cache->functions, thunkName)) - continue; - - // Thunk contents: - // def _thunk...(self, ): - // return ( - // __internal__.class_base_to_derived(self, , ), - // ) - std::vector fnArgs; - fnArgs.emplace_back(fp->ast->args[0].name, N(cp->realizedName()), - nullptr); - for (size_t i = 1; i < args.size(); i++) - fnArgs.emplace_back(fp->ast->args[i].name, N(args[i]->realizedName()), - nullptr); - std::vector callArgs; - callArgs.emplace_back( - N(N("__internal__.class_base_to_derived:0"), - N(fp->ast->args[0].name), N(cp->realizedName()), - N(real->type->realizedName()))); - for (size_t i = 1; i < args.size(); i++) - callArgs.emplace_back(N(fp->ast->args[i].name)); - auto thunkAst = N( - thunkName, nullptr, fnArgs, - N(N(N(N(m->ast->name), callArgs))), - Attr({"std.internal.attributes.inline", Attr::ForceRealize})); - auto &thunkFn = ctx->cache->functions[thunkAst->name]; - thunkFn.ast = std::static_pointer_cast(thunkAst->clone()); - - transform(thunkAst); - prependStmts->push_back(thunkAst); - auto ti = ctx->instantiate(thunkFn.type)->getFunc(); - auto tm = realizeFunc(ti.get(), true); - seqassert(tm, "bad thunk {}", thunkFn.type); - vtable.table[key] = {tm->getFunc(), vid}; - } - } - } - return vid; + return r->getType(); } /// Make IR node for a realized type. ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { // Realize if not, and return cached value if it exists - auto realizedName = t->realizedTypeName(); - if (!in(ctx->cache->classes[t->name].realizations, realizedName)) - realize(t->getClass()); - if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) + auto realizedName = t->ClassType::realizedName(); + auto cls = ctx->cache->getClass(t); + if (!in(cls->realizations, realizedName)) { + t = realize(t->getClass())->getClass(); + realizedName = t->ClassType::realizedName(); + cls = ctx->cache->getClass(t); + } + if (auto l = cls->realizations[realizedName]->ir) { + if (cls->rtti) + cast(l)->setPolymorphic(); return l; + } - auto forceFindIRType = [&](const TypePtr &tt) { + auto forceFindIRType = [&](Type *tt) { auto t = tt->getClass(); - seqassert(t && in(ctx->cache->classes[t->name].realizations, t->realizedTypeName()), - "{} not realized", tt); - auto l = ctx->cache->classes[t->name].realizations[t->realizedTypeName()]->ir; - seqassert(l, "no LLVM type for {}", t); + auto rn = t->ClassType::realizedName(); + auto cls = ctx->cache->getClass(t); + seqassert(t && in(cls->realizations, rn), "{} not realized", *tt); + auto l = cls->realizations[rn]->ir; + seqassert(l, "no LLVM type for {}", *tt); return l; }; // Prepare generics and statics std::vector types; - std::vector statics; - for (auto &m : t->generics) { - if (auto s = m.type->getStatic()) { - seqassert(s->expr->staticValue.evaluated, "static not realized"); - statics.push_back(&(s->expr->staticValue)); - } else { - types.push_back(forceFindIRType(m.type)); + std::vector statics; + if (t->is("unrealized_type")) + types.push_back(nullptr); + else + for (auto &m : t->generics) { + if (auto s = m.type->getStatic()) + statics.push_back(s); + else + types.push_back(forceFindIRType(m.getType())); } - } // Get the IR type auto *module = ctx->cache->module; @@ -675,24 +590,25 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { } else if (t->name == "str") { handle = module->getStringType(); } else if (t->name == "Int" || t->name == "UInt") { - handle = module->Nr(statics[0]->getInt(), t->name == "Int"); + handle = + module->Nr(getIntLiteral(statics[0]), t->name == "Int"); } else if (t->name == "Ptr") { - seqassert(types.size() == 1 && statics.empty(), "bad generics/statics"); + seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetPointerType(types[0]); } else if (t->name == "Generator") { - seqassert(types.size() == 1 && statics.empty(), "bad generics/statics"); + seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetGeneratorType(types[0]); } else if (t->name == TYPE_OPTIONAL) { - seqassert(types.size() == 1 && statics.empty(), "bad generics/statics"); + seqassert(types.size() == 1, "bad generics/statics"); handle = module->unsafeGetOptionalType(types[0]); } else if (t->name == "NoneType") { seqassert(types.empty() && statics.empty(), "bad generics/statics"); auto record = - ir::cast(module->unsafeGetMemberedType(realizedName)); + cast(module->unsafeGetMemberedType(realizedName)); record->realize({}, {}); handle = record; } else if (t->name == "Union") { - seqassert(!types.empty() && statics.empty(), "bad union"); + seqassert(!types.empty(), "bad union"); auto unionTypes = t->getUnion()->getRealizationTypes(); std::vector unionVec; unionVec.reserve(unionTypes.size()); @@ -701,169 +617,127 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { handle = module->unsafeGetUnionType(unionVec); } else if (t->name == "Function") { types.clear(); - for (auto &m : t->generics[0].type->getRecord()->args) - types.push_back(forceFindIRType(m)); - auto ret = forceFindIRType(t->generics[1].type); + for (auto &m : extractClassGeneric(t)->getClass()->generics) + types.push_back(forceFindIRType(m.getType())); + auto ret = forceFindIRType(extractClassGeneric(t, 1)); handle = module->unsafeGetFuncType(realizedName, ret, types); } else if (t->name == "std.experimental.simd.Vec") { - seqassert(types.size() == 1 && statics.size() == 1, "bad generics/statics"); - handle = module->unsafeGetVectorType(statics[0]->getInt(), types[0]); - } else if (auto tr = t->getRecord()) { - seqassert(tr->getRepeats() >= 0, "repeats not resolved: '{}'", tr->debugString(2)); - tr->flatten(); - std::vector typeArgs; - std::vector names; - std::map memberInfo; - for (int ai = 0; ai < tr->args.size(); ai++) { - auto n = t->name == TYPE_TUPLE ? format("item{}", ai + 1) - : ctx->cache->classes[t->name].fields[ai].name; - names.emplace_back(n); - typeArgs.emplace_back(forceFindIRType(tr->args[ai])); - memberInfo[n] = t->name == TYPE_TUPLE - ? tr->getSrcInfo() - : ctx->cache->classes[t->name].fields[ai].type->getSrcInfo(); - } - auto record = - ir::cast(module->unsafeGetMemberedType(realizedName)); - record->realize(typeArgs, names); - handle = record; - handle->setAttribute(std::make_unique(std::move(memberInfo))); + seqassert(types.size() == 2 && !statics.empty(), "bad generics/statics"); + handle = module->unsafeGetVectorType(getIntLiteral(statics[0]), types[0]); } else { // Type arguments will be populated afterwards to avoid infinite loop with recursive // reference types (e.g., `class X: x: Optional[X]`) - handle = module->unsafeGetMemberedType(realizedName, true); - if (ctx->cache->classes[t->name].rtti) { - // LOG("RTTI: {}", t->name); - ir::cast(handle)->setPolymorphic(); + if (t->isRecord()) { + std::vector typeArgs; // needed for IR + std::vector names; // needed for IR + std::map memberInfo; // needed for IR + + seqassert(!t->is("__NTuple__"), "ntuple not inlined"); + auto ft = getClassFieldTypes(t->getClass()); + const auto &fields = cls->fields; + for (size_t i = 0; i < ft.size(); i++) { + if (!realize(ft[i].get())) { + E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), fields[i].name, + t->prettyString()); + } + names.emplace_back(fields[i].name); + typeArgs.emplace_back(makeIRType(ft[i]->getClass())); + memberInfo[fields[i].name] = ft[i]->getSrcInfo(); + } + auto record = + cast(module->unsafeGetMemberedType(realizedName)); + record->realize(typeArgs, names); + handle = record; + handle->setAttribute( + std::make_unique(std::move(memberInfo))); + } else { + handle = module->unsafeGetMemberedType(realizedName, !t->isRecord()); + if (cls->rtti) + cast(handle)->setPolymorphic(); } } handle->setSrcInfo(t->getSrcInfo()); handle->setAstType( std::const_pointer_cast(t->shared_from_this())); - return ctx->cache->classes[t->name].realizations[realizedName]->ir = handle; + return cls->realizations[realizedName]->ir = handle; } /// Make IR node for a realized function. ir::Func *TypecheckVisitor::makeIRFunction( const std::shared_ptr &r) { ir::Func *fn = nullptr; + auto irm = ctx->cache->module; // Create and store a function IR node and a realized AST for IR passes - if (r->ast->attributes.has(Attr::Internal)) { + if (r->ast->hasAttribute(Attr::Internal)) { // e.g., __new__, Ptr.__new__, etc. - fn = ctx->cache->module->Nr(r->type->ast->name); - } else if (r->ast->attributes.has(Attr::LLVM)) { - fn = ctx->cache->module->Nr(r->type->realizedName()); - } else if (r->ast->attributes.has(Attr::C)) { - fn = ctx->cache->module->Nr(r->type->realizedName()); + fn = irm->Nr(r->type->ast->name); + } else if (r->ast->hasAttribute(Attr::LLVM)) { + fn = irm->Nr(r->type->realizedName()); + } else if (r->ast->hasAttribute(Attr::C)) { + fn = irm->Nr(r->type->realizedName()); } else { - fn = ctx->cache->module->Nr(r->type->realizedName()); + fn = irm->Nr(r->type->realizedName()); } fn->setUnmangledName(ctx->cache->reverseIdentifierLookup[r->type->ast->name]); auto parent = r->type->funcParent; - if (!r->ast->attributes.parentClass.empty() && - !r->ast->attributes.has(Attr::Method)) { - // Hack for non-generic methods - parent = ctx->find(r->ast->attributes.parentClass)->type; + if (auto aa = r->ast->getAttribute(Attr::ParentClass)) { + if (!aa->value.empty() && !r->ast->hasAttribute(Attr::Method)) { + // Hack for non-generic methods + parent = ctx->find(aa->value)->type; + } } - if (parent && parent->canRealize()) { - realize(parent); - fn->setParentType(makeIRType(parent->getClass().get())); + if (parent && parent->isInstantiated() && parent->canRealize()) { + parent = extractClassType(parent.get())->shared_from_this(); + realize(parent.get()); + fn->setParentType(makeIRType(parent->getClass())); } fn->setGlobal(); // Mark this realization as pending (i.e., realized but not translated) ctx->cache->pendingRealizations.insert({r->type->ast->name, r->type->realizedName()}); - seqassert(!r->type || r->ast->args.size() == r->type->getArgTypes().size() + - r->type->funcGenerics.size(), + seqassert(!r->type || + r->ast->size() == r->type->size() + r->type->funcGenerics.size(), "type/AST argument mismatch"); // Populate the IR node std::vector names; std::vector types; - for (size_t i = 0, j = 0; i < r->ast->args.size(); i++) { - if (r->ast->args[i].status == Param::Normal) { - if (!r->type->getArgTypes()[j]->getFunc()) { - types.push_back(makeIRType(r->type->getArgTypes()[j]->getClass().get())); - names.push_back(ctx->cache->reverseIdentifierLookup[r->ast->args[i].name]); + for (size_t i = 0, j = 0; i < r->ast->size(); i++) { + if ((*r->ast)[i].isValue()) { + if (!extractFuncArgType(r->getType(), j)->getFunc()) { + types.push_back(makeIRType(extractFuncArgType(r->getType(), j)->getClass())); + names.push_back(ctx->cache->reverseIdentifierLookup[(*r->ast)[i].getName()]); } j++; } } - if (r->ast->hasAttr(Attr::CVarArg)) { + if (r->ast->hasAttribute(Attr::CVarArg)) { types.pop_back(); names.pop_back(); } - auto irType = ctx->cache->module->unsafeGetFuncType( - r->type->realizedName(), makeIRType(r->type->getRetType()->getClass().get()), - types, r->ast->hasAttr(Attr::CVarArg)); - irType->setAstType(r->type->getFunc()); + auto irType = irm->unsafeGetFuncType(r->type->realizedName(), + makeIRType(r->type->getRetType()->getClass()), + types, r->ast->hasAttribute(Attr::CVarArg)); + irType->setAstType(r->type->shared_from_this()); fn->realize(irType, names); return fn; } -/// Generate ASTs for dynamically generated functions. -std::shared_ptr -TypecheckVisitor::generateSpecialAst(types::FuncType *type) { - // Clone the generic AST that is to be realized - auto ast = std::dynamic_pointer_cast( - clone(ctx->cache->functions[type->ast->name].ast)); - - if (ast->hasAttr("autogenerated") && endswith(ast->name, ".__iter__:0") && - type->getArgTypes()[0]->getHeterogenousTuple()) { - // Special case: do not realize auto-generated heterogenous __iter__ - E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); - } else if (ast->hasAttr("autogenerated") && endswith(ast->name, ".__getitem__:0") && - type->getArgTypes()[0]->getHeterogenousTuple()) { - // Special case: do not realize auto-generated heterogenous __getitem__ - E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); - } else if (startswith(ast->name, "Function.__call_internal__")) { - // Special case: Function.__call_internal__ - /// TODO: move to IR one day - std::vector items; - items.push_back(nullptr); - std::vector ll; - std::vector lla; - auto &as = type->getArgTypes()[1]->getRecord()->args; - auto ag = ast->args[1].name; - trimStars(ag); - for (int i = 0; i < as.size(); i++) { - ll.push_back(format("%{} = extractvalue {{}} %args, {}", i, i)); - items.push_back(N(N(ag))); - } - items.push_back(N(N("TR"))); - for (int i = 0; i < as.size(); i++) { - items.push_back(N(N(N(ag), N(i)))); - lla.push_back(format("{{}} %{}", i)); - } - items.push_back(N(N("TR"))); - ll.push_back(format("%{} = call {{}} %self({})", as.size(), combine2(lla))); - ll.push_back(format("ret {{}} %{}", as.size())); - items[0] = N(N(combine2(ll, "\n"))); - ast->suite = N(items); - } else if (startswith(ast->name, "Union.__new__:0")) { - auto unionType = type->funcParent->getUnion(); - seqassert(unionType, "expected union, got {}", type->funcParent); - - StmtPtr suite = N(N( - N("__internal__.new_union:0"), N(type->ast->args[0].name), - N(unionType->realizedTypeName()))); - ast->suite = suite; - } else if (startswith(ast->name, "__internal__.get_union_tag:0")) { - // def __internal__.get_union_tag(union: Union, tag: Static[int]): - // return __internal__.union_get_data(union, T0) - auto szt = type->funcGenerics[0].type->getStatic(); - auto tag = szt->evaluate().getInt(); - auto unionType = type->getArgTypes()[0]->getUnion(); - auto unionTypes = unionType->getRealizationTypes(); - if (tag < 0 || tag >= unionTypes.size()) - E(Error::CUSTOM, getSrcInfo(), "bad union tag"); - auto selfVar = ast->args[0].name; - auto suite = N(N( - N(N("__internal__.union_get_data:0"), N(selfVar), - NT(unionTypes[tag]->realizedName())))); - ast->suite = suite; - } - return ast; +ir::Func *TypecheckVisitor::realizeIRFunc(types::FuncType *fn, + const std::vector &generics) { + // TODO: used by cytonization. Probably needs refactoring. + auto fnType = instantiateType(fn); + types::Type::Unification u; + for (size_t i = 0; i < generics.size(); i++) + fnType->getFunc()->funcGenerics[i].type->unify(generics[i].get(), &u); + if (!realize(fnType.get())) + return nullptr; + + auto pr = ctx->cache->pendingRealizations; // copy it as it might be modified + for (auto &fn : pr) + TranslateVisitor(ctx->cache->codegenCtx) + .translateStmts(clone(getFunction(fn.first)->ast)); + return getFunction(fn->ast->getName())->realizations[fnType->realizedName()]->ir; } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 40de9fca..0e0e28e4 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -5,7 +5,8 @@ #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" +#include "codon/parser/peg/peg.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -13,19 +14,40 @@ using namespace codon::error; namespace codon::ast { using namespace types; +using namespace matcher; + +/// Ensure that `break` is in a loop. +/// Transform if a loop break variable is available +/// (e.g., a break within loop-else block). +/// @example +/// `break` -> `no_break = False; break` -/// Nothing to typecheck; just call setDone void TypecheckVisitor::visit(BreakStmt *stmt) { - stmt->setDone(); - if (!ctx->staticLoops.back().empty()) { - auto a = N(N(ctx->staticLoops.back()), N(false)); - a->setUpdate(); - resultStmt = transform(N(a, stmt->clone())); + if (!ctx->getBase()->getLoop()) + E(Error::EXPECTED_LOOP, stmt, "break"); + ctx->getBase()->getLoop()->flat = false; + if (!ctx->getBase()->getLoop()->breakVar.empty()) { + resultStmt = + N(transform(N( + N(ctx->getBase()->getLoop()->breakVar), + N(false), nullptr, AssignStmt::UpdateMode::Update)), + N()); + } else { + stmt->setDone(); + if (!ctx->staticLoops.back().empty()) { + auto a = N(N(ctx->staticLoops.back()), N(false)); + a->setUpdate(); + resultStmt = transform(N(a, stmt)); + } } } -/// Nothing to typecheck; just call setDone +/// Ensure that `continue` is in a loop void TypecheckVisitor::visit(ContinueStmt *stmt) { + if (!ctx->getBase()->getLoop()) + E(Error::EXPECTED_LOOP, stmt, "continue"); + ctx->getBase()->getLoop()->flat = false; + stmt->setDone(); if (!ctx->staticLoops.back().empty()) { resultStmt = N(); @@ -33,125 +55,174 @@ void TypecheckVisitor::visit(ContinueStmt *stmt) { } } -/// Typecheck while statements. +/// Transform a while loop. +/// @example +/// `while cond: ...` -> `while cond.__bool__(): ...` +/// `while cond: ... else: ...` -> ```no_break = True +/// while cond.__bool__(): +/// ... +/// if no_break: ...``` void TypecheckVisitor::visit(WhileStmt *stmt) { + // Check for while-else clause + std::string breakVar; + if (stmt->getElse() && stmt->getElse()->firstInBlock()) { + // no_break = True + breakVar = getTemporaryVar("no_break"); + prependStmts->push_back( + transform(N(N(breakVar), N(true)))); + } + ctx->staticLoops.push_back(stmt->gotoVar.empty() ? "" : stmt->gotoVar); - transform(stmt->cond); + ctx->getBase()->loops.emplace_back(breakVar); + stmt->cond = transform(stmt->getCond()); + if (stmt->getCond()->getClassType() && !stmt->getCond()->getType()->is("bool")) + stmt->cond = transform(N(N(stmt->getCond(), "__bool__"))); + ctx->blockLevel++; - transform(stmt->suite); + stmt->suite = SuiteStmt::wrap(transform(stmt->getSuite())); ctx->blockLevel--; ctx->staticLoops.pop_back(); - if (stmt->cond->isDone() && stmt->suite->isDone()) + // Complete while-else clause + if (stmt->getElse() && stmt->getElse()->firstInBlock()) { + auto es = stmt->getElse(); + stmt->elseSuite = nullptr; + resultStmt = transform(N(stmt, N(N(breakVar), es))); + } + ctx->getBase()->loops.pop_back(); + + if (stmt->getCond()->isDone() && stmt->getSuite()->isDone()) stmt->setDone(); } /// Typecheck for statements. Wrap the iterator expression with `__iter__` if needed. /// See @c transformHeterogenousTupleFor for iterating heterogenous tuples. void TypecheckVisitor::visit(ForStmt *stmt) { - transform(stmt->decorator); - transform(stmt->iter); + stmt->decorator = transformForDecorator(stmt->getDecorator()); + + std::string breakVar; + // Needs in-advance transformation to prevent name clashes with the iterator variable + stmt->iter = transform(stmt->getIter()); + + // Check for for-else clause + Stmt *assign = nullptr; + if (stmt->getElse() && stmt->getElse()->firstInBlock()) { + breakVar = getTemporaryVar("no_break"); + assign = transform(N(N(breakVar), N(true))); + } // Extract the iterator type of the for - auto iterType = stmt->iter->getType()->getClass(); + auto iterType = extractClassType(stmt->getIter()); if (!iterType) return; // wait until the iterator is known - if ((resultStmt = transformStaticForLoop(stmt))) + auto [delay, staticLoop] = transformStaticForLoop(stmt); + if (delay) return; - - bool maybeHeterogenous = - iterType->name == TYPE_TUPLE || startswith(iterType->name, TYPE_KWTUPLE); - if (maybeHeterogenous && !iterType->canRealize()) { - return; // wait until the tuple is fully realizable - } else if (maybeHeterogenous && iterType->getHeterogenousTuple()) { - // Case: iterating a heterogenous tuple - resultStmt = transformHeterogenousTupleFor(stmt); + if (staticLoop) { + resultStmt = staticLoop; return; } + // Replace for (i, j) in ... { ... } with for tmp in ...: { i, j = tmp ; ... } + if (!cast(stmt->getVar())) { + auto var = N(ctx->cache->getTemporaryVar("for")); + auto ns = unpackAssignment(stmt->getVar(), var); + stmt->suite = N(ns, stmt->getSuite()); + stmt->var = var; + } + // Case: iterating a non-generator. Wrap with `__iter__` - if (iterType->name != "Generator" && !stmt->wrapped) { - stmt->iter = transform(N(N(stmt->iter, "__iter__"))); - iterType = stmt->iter->getType()->getClass(); + if (iterType->name != "Generator" && !stmt->isWrapped()) { + stmt->iter = transform(N(N(stmt->getIter(), "__iter__"))); + iterType = extractClassType(stmt->getIter()); stmt->wrapped = true; } - auto var = stmt->var->getId(); - seqassert(var, "corrupt for variable: {}", stmt->var); - - // Handle dominated for bindings - auto changed = in(ctx->cache->replacements, var->value); - while (auto s = in(ctx->cache->replacements, var->value)) - var->value = s->first, changed = s; - if (changed && changed->second) { - auto u = - N(N(format("{}.__used__", var->value)), N(true)); - u->setUpdate(); - stmt->suite = N(u, stmt->suite); + ctx->getBase()->loops.emplace_back(breakVar); + auto var = cast(stmt->getVar()); + seqassert(var, "corrupt for variable: {}", *(stmt->getVar())); + + if (!var->hasAttribute(Attr::ExprDominated) && + !var->hasAttribute(Attr::ExprDominatedUsed)) { + auto val = ctx->addVar(var->getValue(), ctx->generateCanonicalName(var->getValue()), + instantiateUnbound()); + val->time = getTime(); + } else if (var->hasAttribute(Attr::ExprDominatedUsed)) { + var->eraseAttribute(Attr::ExprDominatedUsed); + var->setAttribute(Attr::ExprDominated); + stmt->suite = N( + N(N(format("{}{}", var->getValue(), VAR_USED_SUFFIX)), + N(true), nullptr, AssignStmt::UpdateMode::Update), + stmt->getSuite()); } - if (changed) - var->setAttr(ExprAttr::Dominated); + stmt->var = transform(stmt->getVar()); // Unify iterator variable and the iterator type - auto val = ctx->find(var->value); - if (!changed) - val = ctx->add(TypecheckItem::Var, var->value, - ctx->getUnbound(stmt->var->getSrcInfo())); if (iterType && iterType->name != "Generator") - E(Error::EXPECTED_GENERATOR, stmt->iter); - unify(stmt->var->type, - iterType ? unify(val->type, iterType->generics[0].type) : val->type); + E(Error::EXPECTED_GENERATOR, stmt->getIter()); + if (iterType) + unify(stmt->getVar()->getType(), extractClassGeneric(iterType)); ctx->staticLoops.emplace_back(); ctx->blockLevel++; - transform(stmt->suite); + stmt->suite = SuiteStmt::wrap(transform(stmt->getSuite())); ctx->blockLevel--; ctx->staticLoops.pop_back(); - if (stmt->iter->isDone() && stmt->suite->isDone()) + if (ctx->getBase()->getLoop()->flat) + stmt->flat = true; + + // Complete for-else clause + if (stmt->getElse() && stmt->getElse()->firstInBlock()) { + auto es = stmt->getElse(); + stmt->elseSuite = nullptr; + resultStmt = + transform(N(assign, stmt, N(N(breakVar), es))); + stmt->elseSuite = nullptr; + } + + ctx->getBase()->loops.pop_back(); + + if (stmt->getIter()->isDone() && stmt->getSuite()->isDone()) stmt->setDone(); } -/// Handle heterogeneous tuple iteration. +/// Transform and check for OpenMP decorator. /// @example -/// `for i in tuple_expr: ` -> -/// ```tuple = tuple_expr -/// for cnt in range(): -/// if cnt == 0: -/// i = t[0]; -/// if cnt == 1: -/// i = t[1]; ...``` -/// A separate suite is generated for each tuple member. -StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) { - auto block = N(); - // `tuple = ` - auto tupleVar = ctx->cache->getTemporaryVar("tuple"); - block->stmts.push_back(N(N(tupleVar), stmt->iter)); - - auto tupleArgs = stmt->iter->getType()->getClass()->getHeterogenousTuple()->args; - auto cntVar = ctx->cache->getTemporaryVar("idx"); - std::vector forBlock; - for (size_t ai = 0; ai < tupleArgs.size(); ai++) { - // `if cnt == ai: (var = tuple[ai]; )` - forBlock.push_back(N( - N(N(cntVar), "==", N(ai)), - N(N(clone(stmt->var), - N(N(tupleVar), N(ai))), - clone(stmt->suite)))); +/// `@par(num_threads=2, openmp="schedule(static)")` -> +/// `for_par(num_threads=2, schedule="static")` +Expr *TypecheckVisitor::transformForDecorator(Expr *decorator) { + if (!decorator) + return nullptr; + Expr *callee = decorator; + if (auto c = cast(callee)) + callee = c->getExpr(); + auto ci = cast(transform(callee)); + if (!ci || !startswith(ci->getValue(), "std.openmp.for_par.0")) { + E(Error::LOOP_DECORATOR, decorator); } - // `for cnt in range(tuple_size): ...` - block->stmts.push_back( - N(N(cntVar), - N(N("std.internal.types.range.range"), - N(tupleArgs.size())), - N(forBlock))); - - ctx->blockLevel++; - transform(block); - ctx->blockLevel--; - return block; + std::vector args; + std::string openmp; + std::vector omp; + if (auto c = cast(decorator)) + for (auto &a : *c) { + if (a.getName() == "openmp" || + (a.getName().empty() && openmp.empty() && cast(a.getExpr()))) { + auto ompOrErr = + parseOpenMP(ctx->cache, cast(a.getExpr())->getValue(), + a.value->getSrcInfo()); + if (!ompOrErr) + throw exc::ParserException(ompOrErr.takeError()); + omp = *ompOrErr; + } else { + args.emplace_back(a.getName(), transform(a.getExpr())); + } + } + for (auto &a : omp) + args.emplace_back(a.getName(), transform(a.getExpr())); + return transform(N(transform(N("for_par")), args)); } /// Handle static for constructs. @@ -166,261 +237,120 @@ StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) { /// loop = False # also set to False on break /// If a loop is flat, while wrappers are removed. /// A separate suite is generated for each static iteration. -StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { - auto var = stmt->var->getId()->value; - if (!stmt->iter->getCall() || !stmt->iter->getCall()->expr->getId()) - return nullptr; - auto iter = stmt->iter->getCall()->expr->getId(); - auto loopVar = ctx->cache->getTemporaryVar("loop"); - - std::vector vars{var}; - auto suiteVec = stmt->suite->getSuite(); - auto oldSuite = suiteVec ? suiteVec->clone() : nullptr; - for (int validI = 0; suiteVec && validI < suiteVec->stmts.size(); validI++) { - if (auto a = suiteVec->stmts[validI]->getAssign()) - if (a->rhs && a->rhs->getIndex()) - if (a->rhs->getIndex()->expr->isId(var)) { - vars.push_back(a->lhs->getId()->value); - suiteVec->stmts[validI] = nullptr; - continue; +std::pair TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { + auto loopVar = getTemporaryVar("loop"); + auto suite = clean_clone(stmt->getSuite()); + auto [ok, delay, preamble, items] = transformStaticLoopCall( + stmt->getVar(), &suite, stmt->getIter(), [&](Stmt *assigns) { + Stmt *ret = nullptr; + if (!stmt->flat) { + auto brk = N(); + brk->setDone(); // Avoid transforming this one to continue + // var [: Static] := expr; suite... + auto loop = N(N(loopVar), + N(assigns, clone(suite), brk)); + loop->gotoVar = loopVar; + ret = loop; + } else { + ret = N(assigns, clone(stmt->getSuite())); } - break; - } - if (vars.size() > 1) - vars.erase(vars.begin()); - auto [ok, items] = transformStaticLoopCall(vars, stmt->iter, [&](StmtPtr assigns) { - StmtPtr ret = nullptr; - if (!stmt->flat) { - auto brk = N(); - brk->setDone(); // Avoid transforming this one to continue - // var [: Static] := expr; suite... - auto loop = N(N(loopVar), - N(assigns, clone(stmt->suite), brk)); - loop->gotoVar = loopVar; - ret = loop; - } else { - ret = N(assigns, clone(stmt->suite)); - } - return ret; - }); - if (!ok) { - if (oldSuite) - stmt->suite = oldSuite; - return nullptr; - } + return ret; + }); + if (!ok) + return {false, nullptr}; + if (delay) + return {true, nullptr}; // Close the loop auto block = N(); + block->addStmt(preamble); for (auto &i : items) - block->stmts.push_back(std::dynamic_pointer_cast(i)); - StmtPtr loop = nullptr; + block->addStmt(cast(i)); + Stmt *loop = nullptr; if (!stmt->flat) { ctx->blockLevel++; auto a = N(N(loopVar), N(false)); a->setUpdate(); - block->stmts.push_back(a); + block->addStmt(a); loop = transform(N(N(N(loopVar), N(true)), N(N(loopVar), block))); ctx->blockLevel--; } else { loop = transform(block); } - return loop; + return {false, loop}; } -std::pair>> -TypecheckVisitor::transformStaticLoopCall( - const std::vector &vars, ExprPtr iter, - std::function(StmtPtr)> wrap) { - if (!iter->getCall()) - return {false, {}}; - auto fn = iter->getCall()->expr->getId(); - if (!fn || vars.empty()) - return {false, {}}; - - auto stmt = N(N(vars[0]), nullptr, nullptr); - std::vector> block; - if (startswith(fn->value, "statictuple:0")) { - auto &args = iter->getCall()->args[0].value->getCall()->args; - if (vars.size() != 1) - error("expected one item"); - for (size_t i = 0; i < args.size(); i++) { - stmt->rhs = args[i].value; - if (stmt->rhs->isStatic()) { - stmt->type = NT( - N("Static"), - N(stmt->rhs->staticValue.type == StaticValue::INT ? "int" : "str")); - } else { - stmt->type = nullptr; - } - block.push_back(wrap(stmt->clone())); - } - } else if (fn && startswith(fn->value, "std.internal.types.range.staticrange:0")) { - if (vars.size() != 1) - error("expected one item"); - int st = - fn->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); - int ed = - fn->type->getFunc()->funcGenerics[1].type->getStatic()->evaluate().getInt(); - int step = - fn->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt(); - if (abs(st - ed) / abs(step) > MAX_STATIC_ITER) - E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, abs(st - ed) / abs(step)); - for (int i = st; step > 0 ? i < ed : i > ed; i += step) { - stmt->rhs = N(i); - stmt->type = NT(N("Static"), N("int")); - block.push_back(wrap(stmt->clone())); - } - } else if (fn && startswith(fn->value, "std.internal.types.range.staticrange:1")) { - if (vars.size() != 1) - error("expected one item"); - int ed = - fn->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); - if (ed > MAX_STATIC_ITER) - E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, ed); - for (int i = 0; i < ed; i++) { - stmt->rhs = N(i); - stmt->type = NT(N("Static"), N("int")); - block.push_back(wrap(stmt->clone())); - } - } else if (fn && startswith(fn->value, "std.internal.static.fn_overloads")) { - if (vars.size() != 1) - error("expected one item"); - if (auto fna = ctx->getFunctionArgs(fn->type)) { - auto [generics, args] = *fna; - auto typ = generics[0]->getClass(); - auto name = ctx->getStaticString(generics[1]); - seqassert(name, "bad static string"); - if (auto n = in(ctx->cache->classes[typ->name].methods, *name)) { - auto &mt = ctx->cache->overloads[*n]; - for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { - auto &method = mt[mti]; - if (endswith(method.name, ":dispatch") || - !ctx->cache->functions[method.name].type) - continue; - if (method.age <= ctx->age) { - if (typ->getHeterogenousTuple()) { - auto &ast = ctx->cache->functions[method.name].ast; - if (ast->hasAttr("autogenerated") && - (endswith(ast->name, ".__iter__:0") || - endswith(ast->name, ".__getitem__:0"))) { - // ignore __getitem__ and other heterogenuous methods - continue; - } - } - stmt->rhs = N(method.name); - block.push_back(wrap(stmt->clone())); - } - } - } - } else { - error("bad call to fn_overloads"); - } - } else if (fn && startswith(fn->value, "std.internal.builtin.staticenumerate")) { - if (vars.size() != 2) - error("expected two items"); - if (auto fna = ctx->getFunctionArgs(fn->type)) { - auto [generics, args] = *fna; - if (auto typ = args[0]->getRecord()) { - for (size_t i = 0; i < typ->args.size(); i++) { - auto b = N( - {N(N(vars[0]), N(i), - NT(NT("Static"), NT("int"))), - N(N(vars[1]), - N(iter->getCall()->args[0].value->clone(), - N(i)))}); - block.push_back(wrap(b)); - } - } else { - error("staticenumerate needs a tuple"); - } - } else { - error("bad call to staticenumerate"); - } - } else if (fn && startswith(fn->value, "std.internal.internal.vars:0")) { - if (auto fna = ctx->getFunctionArgs(fn->type)) { - auto [generics, args] = *fna; - - auto withIdx = generics[0]->getStatic()->evaluate().getInt() != 0 ? 1 : 0; - if (!withIdx && vars.size() != 2) - error("expected two items"); - else if (withIdx && vars.size() != 3) - error("expected three items"); - auto typ = args[0]->getClass(); - size_t idx = 0; - for (auto &f : getClassFields(typ.get())) { - std::vector stmts; - if (withIdx) { - stmts.push_back( - N(N(vars[0]), N(idx), - NT(NT("Static"), NT("int")))); +std::tuple> +TypecheckVisitor::transformStaticLoopCall(Expr *varExpr, SuiteStmt **varSuite, + Expr *iter, + const std::function &wrap, + bool allowNonHeterogenous) { + if (!iter->getClassType()) + return {true, true, nullptr, {}}; + + std::vector vars{}; + if (auto ei = cast(varExpr)) { + vars.push_back(ei->getValue()); + } else { + Items *list = nullptr; + if (auto el = cast(varExpr)) + list = el; + else if (auto et = cast(varExpr)) + list = et; + if (list) { + for (const auto &it : *list) + if (auto ei = cast(it)) { + vars.push_back(ei->getValue()); + } else { + return {false, false, nullptr, {}}; } - stmts.push_back( - N(N(vars[withIdx]), N(f.name), - NT(NT("Static"), NT("str")))); - stmts.push_back( - N(N(vars[withIdx + 1]), - N(iter->getCall()->args[0].value->clone(), f.name))); - auto b = N(stmts); - block.push_back(wrap(b)); - idx++; - } } else { - error("bad call to vars"); + return {false, false, nullptr, {}}; } - } else if (fn && startswith(fn->value, "std.internal.static.vars_types:0")) { - if (auto fna = ctx->getFunctionArgs(fn->type)) { - auto [generics, args] = *fna; - - auto typ = realize(generics[0]->getClass()); - auto withIdx = generics[1]->getStatic()->evaluate().getInt() != 0 ? 1 : 0; - if (!withIdx && vars.size() != 1) - error("expected one item"); - else if (withIdx && vars.size() != 2) - error("expected two items"); - - seqassert(typ, "vars_types expects a realizable type, got '{}' instead", - generics[0]); - - if (auto utyp = typ->getUnion()) { - for (size_t i = 0; i < utyp->getRealizationTypes().size(); i++) { - std::vector stmts; - if (withIdx) { - stmts.push_back( - N(N(vars[0]), N(i), - NT(NT("Static"), NT("int")))); - } - stmts.push_back( - N(N(vars[1]), - N(utyp->getRealizationTypes()[i]->realizedName()))); - auto b = N(stmts); - block.push_back(wrap(b)); - } - } else { - size_t idx = 0; - for (auto &f : getClassFields(typ->getClass().get())) { - auto ta = realize(ctx->instantiate(f.type, typ->getClass())); - seqassert(ta, "cannot realize '{}'", f.type->debugString(1)); - std::vector stmts; - if (withIdx) { - stmts.push_back( - N(N(vars[0]), N(idx), - NT(NT("Static"), NT("int")))); - } - stmts.push_back( - N(N(vars[withIdx]), NT(ta->realizedName()))); - auto b = N(stmts); - block.push_back(wrap(b)); - idx++; - } - } + } + + Stmt *preamble = nullptr; + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + std::vector block; + if (fn && startswith(fn->getValue(), "statictuple")) { + block = populateStaticTupleLoop(iter, vars); + } else if (fn && + startswith(fn->getValue(), "std.internal.types.range.staticrange.0:1")) { + block = populateSimpleStaticRangeLoop(iter, vars); + } else if (fn && + startswith(fn->getValue(), "std.internal.types.range.staticrange.0")) { + block = populateStaticRangeLoop(iter, vars); + } else if (fn && startswith(fn->getValue(), "std.internal.static.fn_overloads.0")) { + block = populateStaticFnOverloadsLoop(iter, vars); + } else if (fn && + startswith(fn->getValue(), "std.internal.builtin.staticenumerate.0")) { + block = populateStaticEnumerateLoop(iter, vars); + } else if (fn && startswith(fn->getValue(), "std.internal.internal.vars.0")) { + block = populateStaticVarsLoop(iter, vars); + } else if (fn && startswith(fn->getValue(), "std.internal.static.vars_types.0")) { + block = populateStaticVarTypesLoop(iter, vars); + } else { + bool maybeHeterogenous = iter->getType()->is(TYPE_TUPLE); + if (maybeHeterogenous) { + if (!iter->getType()->canRealize()) + return {true, true, nullptr, {}}; // wait until the tuple is fully realizable + if (!iter->getClassType()->getHeterogenousTuple() && !allowNonHeterogenous) + return {false, false, nullptr, {}}; + block = populateStaticHeterogenousTupleLoop(iter, vars); + preamble = block.back(); + block.pop_back(); } else { - error("bad call to vars"); + return {false, false, nullptr, {}}; } - } else { - return {false, {}}; } - return {true, block}; + std::vector wrapBlock; + wrapBlock.reserve(block.size()); + for (auto b : block) { + wrapBlock.push_back(wrap(b)); + } + return {true, false, preamble, wrapBlock}; } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index 44e551c4..81283b33 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -6,7 +6,7 @@ #include "codon/parser/ast.h" #include "codon/parser/cache.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/simplify.h" +#include "codon/parser/match.h" #include "codon/parser/visitors/typecheck/typecheck.h" using fmt::format; @@ -15,37 +15,41 @@ using namespace codon::error; namespace codon::ast { using namespace types; +using namespace matcher; /// Replace unary operators with the appropriate magic calls. /// Also evaluate static expressions. See @c evaluateStaticUnary for details. void TypecheckVisitor::visit(UnaryExpr *expr) { - transform(expr->expr); + expr->expr = transform(expr->expr); - static std::unordered_map> - staticOps = {{StaticValue::INT, {"-", "+", "!", "~"}}, - {StaticValue::STRING, {"@"}}}; + static std::unordered_map> staticOps = { + {1, {"-", "+", "!", "~"}}, {2, {"@"}}, {3, {"!"}}}; // Handle static expressions - if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) { - resultExpr = evaluateStaticUnary(expr); + if (auto s = expr->getExpr()->getType()->isStaticType()) { + if (in(staticOps[s], expr->getOp())) { + resultExpr = evaluateStaticUnary(expr); + return; + } + } else if (isUnbound(expr->getExpr())) { return; } - if (expr->op == "!") { + if (expr->getOp() == "!") { // `not expr` -> `expr.__bool__().__invert__()` resultExpr = transform(N(N( - N(N(clone(expr->expr), "__bool__")), "__invert__"))); + N(N(expr->getExpr(), "__bool__")), "__invert__"))); } else { std::string magic; - if (expr->op == "~") + if (expr->getOp() == "~") magic = "invert"; - else if (expr->op == "+") + else if (expr->getOp() == "+") magic = "pos"; - else if (expr->op == "-") + else if (expr->getOp() == "-") magic = "neg"; else - seqassert(false, "invalid unary operator '{}'", expr->op); + seqassert(false, "invalid unary operator '{}'", expr->getOp()); resultExpr = - transform(N(N(clone(expr->expr), format("__{}__", magic)))); + transform(N(N(expr->getExpr(), format("__{}__", magic)))); } } @@ -54,31 +58,81 @@ void TypecheckVisitor::visit(UnaryExpr *expr) { /// @c transformBinaryInplaceMagic for details. /// Also evaluate static expressions. See @c evaluateStaticBinary for details. void TypecheckVisitor::visit(BinaryExpr *expr) { - // Transform lexpr and rexpr. Ignore Nones for now - if (!(startswith(expr->op, "is") && expr->lexpr->getNone())) - transform(expr->lexpr); - if (!(startswith(expr->op, "is") && expr->rexpr->getNone())) - transform(expr->rexpr); - - static std::unordered_map> - staticOps = {{StaticValue::INT, - {"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", - "%", "&", "|", "^", ">>", "<<"}}, - {StaticValue::STRING, {"==", "!=", "+"}}}; - if (expr->lexpr->isStatic() && expr->rexpr->isStatic() && - expr->lexpr->staticValue.type == expr->rexpr->staticValue.type && - in(staticOps[expr->rexpr->staticValue.type], expr->op)) { - // Handle static expressions - resultExpr = evaluateStaticBinary(expr); - } else if (auto e = transformBinarySimple(expr)) { + expr->lexpr = transform(expr->getLhs(), true); + + // Static short-circuit + if (expr->getLhs()->getType()->isStaticType() && expr->op == "&&") { + if (auto tb = expr->getLhs()->getType()->getBoolStatic()) { + if (!tb->value) { + resultExpr = transform(N(false)); + return; + } + } else if (auto ts = expr->getLhs()->getType()->getStrStatic()) { + if (ts->value.empty()) { + resultExpr = transform(N(false)); + return; + } + } else if (auto ti = expr->getLhs()->getType()->getIntStatic()) { + if (!ti->value) { + resultExpr = transform(N(false)); + return; + } + } else { + expr->getType()->getUnbound()->isStatic = 3; + return; + } + } else if (expr->getLhs()->getType()->isStaticType() && expr->op == "||") { + if (auto tb = expr->getLhs()->getType()->getBoolStatic()) { + if (tb->value) { + resultExpr = transform(N(true)); + return; + } + } else if (auto ts = expr->getLhs()->getType()->getStrStatic()) { + if (!ts->value.empty()) { + resultExpr = transform(N(true)); + return; + } + } else if (auto ti = expr->getLhs()->getType()->getIntStatic()) { + if (ti->value) { + resultExpr = transform(N(true)); + return; + } + } else { + expr->getType()->getUnbound()->isStatic = 3; + return; + } + } + + expr->rexpr = transform(expr->getRhs(), true); + + static std::unordered_map> staticOps = { + {1, + {"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//", "%", "&", + "|", "^", ">>", "<<"}}, + {2, {"==", "!=", "+"}}, + {3, {"<", "<=", ">", ">=", "==", "!=", "&&", "||"}}}; + if (expr->getLhs()->getType()->isStaticType() && + expr->getRhs()->getType()->isStaticType()) { + auto l = expr->getLhs()->getType()->isStaticType(); + auto r = expr->getRhs()->getType()->isStaticType(); + bool isStatic = l == r && in(staticOps[l], expr->getOp()); + if (!isStatic && ((l == 1 && r == 3) || (r == 1 && l == 3)) && + in(staticOps[1], expr->getOp())) + isStatic = true; + if (isStatic) { + resultExpr = evaluateStaticBinary(expr); + return; + } + } + + if (auto e = transformBinarySimple(expr)) { // Case: simple binary expressions resultExpr = e; - } else if (expr->lexpr->getType()->getUnbound() || - (expr->op != "is" && expr->rexpr->getType()->getUnbound())) { + } else if (expr->getLhs()->getType()->getUnbound() || + (expr->getOp() != "is" && expr->getRhs()->getType()->getUnbound())) { // Case: types are unknown, so continue later - unify(expr->type, ctx->getUnbound()); return; - } else if (expr->op == "is") { + } else if (expr->getOp() == "is") { // Case: is operator resultExpr = transformBinaryIs(expr); } else { @@ -88,40 +142,68 @@ void TypecheckVisitor::visit(BinaryExpr *expr) { } else if (auto em = transformBinaryMagic(expr)) { // Case: normal magic methods resultExpr = em; - } else if (expr->lexpr->getType()->is(TYPE_OPTIONAL)) { + } else if (expr->getLhs()->getType()->is(TYPE_OPTIONAL)) { // Special case: handle optionals if everything else fails. // Assumes that optionals have no relevant magics (except for __eq__) resultExpr = - transform(N(N(N(FN_UNWRAP), expr->lexpr), - expr->op, expr->rexpr, expr->inPlace)); + transform(N(N(N(FN_UNWRAP), expr->getLhs()), + expr->getOp(), expr->getRhs(), expr->isInPlace())); } else { // Nothing found: report an error - E(Error::OP_NO_MAGIC, expr, expr->op, expr->lexpr->type->prettyString(), - expr->rexpr->type->prettyString()); + E(Error::OP_NO_MAGIC, expr, expr->getOp(), + expr->getLhs()->getType()->prettyString(), + expr->getRhs()->getType()->prettyString()); } } } +/// Transform chain binary expression. +/// @example +/// `a <= b <= c` -> `(a <= (chain := b)) and (chain <= c)` +/// The assignment above ensures that all expressions are executed only once. +void TypecheckVisitor::visit(ChainBinaryExpr *expr) { + seqassert(expr->exprs.size() >= 2, "not enough expressions in ChainBinaryExpr"); + std::vector items; + std::string prev; + for (int i = 1; i < expr->exprs.size(); i++) { + auto l = prev.empty() ? clone(expr->exprs[i - 1].second) : N(prev); + prev = ctx->generateCanonicalName("chain"); + auto r = + (i + 1 == expr->exprs.size()) + ? clone(expr->exprs[i].second) + : N(N(N(prev), clone(expr->exprs[i].second)), + N(prev)); + items.emplace_back(N(l, expr->exprs[i].first, r)); + } + + Expr *final = items.back(); + for (auto i = items.size() - 1; i-- > 0;) + final = N(items[i], "&&", final); + resultExpr = transform(final); +} + /// Helper function that locates the pipe ellipsis within a collection of (possibly /// nested) CallExprs. /// @return List of CallExprs and their locations within the parent CallExpr /// needed to access the ellipsis. /// @example `foo(bar(1, baz(...)))` returns `[{0, baz}, {1, bar}, {0, foo}]` -std::vector> findEllipsis(ExprPtr expr) { - auto call = expr->getCall(); +std::vector> TypecheckVisitor::findEllipsis(Expr *expr) { + auto call = cast(expr); if (!call) return {}; - for (size_t ai = 0; ai < call->args.size(); ai++) { - if (auto el = call->args[ai].value->getEllipsis()) { - if (el->mode == EllipsisExpr::PIPE) + size_t ai = 0; + for (auto &a : *call) { + if (auto el = cast(a)) { + if (el->isPipe()) return {{ai, expr}}; - } else if (call->args[ai].value->getCall()) { - auto v = findEllipsis(call->args[ai].value); + } else if (cast(a)) { + auto v = findEllipsis(a); if (!v.empty()) { v.emplace_back(ai, expr); return v; } } + ai++; } return {}; } @@ -137,10 +219,10 @@ void TypecheckVisitor::visit(PipeExpr *expr) { bool hasGenerator = false; // Return T if t is of type `Generator[T]`; otherwise just `type(t)` - auto getIterableType = [&](TypePtr t) { + auto getIterableType = [&](Type *t) { if (t->is("Generator")) { hasGenerator = true; - return t->getClass()->generics[0].type; + return extractClassGeneric(t); } return t; }; @@ -151,55 +233,57 @@ void TypecheckVisitor::visit(PipeExpr *expr) { expr->inTypes.clear(); // Process the pipeline head - auto inType = transform(expr->items[0].expr)->type; // input type to the next stage - expr->inTypes.push_back(inType); + expr->front().expr = transform(expr->front().expr); + auto inType = expr->front().expr->getType(); // input type to the next stage + expr->inTypes.push_back(inType->shared_from_this()); inType = getIterableType(inType); - auto done = expr->items[0].expr->isDone(); - for (size_t pi = 1; pi < expr->items.size(); pi++) { - int inTypePos = -1; // ellipsis position - ExprPtr *ec = &(expr->items[pi].expr); // a pointer so that we can replace it - while (auto se = (*ec)->getStmtExpr()) // handle StmtExpr (e.g., in partial calls) + auto done = expr->front().expr->isDone(); + for (size_t pi = 1; pi < expr->size(); pi++) { + int inTypePos = -1; // ellipsis position + Expr **ec = &((*expr)[pi].expr); // a pointer so that we can replace it + while (auto se = cast(*ec)) // handle StmtExpr (e.g., in partial calls) ec = &(se->expr); - if (auto call = (*ec)->getCall()) { + if (auto call = cast(*ec)) { // Case: a call. Find the position of the pipe ellipsis within it - for (size_t ia = 0; inTypePos == -1 && ia < call->args.size(); ia++) - if (call->args[ia].value->getEllipsis()) { + for (size_t ia = 0; inTypePos == -1 && ia < call->size(); ia++) + if (cast((*call)[ia].value)) inTypePos = int(ia); - } // No ellipses found? Prepend it as the first argument if (inTypePos == -1) { - call->args.insert(call->args.begin(), - {"", N(EllipsisExpr::PARTIAL)}); + call->items.insert(call->items.begin(), + {"", N(EllipsisExpr::PARTIAL)}); inTypePos = 0; } } else { // Case: not a call. Convert it to a call with a single ellipsis - expr->items[pi].expr = - N(expr->items[pi].expr, N(EllipsisExpr::PARTIAL)); - ec = &expr->items[pi].expr; + (*expr)[pi].expr = + N((*expr)[pi].expr, N(EllipsisExpr::PARTIAL)); + ec = &(*expr)[pi].expr; inTypePos = 0; } // Set the ellipsis type - auto el = (*ec)->getCall()->args[inTypePos].value->getEllipsis(); + auto el = cast((*cast(*ec))[inTypePos].value); el->mode = EllipsisExpr::PIPE; // Don't unify unbound inType yet (it might become a generator that needs to be // extracted) + if (!el->getType()) + el->setType(instantiateUnbound()); if (inType && !inType->getUnbound()) - unify(el->type, inType); + unify(el->getType(), inType); // Transform the call. Because a transformation might wrap the ellipsis in layers, // make sure to extract these layers and move them to the pipeline. // Example: `foo(...)` that is transformed to `foo(unwrap(...))` will become // `unwrap(...) |> foo(...)` - transform(*ec); + *ec = transform(*ec); auto layers = findEllipsis(*ec); seqassert(!layers.empty(), "can't find the ellipsis"); if (layers.size() > 1) { // Prepend layers for (auto &[pos, prepend] : layers) { - prepend->getCall()->args[pos].value = N(EllipsisExpr::PIPE); + (*cast(prepend))[pos].value = N(EllipsisExpr::PIPE); expr->items.insert(expr->items.begin() + pi++, {"|>", prepend}); } // Rewind the loop (yes, the current expression will get transformed again) @@ -209,19 +293,19 @@ void TypecheckVisitor::visit(PipeExpr *expr) { continue; } - if ((*ec)->type) - unify(expr->items[pi].expr->type, (*ec)->type); - expr->items[pi].expr = *ec; - inType = expr->items[pi].expr->getType(); + if ((*ec)->getType()) + unify((*expr)[pi].expr->getType(), (*ec)->getType()); + (*expr)[pi].expr = *ec; + inType = (*expr)[pi].expr->getType(); if (!realize(inType)) done = false; - expr->inTypes.push_back(inType); + expr->inTypes.push_back(inType->shared_from_this()); // Do not extract the generator in the last stage of a pipeline if (pi + 1 < expr->items.size()) inType = getIterableType(inType); } - unify(expr->type, (hasGenerator ? ctx->getType("NoneType") : inType)); + unify(expr->getType(), (hasGenerator ? getStdLibType("NoneType") : inType)); if (done) expr->setDone(); } @@ -233,36 +317,62 @@ void TypecheckVisitor::visit(PipeExpr *expr) { /// `foo[idx]` -> `foo.__getitem__(idx)` /// expr.itemN or a sub-tuple if index is static (see transformStaticTupleIndex()), void TypecheckVisitor::visit(IndexExpr *expr) { - // Handle `Static[T]` constructs - if (expr->expr->isId("Static")) { - auto typ = ctx->getUnbound(); + std::string staticType; + if (match(expr, + M(M("Static"), M(MOr("int", "str", "bool"))))) { + // Special case: static types. + auto typ = instantiateUnbound(); typ->isStatic = getStaticGeneric(expr); - unify(expr->type, typ); + unify(expr->getType(), typ); expr->setDone(); return; + } else if (match(expr->expr, M("Static"))) { + E(Error::BAD_STATIC_TYPE, expr->getIndex()); + } + if (match(expr->expr, M("tuple"))) + cast(expr->expr)->setValue(TYPE_TUPLE); + expr->expr = transform(expr->expr, true); + + // IndexExpr[i1, ..., iN] is internally represented as + // IndexExpr[TupleExpr[i1, ..., iN]] for N > 1 + std::vector items; + bool isTuple = false; + if (auto t = cast(expr->getIndex())) { + items = t->items; + isTuple = true; + } else { + items.push_back(expr->getIndex()); + } + for (auto &i : items) { + if (cast(i) && isTypeExpr(expr->getExpr())) { + // Special case: `A[[A, B], C]` -> `A[Tuple[A, B], C]` (e.g., in + // `Function[...]`) + i = N(N(TYPE_TUPLE), cast(i)->items); + } + i = transform(i, true); + } + if (isTypeExpr(expr->getExpr())) { + resultExpr = transform(N(expr->getExpr(), items)); + return; } - transform(expr->expr); - seqassert(!expr->expr->isType(), "index not converted to instantiate"); - auto cls = expr->expr->getType()->getClass(); + expr->index = (!isTuple && items.size() == 1) ? items[0] : N(items); + auto cls = expr->getExpr()->getClassType(); if (!cls) { // Wait until the type becomes known - unify(expr->type, ctx->getUnbound()); return; } // Case: static tuple access - auto [isTuple, tupleExpr] = transformStaticTupleIndex(cls, expr->expr, expr->index); - if (isTuple) { - if (!tupleExpr) { - unify(expr->type, ctx->getUnbound()); - } else { + auto [isStaticTuple, tupleExpr] = + transformStaticTupleIndex(cls, expr->getExpr(), expr->getIndex()); + if (isStaticTuple) { + if (tupleExpr) resultExpr = tupleExpr; - } } else { // Case: normal __getitem__ - resultExpr = - transform(N(N(expr->expr, "__getitem__"), expr->index)); + resultExpr = transform( + N(N(expr->getExpr(), "__getitem__"), expr->getIndex())); } } @@ -270,89 +380,89 @@ void TypecheckVisitor::visit(IndexExpr *expr) { /// @example /// Instantiate(foo, [bar]) -> Id("foo[bar]") void TypecheckVisitor::visit(InstantiateExpr *expr) { - transformType(expr->typeExpr); - - std::shared_ptr repeats = nullptr; - if (expr->typeExpr->isId(TYPE_TUPLE) && !expr->typeParams.empty()) { - transform(expr->typeParams[0]); - if (expr->typeParams[0]->staticValue.type == StaticValue::INT) { - repeats = Type::makeStatic(ctx->cache, expr->typeParams[0]); - } - } + expr->expr = transformType(expr->getExpr()); TypePtr typ = nullptr; - size_t typeParamsSize = expr->typeParams.size() - (repeats != nullptr); - if (expr->typeExpr->isId(TYPE_TUPLE)) { - typ = ctx->instantiateTuple(typeParamsSize); + size_t typeParamsSize = expr->size(); + if (extractType(expr->expr)->is(TYPE_TUPLE)) { + if (!expr->empty()) { + expr->items.front() = transform(expr->front()); + if (expr->front()->getType()->isStaticType() == 1) { + auto et = N( + N("Tuple"), + std::vector(expr->items.begin() + 1, expr->items.end())); + resultExpr = transform(N(N("__NTuple__"), + std::vector{(*expr)[0], et})); + return; + } + } + auto t = generateTuple(typeParamsSize); + typ = instantiateType(t); } else { - typ = ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType()); + typ = instantiateType(expr->getExpr()->getSrcInfo(), extractType(expr->getExpr())); } - seqassert(typ->getClass(), "unknown type: {}", expr->typeExpr); + seqassert(typ->getClass(), "unknown type: {}", *(expr->getExpr())); auto &generics = typ->getClass()->generics; bool isUnion = typ->getUnion() != nullptr; if (!isUnion && typeParamsSize != generics.size()) - E(Error::GENERICS_MISMATCH, expr, ctx->cache->rev(typ->getClass()->name), + E(Error::GENERICS_MISMATCH, expr, getUnmangledName(typ->getClass()->name), generics.size(), typeParamsSize); - if (expr->typeExpr->isId(TYPE_CALLABLE)) { + if (isId(expr->getExpr(), TYPE_CALLABLE)) { // Case: Callable[...] trait instantiation - std::vector types; // Callable error checking. - for (auto &typeParam : expr->typeParams) { - transformType(typeParam); - if (typeParam->type->isStaticType()) + std::vector types; + for (auto &typeParam : *expr) { + typeParam = transformType(typeParam); + if (typeParam->getType()->isStaticType()) E(Error::INST_CALLABLE_STATIC, typeParam); - types.push_back(typeParam->type); + types.push_back(extractType(typeParam)->shared_from_this()); } - auto typ = ctx->getUnbound(); + auto typ = instantiateUnbound(); // Set up the Callable trait typ->getLink()->trait = std::make_shared(ctx->cache, types); - unify(expr->type, typ); - } else if (expr->typeExpr->isId(TYPE_TYPEVAR)) { + unify(expr->getType(), instantiateTypeVar(typ.get())); + } else if (isId(expr->getExpr(), TYPE_TYPEVAR)) { // Case: TypeVar[...] trait instantiation - transformType(expr->typeParams[0]); - auto typ = ctx->getUnbound(); - typ->getLink()->trait = std::make_shared(expr->typeParams[0]->type); - unify(expr->type, typ); + (*expr)[0] = transformType((*expr)[0]); + auto typ = instantiateUnbound(); + typ->getLink()->trait = + std::make_shared(extractType(expr->front())->shared_from_this()); + unify(expr->getType(), typ); } else { - for (size_t i = (repeats != nullptr); i < expr->typeParams.size(); i++) { - transform(expr->typeParams[i]); - TypePtr t = nullptr; - if (expr->typeParams[i]->isStatic()) { - t = Type::makeStatic(ctx->cache, expr->typeParams[i]); + for (size_t i = 0; i < expr->size(); i++) { + (*expr)[i] = transformType((*expr)[i]); + auto t = instantiateType((*expr)[i]->getSrcInfo(), extractType((*expr)[i])); + if (isUnion || (*expr)[i]->getType()->isStaticType() != + generics[i].getType()->isStaticType()) { + if (cast((*expr)[i])) // `None` -> `NoneType` + (*expr)[i] = transformType((*expr)[i]); + if (!isTypeExpr((*expr)[i])) + E(Error::EXPECTED_TYPE, (*expr)[i], "type"); + } + if (isUnion) { + if (!typ->getUnion()->addType(t.get())) + E(error::Error::UNION_TOO_BIG, (*expr)[i], + typ->getUnion()->pendingTypes.size()); } else { - if (expr->typeParams[i]->getNone()) // `None` -> `NoneType` - transformType(expr->typeParams[i]); - if (expr->typeParams[i]->type->getClass() && !expr->typeParams[i]->isType()) - E(Error::EXPECTED_TYPE, expr->typeParams[i], "type"); - t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(), - expr->typeParams[i]->getType()); + unify(t.get(), generics[i].getType()); } - if (isUnion) - typ->getUnion()->addType(t); - else - unify(t, generics[i - (repeats != nullptr)].type); - } - if (repeats) { - typ->getRecord()->repeats = repeats; } if (isUnion) { typ->getUnion()->seal(); } - unify(expr->type, typ); - } - expr->markType(); - // If the type is realizable, use the realized name instead of instantiation - // (e.g. use Id("Ptr[byte]") instead of Instantiate(Ptr, {byte})) - if (realize(expr->type)) { - resultExpr = N(expr->type->realizedName()); - resultExpr->setType(expr->type); - resultExpr->setDone(); - if (expr->typeExpr->isType()) - resultExpr->markType(); + unify(expr->getType(), instantiateTypeVar(typ.get())); + // If the type is realizable, use the realized name instead of instantiation + // (e.g. use Id("Ptr[byte]") instead of Instantiate(Ptr, {byte})) + if (realize(expr->getType())) { + auto t = extractType(expr); + resultExpr = N(t->realizedName()); + resultExpr->setType(expr->getType()->shared_from_this()); + resultExpr->setDone(); + } } } @@ -360,55 +470,68 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) { /// @example /// `start::step` -> `Slice(start, Optional.__new__(), step)` void TypecheckVisitor::visit(SliceExpr *expr) { - ExprPtr none = N(N(TYPE_OPTIONAL, "__new__")); - resultExpr = transform(N( - N(TYPE_SLICE), expr->start ? expr->start : clone(none), - expr->stop ? expr->stop : clone(none), expr->step ? expr->step : clone(none))); + Expr *none = N(N(N(TYPE_OPTIONAL), "__new__")); + resultExpr = transform(N(N(getStdLibType("Slice")->name), + expr->getStart() ? expr->getStart() : clone(none), + expr->getStop() ? expr->getStop() : clone(none), + expr->getStep() ? expr->getStep() : clone(none))); } /// Evaluate a static unary expression and return the resulting static expression. /// If the expression cannot be evaluated yet, return nullptr. /// Supported operators: (strings) not (ints) not, -, + -ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { +Expr *TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { // Case: static strings - if (expr->expr->staticValue.type == StaticValue::STRING) { - if (expr->op == "!") { - if (expr->expr->staticValue.evaluated) { - bool value = expr->expr->staticValue.getString().empty(); + if (expr->getExpr()->getType()->isStaticType() == 2) { + if (expr->getOp() == "!") { + if (expr->getExpr()->getType()->canRealize()) { + bool value = getStrLiteral(expr->getExpr()->getType()).empty(); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); - return transform(N(value)); + return transform(N(value)); } else { // Cannot be evaluated yet: just set the type - unify(expr->type, ctx->getType("bool")); - if (!expr->isStatic()) - expr->staticValue.type = StaticValue::INT; + expr->getType()->getUnbound()->isStatic = 1; + } + } + return nullptr; + } + + // Case: static bools + if (expr->getExpr()->getType()->isStaticType() == 3) { + if (expr->getOp() == "!") { + if (expr->getExpr()->getType()->canRealize()) { + bool value = getBoolLiteral(expr->getExpr()->getType()); + LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); + return transform(N(!value)); + } else { + // Cannot be evaluated yet: just set the type + expr->getType()->getUnbound()->isStatic = 3; } } return nullptr; } // Case: static integers - if (expr->op == "-" || expr->op == "+" || expr->op == "!" || expr->op == "~") { - if (expr->expr->staticValue.evaluated) { - int64_t value = expr->expr->staticValue.getInt(); - if (expr->op == "+") + if (expr->getOp() == "-" || expr->getOp() == "+" || expr->getOp() == "!" || + expr->getOp() == "~") { + if (expr->getExpr()->getType()->canRealize()) { + int64_t value = getIntLiteral(expr->getExpr()->getType()); + if (expr->getOp() == "+") ; - else if (expr->op == "-") + else if (expr->getOp() == "-") value = -value; - else if (expr->op == "~") + else if (expr->getOp() == "~") value = ~value; else value = !bool(value); LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); - if (expr->op == "!") - return transform(N(bool(value))); + if (expr->getOp() == "!") + return transform(N(value)); else return transform(N(value)); } else { // Cannot be evaluated yet: just set the type - unify(expr->type, ctx->getType("int")); - if (!expr->isStatic()) - expr->staticValue.type = StaticValue::INT; + expr->getType()->getUnbound()->isStatic = expr->getOp() == "!" ? 3 : 1; } } @@ -416,13 +539,15 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { } /// Division and modulus implementations. -std::pair divMod(const std::shared_ptr &ctx, int a, int b) { - if (!b) +std::pair divMod(const std::shared_ptr &ctx, int64_t a, + int64_t b) { + if (!b) { E(Error::STATIC_DIV_ZERO, ctx->getSrcInfo()); - if (ctx->cache->pythonCompat) { + return {0, 0}; + } else if (ctx->cache->pythonCompat) { // Use Python implementation. - int d = a / b; - int m = a - d * b; + int64_t d = a / b; + int64_t m = a - d * b; if (m && ((b ^ m) < 0)) { m += b; d -= 1; @@ -438,93 +563,98 @@ std::pair divMod(const std::shared_ptr &ctx, int a, int b /// If the expression cannot be evaluated yet, return nullptr. /// Supported operators: (strings) +, ==, != /// (ints) <, <=, >, >=, ==, !=, and, or, +, -, *, //, %, ^, |, & -ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { +Expr *TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { // Case: static strings - if (expr->rexpr->staticValue.type == StaticValue::STRING) { - if (expr->op == "+") { + if (expr->getRhs()->getType()->isStaticType() == 2) { + if (expr->getOp() == "+") { // `"a" + "b"` -> `"ab"` - if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) { - auto value = - expr->lexpr->staticValue.getString() + expr->rexpr->staticValue.getString(); + if (expr->getLhs()->getType()->getStrStatic() && + expr->getRhs()->getType()->getStrStatic()) { + auto value = getStrLiteral(expr->getLhs()->getType()) + + getStrLiteral(expr->getRhs()->getType()); LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); return transform(N(value)); } else { // Cannot be evaluated yet: just set the type - if (!expr->isStatic()) - expr->staticValue.type = StaticValue::STRING; - unify(expr->type, ctx->getType("str")); + expr->getType()->getUnbound()->isStatic = 2; } } else { // `"a" == "b"` -> `False` (also handles `!=`) - if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) { - bool eq = expr->lexpr->staticValue.getString() == - expr->rexpr->staticValue.getString(); - bool value = expr->op == "==" ? eq : !eq; + if (expr->getLhs()->getType()->getStrStatic() && + expr->getRhs()->getType()->getStrStatic()) { + bool eq = getStrLiteral(expr->getLhs()->getType()) == + getStrLiteral(expr->getRhs()->getType()); + bool value = expr->getOp() == "==" ? eq : !eq; LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); return transform(N(value)); } else { // Cannot be evaluated yet: just set the type - if (!expr->isStatic()) - expr->staticValue.type = StaticValue::INT; - unify(expr->type, ctx->getType("bool")); + expr->getType()->getUnbound()->isStatic = 3; } } return nullptr; } // Case: static integers - if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) { - int64_t lvalue = expr->lexpr->staticValue.getInt(); - int64_t rvalue = expr->rexpr->staticValue.getInt(); - if (expr->op == "<") + if (expr->getLhs()->getType()->getStatic() && + expr->getRhs()->getType()->getStatic()) { + int64_t lvalue = expr->getLhs()->getType()->getIntStatic() + ? getIntLiteral(expr->getLhs()->getType()) + : getBoolLiteral(expr->getLhs()->getType()); + int64_t rvalue = expr->getRhs()->getType()->getIntStatic() + ? getIntLiteral(expr->getRhs()->getType()) + : getBoolLiteral(expr->getRhs()->getType()); + if (expr->getOp() == "<") lvalue = lvalue < rvalue; - else if (expr->op == "<=") + else if (expr->getOp() == "<=") lvalue = lvalue <= rvalue; - else if (expr->op == ">") + else if (expr->getOp() == ">") lvalue = lvalue > rvalue; - else if (expr->op == ">=") + else if (expr->getOp() == ">=") lvalue = lvalue >= rvalue; - else if (expr->op == "==") + else if (expr->getOp() == "==") lvalue = lvalue == rvalue; - else if (expr->op == "!=") + else if (expr->getOp() == "!=") lvalue = lvalue != rvalue; - else if (expr->op == "&&") + else if (expr->getOp() == "&&") lvalue = lvalue && rvalue; - else if (expr->op == "||") + else if (expr->getOp() == "||") lvalue = lvalue || rvalue; - else if (expr->op == "+") + else if (expr->getOp() == "+") lvalue = lvalue + rvalue; - else if (expr->op == "-") + else if (expr->getOp() == "-") lvalue = lvalue - rvalue; - else if (expr->op == "*") + else if (expr->getOp() == "*") lvalue = lvalue * rvalue; - else if (expr->op == "^") + else if (expr->getOp() == "^") lvalue = lvalue ^ rvalue; - else if (expr->op == "&") + else if (expr->getOp() == "&") lvalue = lvalue & rvalue; - else if (expr->op == "|") + else if (expr->getOp() == "|") lvalue = lvalue | rvalue; - else if (expr->op == ">>") + else if (expr->getOp() == ">>") lvalue = lvalue >> rvalue; - else if (expr->op == "<<") + else if (expr->getOp() == "<<") lvalue = lvalue << rvalue; - else if (expr->op == "//") + else if (expr->getOp() == "//") lvalue = divMod(ctx, lvalue, rvalue).first; - else if (expr->op == "%") + else if (expr->getOp() == "%") lvalue = divMod(ctx, lvalue, rvalue).second; else - seqassert(false, "unknown static operator {}", expr->op); + seqassert(false, "unknown static operator {}", expr->getOp()); LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), lvalue); if (in(std::set{"==", "!=", "<", "<=", ">", ">=", "&&", "||"}, - expr->op)) - return transform(N(bool(lvalue))); + expr->getOp())) + return transform(N(lvalue)); else return transform(N(lvalue)); } else { // Cannot be evaluated yet: just set the type - if (!expr->isStatic()) - expr->staticValue.type = StaticValue::INT; - unify(expr->type, ctx->getType("int")); + if (in(std::set{"==", "!=", "<", "<=", ">", ">=", "&&", "||"}, + expr->getOp())) + expr->getType()->getUnbound()->isStatic = 3; + else + expr->getType()->getUnbound()->isStatic = 1; } return nullptr; @@ -537,96 +667,101 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { /// `a in b` -> `a.__contains__(b)` /// `a not in b` -> `not (a in b)` /// `a is not b` -> `not (a is b)` -ExprPtr TypecheckVisitor::transformBinarySimple(BinaryExpr *expr) { +Expr *TypecheckVisitor::transformBinarySimple(BinaryExpr *expr) { // Case: simple transformations - if (expr->op == "&&") { - return transform(N(expr->lexpr, - N(N(expr->rexpr, "__bool__")), + if (expr->getOp() == "&&") { + return transform(N(expr->getLhs(), + N(N(expr->getRhs(), "__bool__")), N(false))); - } else if (expr->op == "||") { - return transform(N(expr->lexpr, N(true), - N(N(expr->rexpr, "__bool__")))); - } else if (expr->op == "not in") { - return transform(N( - N(N(N(expr->rexpr, "__contains__"), expr->lexpr), - "__invert__"))); - } else if (expr->op == "in") { - return transform(N(N(expr->rexpr, "__contains__"), expr->lexpr)); - } else if (expr->op == "is") { - if (expr->lexpr->getNone() && expr->rexpr->getNone()) + } else if (expr->getOp() == "||") { + return transform(N(expr->getLhs(), N(true), + N(N(expr->getRhs(), "__bool__")))); + } else if (expr->getOp() == "not in") { + return transform(N(N( + N(N(expr->getRhs(), "__contains__"), expr->getLhs()), + "__invert__"))); + } else if (expr->getOp() == "in") { + return transform( + N(N(expr->getRhs(), "__contains__"), expr->getLhs())); + } else if (expr->getOp() == "is") { + if (cast(expr->getLhs()) && cast(expr->getRhs())) return transform(N(true)); - else if (expr->lexpr->getNone()) - return transform(N(expr->rexpr, "is", expr->lexpr)); - } else if (expr->op == "is not") { - return transform(N("!", N(expr->lexpr, "is", expr->rexpr))); + else if (cast(expr->getLhs())) + return transform(N(expr->getRhs(), "is", expr->getLhs())); + } else if (expr->getOp() == "is not") { + return transform( + N("!", N(expr->getLhs(), "is", expr->getRhs()))); } return nullptr; } /// Transform a binary `is` expression by checking for type equality. Handle special `is /// None` cаses as well. See inside for details. -ExprPtr TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { +Expr *TypecheckVisitor::transformBinaryIs(BinaryExpr *expr) { seqassert(expr->op == "is", "not an is binary expression"); // Case: `is None` expressions - if (expr->rexpr->getNone()) { - if (expr->lexpr->getType()->is("NoneType")) + if (cast(expr->getRhs())) { + if (extractClassType(expr->getLhs())->is("NoneType")) return transform(N(true)); - if (!expr->lexpr->getType()->is(TYPE_OPTIONAL)) { + if (!extractClassType(expr->getLhs())->is(TYPE_OPTIONAL)) { // lhs is not optional: `return False` return transform(N(false)); } else { // Special case: Optional[Optional[... Optional[NoneType]]...] == NoneType - auto g = expr->lexpr->getType()->getClass(); - for (; g->generics[0].type->is("Optional"); g = g->generics[0].type->getClass()) + auto g = extractClassType(expr->getLhs()); + for (; extractClassGeneric(g)->is("Optional"); + g = extractClassGeneric(g)->getClass()) ; - if (!g->generics[0].type->getClass()) { - if (!expr->isStatic()) - expr->staticValue.type = StaticValue::INT; - unify(expr->type, ctx->getType("bool")); + if (!extractClassGeneric(g)->getClass()) { + auto typ = instantiateUnbound(); + typ->isStatic = 3; + unify(expr->getType(), typ); return nullptr; } - if (g->generics[0].type->is("NoneType")) + if (extractClassGeneric(g)->is("NoneType")) return transform(N(true)); // lhs is optional: `return lhs.__has__().__invert__()` - return transform(N( - N(N(N(expr->lexpr, "__has__")), "__invert__"))); + if (expr->getType()->getUnbound() && expr->getType()->isStaticType()) + expr->getType()->getUnbound()->isStatic = 0; + return transform(N(N( + N(N(expr->getLhs(), "__has__")), "__invert__"))); } } // Check the type equality (operand types and __raw__ pointers must match). - auto lc = realize(expr->lexpr->getType()); - auto rc = realize(expr->rexpr->getType()); + auto lc = realize(expr->getLhs()->getType()); + auto rc = realize(expr->getRhs()->getType()); if (!lc || !rc) { // Types not known: return early - unify(expr->type, ctx->getType("bool")); + unify(expr->getType(), getStdLibType("bool")); return nullptr; } - if (expr->lexpr->isType() && expr->rexpr->isType()) + if (isTypeExpr(expr->getLhs()) && isTypeExpr(expr->getRhs())) return transform(N(lc->realizedName() == rc->realizedName())); - if (!lc->getRecord() && !rc->getRecord()) { + if (!lc->getClass()->isRecord() && !rc->getClass()->isRecord()) { // Both reference types: `return lhs.__raw__() == rhs.__raw__()` return transform( - N(N(N(expr->lexpr, "__raw__")), - "==", N(N(expr->rexpr, "__raw__")))); + N(N(N(expr->getLhs(), "__raw__")), + "==", N(N(expr->getRhs(), "__raw__")))); } - if (lc->getClass()->is(TYPE_OPTIONAL)) { + if (lc->is(TYPE_OPTIONAL)) { // lhs is optional: `return lhs.__is_optional__(rhs)` return transform( - N(N(expr->lexpr, "__is_optional__"), expr->rexpr)); + N(N(expr->getLhs(), "__is_optional__"), expr->getRhs())); } - if (rc->getClass()->is(TYPE_OPTIONAL)) { + if (rc->is(TYPE_OPTIONAL)) { // rhs is optional: `return rhs.__is_optional__(lhs)` return transform( - N(N(expr->rexpr, "__is_optional__"), expr->lexpr)); + N(N(expr->getRhs(), "__is_optional__"), expr->getLhs())); } if (lc->realizedName() != rc->realizedName()) { // tuple names do not match: `return False` return transform(N(false)); } // Same tuple types: `return lhs == rhs` - return transform(N(expr->lexpr, "==", expr->rexpr)); + return transform(N(expr->getLhs(), "==", expr->getRhs())); } /// Return a binary magic opcode for the provided operator. @@ -654,76 +789,79 @@ std::pair TypecheckVisitor::getMagic(const std::string /// @example /// `a op= b` -> `a.__iopmagic__(b)` /// @param isAtomic if set, use atomic magics if available. -ExprPtr TypecheckVisitor::transformBinaryInplaceMagic(BinaryExpr *expr, bool isAtomic) { - auto [magic, _] = getMagic(expr->op); - auto lt = expr->lexpr->getType()->getClass(); - auto rt = expr->rexpr->getType()->getClass(); - seqassert(lt && rt, "lhs and rhs types not known"); +Expr *TypecheckVisitor::transformBinaryInplaceMagic(BinaryExpr *expr, bool isAtomic) { + auto [magic, _] = getMagic(expr->getOp()); + auto lt = expr->getLhs()->getClassType(); + seqassert(lt, "lhs type not known"); - FuncTypePtr method = nullptr; + FuncType *method = nullptr; // Atomic operations: check if `lhs.__atomic_op__(Ptr[lhs], rhs)` exists if (isAtomic) { - auto ptr = ctx->instantiateGeneric(ctx->getType("Ptr"), {lt}); - if ((method = findBestMethod(lt, format("__atomic_{}__", magic), {ptr, rt}))) { - expr->lexpr = N(N("__ptr__"), expr->lexpr); + auto ptr = instantiateType(getStdLibType("Ptr"), std::vector{lt}); + if ((method = findBestMethod(lt, format("__atomic_{}__", magic), + {ptr.get(), expr->getRhs()->getType()}))) { + expr->lexpr = N(N("__ptr__"), expr->getLhs()); } } // In-place operations: check if `lhs.__iop__(lhs, rhs)` exists - if (!method && expr->inPlace) { - method = findBestMethod(lt, format("__i{}__", magic), {expr->lexpr, expr->rexpr}); + if (!method && expr->isInPlace()) { + method = findBestMethod(lt, format("__i{}__", magic), + std::vector{expr->getLhs(), expr->getRhs()}); } if (method) return transform( - N(N(method->ast->name), expr->lexpr, expr->rexpr)); + N(N(method->getFuncName()), expr->getLhs(), expr->getRhs())); return nullptr; } /// Transform a magic binary expression. /// @example /// `a op b` -> `a.__opmagic__(b)` -ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) { - auto [magic, rightMagic] = getMagic(expr->op); - auto lt = expr->lexpr->getType()->getClass(); - auto rt = expr->rexpr->getType()->getClass(); - seqassert(lt && rt, "lhs and rhs types not known"); +Expr *TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) { + auto [magic, rightMagic] = getMagic(expr->getOp()); + auto lt = expr->getLhs()->getType(); + auto rt = expr->getRhs()->getType(); if (!lt->is("pyobj") && rt->is("pyobj")) { // Special case: `obj op pyobj` -> `rhs.__rmagic__(lhs)` on lhs // Assumes that pyobj implements all left and right magics - auto l = ctx->cache->getTemporaryVar("l"), r = ctx->cache->getTemporaryVar("r"); + auto l = getTemporaryVar("l"); + auto r = getTemporaryVar("r"); return transform( - N(N(N(l), expr->lexpr), - N(N(r), expr->rexpr), + N(N(N(l), expr->getLhs()), + N(N(r), expr->getRhs()), N(N(N(r), format("__{}__", rightMagic)), N(l)))); } if (lt->getUnion()) { // Special case: `union op obj` -> `union.__magic__(rhs)` - return transform( - N(N(expr->lexpr, format("__{}__", magic)), expr->rexpr)); + return transform(N(N(expr->getLhs(), format("__{}__", magic)), + expr->getRhs())); } // Normal operations: check if `lhs.__magic__(lhs, rhs)` exists if (auto method = - findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr})) { + findBestMethod(lt->getClass(), format("__{}__", magic), + std::vector{expr->getLhs(), expr->getRhs()})) { // Normal case: `__magic__(lhs, rhs)` return transform( - N(N(method->ast->name), expr->lexpr, expr->rexpr)); + N(N(method->getFuncName()), expr->getLhs(), expr->getRhs())); } // Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists - if (auto method = findBestMethod(rt, format("__{}__", rightMagic), - {expr->rexpr, expr->lexpr})) { - auto l = ctx->cache->getTemporaryVar("l"), r = ctx->cache->getTemporaryVar("r"); + if (auto method = + findBestMethod(rt->getClass(), format("__{}__", rightMagic), + std::vector{expr->getRhs(), expr->getLhs()})) { + auto l = getTemporaryVar("l"); + auto r = getTemporaryVar("r"); return transform(N( - N(N(l), expr->lexpr), - N(N(r), expr->rexpr), - N(N(method->ast->name), N(r), N(l)))); + N(N(l), expr->getLhs()), + N(N(r), expr->getRhs()), + N(N(method->getFuncName()), N(r), N(l)))); } - // 145 return nullptr; } @@ -732,21 +870,17 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) { /// (integer or slice). If so, statically extract the specified tuple item or a /// sub-tuple (if the index is a slice). /// Works only on normal tuples and partial functions. -std::pair -TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, - const ExprPtr &expr, const ExprPtr &index) { - bool isStaticString = - expr->isStatic() && expr->staticValue.type == StaticValue::STRING; - - if (isStaticString && !expr->staticValue.evaluated) { +std::pair +TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, Expr *expr, Expr *index) { + bool isStaticString = expr->getType()->isStaticType() == 2; + if (isStaticString && !expr->getType()->canRealize()) { return {true, nullptr}; } else if (!isStaticString) { - if (!tuple->getRecord()) + if (!tuple->isRecord()) return {false, nullptr}; - if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) && - !startswith(tuple->name, TYPE_PARTIAL)) { + if (!tuple->is(TYPE_TUPLE)) { if (tuple->is(TYPE_OPTIONAL)) { - if (auto newTuple = tuple->generics[0].type->getClass()) { + if (auto newTuple = extractClassGeneric(tuple)->getClass()) { return transformStaticTupleIndex( newTuple, transform(N(N(FN_UNWRAP), expr)), index); } else { @@ -758,23 +892,19 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, } // Extract the static integer value from expression - auto getInt = [&](int64_t *o, const ExprPtr &e) { + auto getInt = [&](int64_t *o, Expr *e) { if (!e) return true; - auto f = transform(clone(e)); - if (f->staticValue.type == StaticValue::INT) { - seqassert(f->staticValue.evaluated, "{} not evaluated", e); - *o = f->staticValue.getInt(); - return true; - } else if (auto ei = f->getInt()) { - *o = *(ei->intValue); + auto f = transform(clean_clone(e)); + if (auto s = f->getType()->getIntStatic()) { + *o = s->value; return true; } return false; }; - auto sz = int64_t(isStaticString ? expr->staticValue.getString().size() - : tuple->getRecord()->args.size()); + std::string str = isStaticString ? getStrLiteral(expr->getType()) : ""; + auto sz = int64_t(isStaticString ? str.size() : getClassFields(tuple).size()); int64_t start = 0, stop = sz, step = 1, multiple = 0; if (getInt(&start, index)) { // Case: `tuple[int]` @@ -782,16 +912,16 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, if (i < 0 || i >= stop) E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i); start = i; - } else if (auto slice = CAST(index, SliceExpr)) { + } else if (auto slice = cast(index->getOrigExpr())) { // Case: `tuple[int:int:int]` - if (!getInt(&start, slice->start) || !getInt(&stop, slice->stop) || - !getInt(&step, slice->step)) + if (!getInt(&start, slice->getStart()) || !getInt(&stop, slice->getStop()) || + !getInt(&step, slice->getStep())) return {false, nullptr}; // Adjust slice indices (Python slicing rules) - if (slice->step && !slice->start) + if (slice->getStep() && !slice->getStart()) start = step > 0 ? 0 : (sz - 1); - if (slice->step && !slice->stop) + if (slice->getStep() && !slice->getStop()) stop = step > 0 ? sz : -(sz + 1); sliceAdjustIndices(sz, &start, &stop, step); multiple = 1; @@ -800,7 +930,6 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, } if (isStaticString) { - auto str = expr->staticValue.getString(); if (!multiple) { return {true, transform(N(str.substr(start, 1)))}; } else { @@ -810,23 +939,22 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, return {true, transform(N(newStr))}; } } else { - auto classFields = getClassFields(tuple.get()); + auto classFields = getClassFields(tuple); if (!multiple) { return {true, transform(N(expr, classFields[start].name))}; } else { // Generate a sub-tuple - auto var = N(ctx->cache->getTemporaryVar("tup")); + auto var = N(getTemporaryVar("tup")); auto ass = N(var, expr); - std::vector te; + std::vector te; for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) { if (i < 0 || i >= sz) E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i); te.push_back(N(clone(var), classFields[i].name)); } - auto s = ctx->generateTuple(te.size()); - ExprPtr e = - transform(N(std::vector{ass}, - N(N(N(s), "__new__"), te))); + auto s = generateTuple(te.size()); + Expr *e = transform(N(std::vector{ass}, + N(N(TYPE_TUPLE), te))); return {true, e}; } } diff --git a/codon/parser/visitors/typecheck/special.cpp b/codon/parser/visitors/typecheck/special.cpp new file mode 100644 index 00000000..d59f3f7d --- /dev/null +++ b/codon/parser/visitors/typecheck/special.cpp @@ -0,0 +1,1222 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#include +#include +#include +#include +#include +#include + +#include "codon/cir/attribute.h" +#include "codon/cir/types/types.h" +#include "codon/parser/ast.h" +#include "codon/parser/common.h" +#include "codon/parser/visitors/scoping/scoping.h" +#include "codon/parser/visitors/typecheck/typecheck.h" + +using fmt::format; +using namespace codon::error; + +namespace codon::ast { + +using namespace types; + +/// Generate ASTs for all __internal__ functions that deal with vtable generation. +/// Intended to be called once the typechecking is done. +/// TODO: add JIT compatibility. + +void TypecheckVisitor::prepareVTables() { + auto fn = getFunction("__internal__.class_populate_vtables:0"); + fn->ast->suite = generateClassPopulateVTablesAST(); + auto typ = fn->realizations.begin()->second->getType(); + LOG_REALIZE("[poly] {} : {}", typ->debugString(2), fn->ast->suite->toString(2)); + typ->ast = fn->ast; + realizeFunc(typ, true); + + // def class_base_derived_dist(B, D): + // return Tuple[].__elemsize__ + fn = getFunction("__internal__.class_base_derived_dist:0"); + auto oldAst = fn->ast; + for (const auto &[_, real] : fn->realizations) { + fn->ast->suite = generateBaseDerivedDistAST(real->getType()); + LOG_REALIZE("[poly] {} : {}", real->type->debugString(2), *fn->ast->suite); + real->type->ast = fn->ast; + realizeFunc(real->type.get(), true); + } + fn->ast = oldAst; +} + +SuiteStmt *TypecheckVisitor::generateClassPopulateVTablesAST() { + auto suite = N(); + for (const auto &[_, cls] : ctx->cache->classes) { + for (const auto &[r, real] : cls.realizations) { + size_t vtSz = 0; + for (auto &[base, vtable] : real->vtables) { + if (!vtable.ir) + vtSz += vtable.table.size(); + } + if (!vtSz) + continue; + // __internal__.class_set_rtti_vtable(real.ID, size, real.type) + suite->addStmt(N( + N(N(N("__internal__"), "class_set_rtti_vtable"), + N(real->id), N(vtSz + 2), N(r)))); + // LOG("[poly] {} -> {}", r, real->id); + vtSz = 0; + for (const auto &[base, vtable] : real->vtables) { + if (!vtable.ir) { + for (const auto &[k, v] : vtable.table) { + auto &[fn, id] = v; + std::vector ids; + for (auto t : *fn) + ids.push_back(N(t.getType()->realizedName())); + // p[real.ID].__setitem__(f.ID, Function[](f).__raw__()) + LOG_REALIZE("[poly] vtable[{}][{}] = {}", real->id, vtSz + id, + fn->debugString(2)); + Expr *fnCall = N( + N( + N("Function"), + std::vector{N(N(TYPE_TUPLE), ids), + N(fn->getRetType()->realizedName())}), + N(fn->realizedName())); + suite->addStmt(N(N( + N(N("__internal__"), "class_set_rtti_vtable_fn"), + N(real->id), N(vtSz + id), + N(N(fnCall, "__raw__")), N(r)))); + } + vtSz += vtable.table.size(); + } + } + } + } + return suite; +} + +SuiteStmt *TypecheckVisitor::generateBaseDerivedDistAST(FuncType *f) { + auto baseTyp = extractFuncGeneric(f, 0)->getClass(); + size_t baseTypFields = 0; + for (auto &f : getClassFields(baseTyp)) { + if (f.baseClass == baseTyp->name) { + baseTypFields++; + } + } + + auto derivedTyp = extractFuncGeneric(f, 1)->getClass(); + auto fields = getClassFields(derivedTyp); + auto types = std::vector{}; + auto found = false; + for (auto &f : fields) { + if (f.baseClass == baseTyp->name) { + found = true; + break; + } else { + auto ft = realize(instantiateType(f.getType(), derivedTyp)); + types.push_back(N(ft->realizedName())); + } + } + seqassert(found || !baseTypFields, "cannot find distance between {} and {}", + derivedTyp->name, baseTyp->name); + Stmt *suite = N( + N(N(N(TYPE_TUPLE), types), "__elemsize__")); + return SuiteStmt::wrap(suite); +} + +FunctionStmt *TypecheckVisitor::generateThunkAST(FuncType *fp, ClassType *base, + ClassType *derived) { + auto ct = instantiateType(extractClassType(derived->name), base->getClass()); + std::vector args; + for (const auto &a : *fp) + args.push_back(a.getType()); + args[0] = ct.get(); + auto m = findBestMethod(ct->getClass(), getUnmangledName(fp->getFuncName()), args); + if (!m) { + // Print a nice error message + std::vector a; + for (auto &t : args) + a.emplace_back(fmt::format("{}", t->prettyString())); + std::string argsNice = fmt::format("({})", fmt::join(a, ", ")); + E(Error::DOT_NO_ATTR_ARGS, getSrcInfo(), ct->prettyString(), + getUnmangledName(fp->getFuncName()), argsNice); + } + + std::vector ns; + for (auto &a : args) + ns.push_back(a->realizedName()); + auto thunkName = + format("_thunk.{}.{}.{}", base->name, fp->getFuncName(), fmt::join(ns, ".")); + if (getFunction(thunkName + ":0")) + return nullptr; + + // Thunk contents: + // def _thunk...(self, ): + // return ( + // __internal__.class_base_to_derived(self, , ), + // ) + std::vector fnArgs; + fnArgs.emplace_back("self", N(base->realizedName()), nullptr); + for (size_t i = 1; i < args.size(); i++) + fnArgs.emplace_back(getUnmangledName((*fp->ast)[i].getName()), + N(args[i]->realizedName()), nullptr); + std::vector callArgs; + callArgs.emplace_back(N( + N(N("__internal__"), "class_base_to_derived"), N("self"), + N(base->realizedName()), N(derived->realizedName()))); + for (size_t i = 1; i < args.size(); i++) + callArgs.emplace_back(N(getUnmangledName((*fp->ast)[i].getName()))); + auto thunkAst = N( + thunkName, nullptr, fnArgs, + N(N(N(N(m->ast->name), callArgs)))); + thunkAst->setAttribute(Attr::Inline); + return cast(transform(thunkAst)); +} + +/// Generate thunks in all derived classes for a given virtual function (must be fully +/// realizable) and the corresponding base class. +/// @return unique thunk ID. +size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType *fp) { + seqassert(cp->canRealize() && fp->canRealize() && fp->getRetType()->canRealize(), + "{} not realized", fp->debugString(1)); + + // TODO: ugly, ugly; surely needs refactoring + + // Function signature for storing thunks + auto sig = [](types::FuncType *fp) { + std::vector gs; + for (auto a : *fp) + gs.emplace_back(a.getType()->realizedName()); + gs.emplace_back("|"); + for (auto &a : fp->funcGenerics) + if (!a.name.empty()) + gs.push_back(a.type->realizedName()); + return join(gs, ","); + }; + + // Set up the base class information + auto baseCls = cp->name; + auto fnName = getUnmangledName(fp->getFuncName()); + auto key = make_pair(fnName, sig(fp)); + auto &vt = getClassRealization(cp)->vtables[cp->realizedName()]; + + // Add or extract thunk ID + size_t vid = 0; + if (auto i = in(vt.table, key)) { + vid = i->second; + } else { + vid = vt.table.size() + 1; + vt.table[key] = {std::static_pointer_cast(fp->shared_from_this()), vid}; + } + + // Iterate through all derived classes and instantiate the corresponding thunk + for (const auto &[clsName, cls] : ctx->cache->classes) { + bool inMro = false; + for (auto &m : cls.mro) + if (m && m->is(baseCls)) { + inMro = true; + break; + } + if (clsName != baseCls && inMro) { + for (const auto &[_, real] : cls.realizations) { + if (auto thunkAst = generateThunkAST(fp, cp, real->getType())) { + auto thunkFn = getFunction(thunkAst->name); + auto ti = + std::static_pointer_cast(instantiateType(thunkFn->getType())); + auto tm = realizeFunc(ti.get(), true); + seqassert(tm, "bad thunk {}", thunkFn->type->debugString(2)); + real->vtables[baseCls].table[key] = { + std::static_pointer_cast(tm->shared_from_this()), vid}; + } + } + } + } + return vid; +} + +SuiteStmt *TypecheckVisitor::generateFunctionCallInternalAST(FuncType *type) { + // Special case: Function.__call_internal__ + /// TODO: move to IR one day + std::vector items; + items.push_back(nullptr); + std::vector ll; + std::vector lla; + seqassert(extractFuncArgType(type, 1)->is(TYPE_TUPLE), "bad function base: {}", + extractFuncArgType(type, 1)->debugString(2)); + auto as = extractFuncArgType(type, 1)->getClass()->generics.size(); + auto [_, ag] = (*type->ast)[1].getNameWithStars(); + for (int i = 0; i < as; i++) { + ll.push_back(format("%{} = extractvalue {{}} %args, {}", i, i)); + items.push_back(N(N(ag))); + } + items.push_back(N(N("TR"))); + for (int i = 0; i < as; i++) { + items.push_back(N(N(N(ag), N(i)))); + lla.push_back(format("{{}} %{}", i)); + } + items.push_back(N(N("TR"))); + ll.push_back(format("%{} = call {{}} %self({})", as, combine2(lla))); + ll.push_back(format("ret {{}} %{}", as)); + items[0] = N(N(combine2(ll, "\n"))); + return N(items); +} + +SuiteStmt *TypecheckVisitor::generateUnionNewAST(FuncType *type) { + auto unionType = type->funcParent->getUnion(); + seqassert(unionType, "expected union, got {}", *(type->funcParent)); + + Stmt *suite = N(N( + N(N("__internal__"), "new_union"), + N(type->ast->begin()->name), N(unionType->realizedName()))); + return SuiteStmt::wrap(suite); +} + +SuiteStmt *TypecheckVisitor::generateUnionTagAST(FuncType *type) { + // return __internal__.union_get_data(union, T0) + auto tag = getIntLiteral(extractFuncGeneric(type)); + auto unionType = extractFuncArgType(type)->getUnion(); + auto unionTypes = unionType->getRealizationTypes(); + if (tag < 0 || tag >= unionTypes.size()) + E(Error::CUSTOM, getSrcInfo(), "bad union tag"); + auto selfVar = type->ast->begin()->name; + auto suite = N(N( + N(N("__internal__.union_get_data:0"), N(selfVar), + N(unionTypes[tag]->realizedName())))); + return suite; +} + +SuiteStmt *TypecheckVisitor::generateNamedKeysAST(FuncType *type) { + auto n = getIntLiteral(extractFuncGeneric(type)); + if (n < 0 || n >= ctx->cache->generatedTupleNames.size()) + E(Error::CUSTOM, getSrcInfo(), "bad namedkeys index"); + std::vector s; + for (auto &k : ctx->cache->generatedTupleNames[n]) + s.push_back(N(k)); + auto suite = N(N(N(s))); + return suite; +} + +SuiteStmt *TypecheckVisitor::generateTupleMulAST(FuncType *type) { + auto n = std::max(int64_t(0), getIntLiteral(extractFuncGeneric(type))); + auto t = extractFuncArgType(type)->getClass(); + if (!t || !t->is(TYPE_TUPLE)) + return nullptr; + std::vector exprs; + for (size_t i = 0; i < n; i++) + for (size_t j = 0; j < t->generics.size(); j++) + exprs.push_back( + N(N(type->ast->front().getName()), N(j))); + auto suite = N(N(N(exprs))); + return suite; +} + +/// Generate ASTs for dynamically generated functions. +SuiteStmt *TypecheckVisitor::generateSpecialAst(types::FuncType *type) { + // Clone the generic AST that is to be realized + auto ast = type->ast; + if (ast->hasAttribute(Attr::AutoGenerated) && endswith(ast->name, ".__iter__") && + extractFuncArgType(type, 0)->getHeterogenousTuple()) { + // Special case: do not realize auto-generated heterogenous __iter__ + E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); + } else if (ast->hasAttribute(Attr::AutoGenerated) && + endswith(ast->name, ".__getitem__") && + extractFuncArgType(type, 0)->getHeterogenousTuple()) { + // Special case: do not realize auto-generated heterogenous __getitem__ + E(Error::EXPECTED_TYPE, getSrcInfo(), "iterable"); + } else if (startswith(ast->name, "Function.__call_internal__")) { + return generateFunctionCallInternalAST(type); + } else if (startswith(ast->name, "Union.__new__")) { + return generateUnionNewAST(type); + } else if (startswith(ast->name, "__internal__.get_union_tag:0")) { + return generateUnionTagAST(type); + } else if (startswith(ast->name, "__internal__.namedkeys")) { + return generateNamedKeysAST(type); + } else if (startswith(ast->name, "__magic__.mul:0")) { + return generateTupleMulAST(type); + } + return nullptr; +} + +/// Transform named tuples. +/// @example +/// `namedtuple("NT", ["a", ("b", int)])` -> ```@tuple +/// class NT[T1]: +/// a: T1 +/// b: int``` +Expr *TypecheckVisitor::transformNamedTuple(CallExpr *expr) { + // Ensure that namedtuple call is valid + auto name = getStrLiteral(extractFuncGeneric(expr->getExpr()->getType())); + if (expr->size() != 1) + E(Error::CALL_NAMEDTUPLE, expr); + + // Construct the class statement + std::vector generics, params; + auto orig = cast(expr->front().getExpr()->getOrigExpr()); + size_t ti = 1; + for (auto *i : *orig) { + if (auto s = cast(i)) { + generics.emplace_back(format("T{}", ti), N(TYPE_TYPE), nullptr, true); + params.emplace_back(s->getValue(), N(format("T{}", ti++)), nullptr); + continue; + } + auto t = cast(i); + if (t && t->size() == 2 && cast((*t)[0])) { + params.emplace_back(cast((*t)[0])->getValue(), transformType((*t)[1]), + nullptr); + continue; + } + E(Error::CALL_NAMEDTUPLE, i); + } + for (auto &g : generics) + params.push_back(g); + auto cls = N( + N(name, params, nullptr, std::vector{N("tuple")})); + if (auto err = ast::ScopingVisitor::apply(ctx->cache, cls)) + throw exc::ParserException(std::move(err)); + prependStmts->push_back(transform(cls)); + return transformType(N(name)); +} + +/// Transform partial calls (Python syntax). +/// @example +/// `partial(foo, 1, a=2)` -> `foo(1, a=2, ...)` +Expr *TypecheckVisitor::transformFunctoolsPartial(CallExpr *expr) { + if (expr->empty()) + E(Error::CALL_PARTIAL, getSrcInfo()); + std::vector args(expr->items.begin() + 1, expr->items.end()); + args.emplace_back("", N(EllipsisExpr::PARTIAL)); + return transform(N(expr->begin()->value, args)); +} + +/// Typecheck superf method. This method provides the access to the previous matching +/// overload. +/// @example +/// ```class cls: +/// def foo(): print('foo 1') +/// def foo(): +/// superf() # access the previous foo +/// print('foo 2') +/// cls.foo()``` +/// prints "foo 1" followed by "foo 2" +Expr *TypecheckVisitor::transformSuperF(CallExpr *expr) { + auto func = ctx->getBase()->type->getFunc(); + + // Find list of matching superf methods + std::vector supers; + if (!isDispatch(func)) { + if (auto a = func->ast->getAttribute(Attr::ParentClass)) { + auto c = getClass(a->value); + if (auto m = in(c->methods, getUnmangledName(func->getFuncName()))) { + for (auto &overload : getOverloads(*m)) { + if (isDispatch(overload)) + continue; + if (overload == func->getFuncName()) + break; + supers.emplace_back(getFunction(overload)->getType()); + } + } + std::reverse(supers.begin(), supers.end()); + } + } + if (supers.empty()) + E(Error::CALL_SUPERF, expr); + + seqassert(expr->size() == 1 && cast(expr->begin()->getExpr()), + "bad superf call"); + std::vector newArgs; + for (const auto &a : *cast(expr->begin()->getExpr())) + newArgs.emplace_back(a.getExpr()); + auto m = findMatchingMethods( + func->funcParent ? func->funcParent->getClass() : nullptr, supers, newArgs); + if (m.empty()) + E(Error::CALL_SUPERF, expr); + auto c = transform(N(N(m[0]->getFuncName()), newArgs)); + return c; +} + +/// Typecheck and transform super method. Replace it with the current self object cast +/// to the first inherited type. +/// TODO: only an empty super() is currently supported. +Expr *TypecheckVisitor::transformSuper() { + if (!ctx->getBase()->type) + E(Error::CALL_SUPER_PARENT, getSrcInfo()); + auto funcTyp = ctx->getBase()->type->getFunc(); + if (!funcTyp || !funcTyp->ast->hasAttribute(Attr::Method)) + E(Error::CALL_SUPER_PARENT, getSrcInfo()); + if (funcTyp->empty()) + E(Error::CALL_SUPER_PARENT, getSrcInfo()); + + ClassType *typ = extractFuncArgType(funcTyp)->getClass(); + auto cls = getClass(typ); + auto cands = cls->staticParentClasses; + if (cands.empty()) { + // Dynamic inheritance: use MRO + // TODO: maybe super() should be split into two separate functions... + const auto &vCands = cls->mro; + if (vCands.size() < 2) + E(Error::CALL_SUPER_PARENT, getSrcInfo()); + + auto superTyp = instantiateType(vCands[1].get(), typ); + auto self = N(funcTyp->ast->begin()->name); + self->setType(typ->shared_from_this()); + + auto typExpr = N(superTyp->getClass()->name); + typExpr->setType(instantiateTypeVar(superTyp->getClass())); + // LOG("-> {:c} : {:c} {:c}", typ, vCands[1], typExpr->type); + return transform(N(N(N("__internal__"), "class_super"), + self, typExpr, N(1))); + } + + auto name = cands.front(); // the first inherited type + auto superTyp = instantiateType(extractClassType(name), typ); + if (typ->isRecord()) { + // Case: tuple types. Return `tuple(obj.args...)` + std::vector members; + for (auto &field : getClassFields(superTyp->getClass())) + members.push_back( + N(N(funcTyp->ast->begin()->getName()), field.name)); + Expr *e = transform(N(members)); + auto ft = getClassFieldTypes(superTyp->getClass()); + for (size_t i = 0; i < ft.size(); i++) + unify(ft[i].get(), extractClassGeneric(e->getType(), i)); // see super_tuple test + e->setType(superTyp->shared_from_this()); + return e; + } else { + // Case: reference types. Return `__internal__.class_super(self, T)` + auto self = N(funcTyp->ast->begin()->name); + self->setType(typ->shared_from_this()); + return castToSuperClass(self, superTyp->getClass()); + } +} + +/// Typecheck __ptr__ method. This method creates a pointer to an object. Ensure that +/// the argument is a variable binding. +Expr *TypecheckVisitor::transformPtr(CallExpr *expr) { + auto id = cast(expr->begin()->getExpr()); + if (!id) { + // Case where id is guarded by a check + if (auto sexp = cast(expr->begin()->getExpr())) + id = cast(sexp->getExpr()); + } + auto val = id ? ctx->find(id->getValue(), getTime()) : nullptr; + if (!val || !val->isVar()) { + E(Error::CALL_PTR_VAR, expr->begin()->getExpr()); + } + + expr->begin()->value = transform(expr->begin()->getExpr()); + unify(expr->getType(), + instantiateType(getStdLibType("Ptr"), {expr->begin()->getExpr()->getType()})); + if (expr->begin()->getExpr()->isDone()) + expr->setDone(); + return nullptr; +} + +/// Typecheck __array__ method. This method creates a stack-allocated array via alloca. +Expr *TypecheckVisitor::transformArray(CallExpr *expr) { + auto arrTyp = expr->expr->getType()->getFunc(); + unify(expr->getType(), + instantiateType(getStdLibType("Array"), + {extractClassGeneric(arrTyp->getParentType())})); + if (realize(expr->getType())) + expr->setDone(); + return nullptr; +} + +/// Transform isinstance method to a static boolean expression. +/// Special cases: +/// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type +/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type +Expr *TypecheckVisitor::transformIsInstance(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + + expr->begin()->value = transform(expr->begin()->getExpr()); + auto typ = expr->begin()->getExpr()->getClassType(); + if (!typ || !typ->canRealize()) + return nullptr; + + expr->begin()->value = transform(expr->begin()->getExpr()); // again to realize it + + typ = extractClassType(typ); + auto &typExpr = (*expr)[1].value; + if (auto c = cast(typExpr)) { + // Handle `isinstance(obj, (type1, type2, ...))` + if (typExpr->getOrigExpr() && cast(typExpr->getOrigExpr())) { + Expr *result = transform(N(false)); + for (auto *i : *cast(typExpr->getOrigExpr())) { + result = transform(N( + result, "||", + N(N("isinstance"), expr->begin()->getExpr(), i))); + } + return result; + } + } + + auto tei = cast(typExpr); + if (tei && tei->getValue() == "type") { + return transform(N(isTypeExpr(expr->begin()->value))); + } else if (tei && tei->getValue() == "type[Tuple]") { + return transform(N(typ->is(TYPE_TUPLE))); + } else if (tei && tei->getValue() == "type[ByVal]") { + return transform(N(typ->isRecord())); + } else if (tei && tei->getValue() == "type[ByRef]") { + return transform(N(!typ->isRecord())); + } else if (tei && tei->getValue() == "type[Union]") { + return transform(N(typ->getUnion() != nullptr)); + } else if (!extractType(typExpr)->getUnion() && typ->getUnion()) { + auto unionTypes = typ->getUnion()->getRealizationTypes(); + int tag = -1; + for (size_t ui = 0; ui < unionTypes.size(); ui++) { + if (extractType(typExpr)->unify(unionTypes[ui], nullptr) >= 0) { + tag = int(ui); + break; + } + } + if (tag == -1) + return transform(N(false)); + return transform(N( + N(N(N("__internal__"), "union_get_tag"), + expr->begin()->getExpr()), + "==", N(tag))); + } else if (typExpr->getType()->is("pyobj")) { + if (typ->is("pyobj")) { + return transform(N(N("std.internal.python._isinstance.0"), + expr->begin()->getExpr(), (*expr)[1].getExpr())); + } else { + return transform(N(false)); + } + } + + typExpr = transformType(typExpr); + auto targetType = extractType(typExpr); + // Check super types (i.e., statically inherited) as well + for (auto &tx : getSuperTypes(typ->getClass())) { + types::Type::Unification us; + auto s = tx->unify(targetType, &us); + us.undo(); + if (s >= 0) + return transform(N(true)); + } + return transform(N(false)); +} + +/// Transform staticlen method to a static integer expression. This method supports only +/// static strings and tuple types. +Expr *TypecheckVisitor::transformStaticLen(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 1; + + expr->begin()->value = transform(expr->begin()->getExpr()); + auto typ = extractType(expr->begin()->getExpr()); + + if (auto ss = typ->getStrStatic()) { + // Case: staticlen on static strings + return transform(N(ss->value.size())); + } + if (!typ->getClass()) + return nullptr; + if (typ->getUnion()) { + if (realize(typ)) + return transform(N(typ->getUnion()->getRealizationTypes().size())); + return nullptr; + } + if (!typ->getClass()->isRecord()) + E(Error::EXPECTED_TUPLE, expr->begin()->getExpr()); + return transform(N(getClassFields(typ->getClass()).size())); +} + +/// Transform hasattr method to a static boolean expression. +/// This method also supports additional argument types that are used to check +/// for a matching overload (not available in Python). +Expr *TypecheckVisitor::transformHasAttr(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + + auto typ = extractClassType((*expr)[0].getExpr()); + if (!typ) + return nullptr; + + auto member = getStrLiteral(extractFuncGeneric(expr->getExpr()->getType())); + std::vector> args{{"", typ}}; + + if (auto tup = cast((*expr)[1].getExpr())) { + for (auto &a : *tup) { + a.value = transform(a.getExpr()); + if (!a.getExpr()->getClassType()) + return nullptr; + auto t = extractType(a); + args.emplace_back("", t->is("TypeWrap") ? extractClassGeneric(t) : t); + } + } + for (auto &[n, ne] : extractNamedTuple((*expr)[2].getExpr())) { + ne = transform(ne); + auto t = extractType(ne); + args.emplace_back(n, t->is("TypeWrap") ? extractClassGeneric(t) : t); + } + + if (typ->getUnion()) { + Expr *cond = nullptr; + auto unionTypes = typ->getUnion()->getRealizationTypes(); + int tag = -1; + for (size_t ui = 0; ui < unionTypes.size(); ui++) { + auto tu = realize(unionTypes[ui]); + if (!tu) + return nullptr; + auto te = N(tu->getClass()->realizedName()); + auto e = N( + N(N("isinstance"), (*expr)[0].getExpr(), te), "&&", + N(N("hasattr"), te, N(member))); + cond = !cond ? e : N(cond, "||", e); + } + if (!cond) + return transform(N(false)); + return transform(cond); + } + + bool exists = !findMethod(typ->getClass(), member).empty() || + findMember(typ->getClass(), member); + if (exists && args.size() > 1) { + exists &= findBestMethod(typ, member, args) != nullptr; + } + return transform(N(exists)); +} + +/// Transform getattr method to a DotExpr. +Expr *TypecheckVisitor::transformGetAttr(CallExpr *expr) { + auto name = getStrLiteral(extractFuncGeneric(expr->expr->getType())); + + // special handling for NamedTuple + if (expr->begin()->getExpr()->getType() && + expr->begin()->getExpr()->getType()->is("NamedTuple")) { + auto val = expr->begin()->getExpr()->getClassType(); + auto id = getIntLiteral(val); + seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); + auto names = ctx->cache->generatedTupleNames[id]; + for (size_t i = 0; i < names.size(); i++) + if (names[i] == name) { + return transform( + N(N(expr->begin()->getExpr(), "args"), N(i))); + } + E(Error::DOT_NO_ATTR, expr, val->prettyString(), name); + } + return transform(N(expr->begin()->getExpr(), name)); +} + +/// Transform setattr method to a AssignMemberStmt. +Expr *TypecheckVisitor::transformSetAttr(CallExpr *expr) { + auto attr = getStrLiteral(extractFuncGeneric(expr->expr->getType())); + return transform( + N(N((*expr)[0].getExpr(), attr, (*expr)[1].getExpr()), + N(N("NoneType")))); +} + +/// Raise a compiler error. +Expr *TypecheckVisitor::transformCompileError(CallExpr *expr) { + auto msg = getStrLiteral(extractFuncGeneric(expr->expr->getType())); + E(Error::CUSTOM, expr, msg.c_str()); + return nullptr; +} + +/// Convert a class to a tuple. +Expr *TypecheckVisitor::transformTupleFn(CallExpr *expr) { + for (auto &a : *expr) + a.value = transform(a.getExpr()); + auto cls = extractClassType(expr->begin()->getExpr()->getType()); + if (!cls) + return nullptr; + + // tuple(ClassType) is a tuple type that corresponds to a class + if (isTypeExpr(expr->begin()->getExpr())) { + if (!realize(cls)) + return expr; + + std::vector items; + auto ft = getClassFieldTypes(cls); + for (size_t i = 0; i < ft.size(); i++) { + auto rt = realize(ft[i].get()); + seqassert(rt, "cannot realize '{}' in {}", getClass(cls)->fields[i].name, + cls->debugString(2)); + items.push_back(N(rt->realizedName())); + } + auto e = transform(N(N(TYPE_TUPLE), items)); + return e; + } + + std::vector args; + std::string var = getTemporaryVar("tup"); + for (auto &field : getClassFields(cls)) + args.emplace_back(N(N(var), field.name)); + + return transform(N(N(N(var), expr->begin()->getExpr()), + N(args))); +} + +/// Transform type function to a type IdExpr identifier. +Expr *TypecheckVisitor::transformTypeFn(CallExpr *expr) { + expr->begin()->value = transform(expr->begin()->getExpr()); + unify(expr->getType(), instantiateTypeVar(expr->begin()->getExpr()->getType())); + if (!realize(expr->getType())) + return nullptr; + + auto e = N(expr->getType()->realizedName()); + e->setType(expr->getType()->shared_from_this()); + e->setDone(); + return e; +} + +/// Transform __realized__ function to a fully realized type identifier. +Expr *TypecheckVisitor::transformRealizedFn(CallExpr *expr) { + auto fn = (*expr)[0].getExpr()->getType()->shared_from_this(); + auto pt = (*expr)[0].getExpr()->getType()->getPartial(); + if (!fn->getFunc() && pt && pt->isPartialEmpty()) + fn = instantiateType(pt->getPartialFunc()); + if (!fn->getFunc()) + E(Error::CALL_REALIZED_FN, (*expr)[0].getExpr()); + std::vector args; + if (auto tup = cast((*expr)[1].getExpr())) { + for (auto &a : *tup) { + a.value = transform(a.getExpr()); + if (!a.getExpr()->getClassType()) + return nullptr; + auto t = extractType(a); + args.emplace_back(t->is("TypeWrap") ? extractClassGeneric(t) : t); + } + } + for (size_t i = 0; i < std::min(args.size(), fn->getFunc()->size()); i++) + unify((*fn->getFunc())[i], args[i]); + if (auto f = realize(fn.get())) { + auto e = N(f->getFunc()->realizedName()); + e->setType(f->shared_from_this()); + e->setDone(); + return e; + } + return nullptr; +} + +/// Transform __static_print__ function to a fully realized type identifier. +Expr *TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) { + for (auto &a : *cast(expr->begin()->getExpr())) { + fmt::print(stderr, "[static_print] {}: {} ({}){}\n", getSrcInfo(), + a.getExpr()->getType() ? a.getExpr()->getType()->debugString(2) : "-", + a.getExpr()->getType() ? a.getExpr()->getType()->realizedName() : "-", + a.getExpr()->getType()->getStatic() ? " [static]" : ""); + } + return nullptr; +} + +/// Transform __has_rtti__ to a static boolean that indicates RTTI status of a type. +Expr *TypecheckVisitor::transformHasRttiFn(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + + auto t = extractFuncGeneric(expr->getExpr()->getType())->getClass(); + if (!t) + return nullptr; + return transform(N(getClass(t)->hasRTTI())); +} + +// Transform internal.static calls +Expr *TypecheckVisitor::transformStaticFnCanCall(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + + auto typ = extractClassType((*expr)[0].getExpr()); + if (!typ) + return nullptr; + + auto inargs = unpackTupleTypes((*expr)[1].getExpr()); + auto kwargs = unpackTupleTypes((*expr)[2].getExpr()); + seqassert(inargs && kwargs, "bad call to fn_can_call"); + + std::vector callArgs; + for (auto &[v, t] : *inargs) { + callArgs.emplace_back(v, N()); // dummy expression + callArgs.back().getExpr()->setType(t->shared_from_this()); + } + for (auto &[v, t] : *kwargs) { + callArgs.emplace_back(v, N()); // dummy expression + callArgs.back().getExpr()->setType(t->shared_from_this()); + } + if (auto fn = typ->getFunc()) { + return transform(N(canCall(fn, callArgs) >= 0)); + } else if (auto pt = typ->getPartial()) { + return transform(N(canCall(pt->getPartialFunc(), callArgs, pt) >= 0)); + } else { + compilationWarning("cannot use fn_can_call on non-functions", getSrcInfo().file, + getSrcInfo().line, getSrcInfo().col); + return transform(N(false)); + } +} + +Expr *TypecheckVisitor::transformStaticFnArgHasType(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + + auto fn = extractFunction(expr->begin()->getExpr()->getType()); + if (!fn) + E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", + expr->begin()->getExpr()->getType()->prettyString()); + auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); + seqassert(idx, "expected a static integer"); + return transform(N(idx->value >= 0 && idx->value < fn->size() && + (*fn)[idx->value]->canRealize())); +} + +Expr *TypecheckVisitor::transformStaticFnArgGetType(CallExpr *expr) { + auto fn = extractFunction(expr->begin()->getExpr()->getType()); + if (!fn) + E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", + expr->begin()->getExpr()->getType()->prettyString()); + auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); + seqassert(idx, "expected a static integer"); + if (idx->value < 0 || idx->value >= fn->size() || !(*fn)[idx->value]->canRealize()) + E(Error::CUSTOM, getSrcInfo(), "argument does not have type"); + return transform(N((*fn)[idx->value]->realizedName())); +} + +Expr *TypecheckVisitor::transformStaticFnArgs(CallExpr *expr) { + auto fn = extractFunction(expr->begin()->value->getType()); + if (!fn) + E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", + expr->begin()->getExpr()->getType()->prettyString()); + std::vector v; + v.reserve(fn->ast->size()); + for (const auto &a : *fn->ast) { + auto [_, n] = a.getNameWithStars(); + n = getUnmangledName(n); + v.push_back(N(n)); + } + return transform(N(v)); +} + +Expr *TypecheckVisitor::transformStaticFnHasDefault(CallExpr *expr) { + if (auto u = expr->getType()->getUnbound()) + u->isStatic = 3; + + auto fn = extractFunction(expr->begin()->getExpr()->getType()); + if (!fn) + E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", + expr->begin()->getExpr()->getType()->prettyString()); + auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); + seqassert(idx, "expected a static integer"); + if (idx->value < 0 || idx->value >= fn->ast->size()) + E(Error::CUSTOM, getSrcInfo(), "argument out of bounds"); + return transform(N((*fn->ast)[idx->value].getDefault() != nullptr)); +} + +Expr *TypecheckVisitor::transformStaticFnGetDefault(CallExpr *expr) { + auto fn = extractFunction(expr->begin()->getExpr()->getType()); + if (!fn) + E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", + expr->begin()->getExpr()->getType()->prettyString()); + auto idx = extractFuncGeneric(expr->getExpr()->getType())->getIntStatic(); + seqassert(idx, "expected a static integer"); + if (idx->value < 0 || idx->value >= fn->ast->size()) + E(Error::CUSTOM, getSrcInfo(), "argument out of bounds"); + return transform((*fn->ast)[idx->value].getDefault()); +} + +Expr *TypecheckVisitor::transformStaticFnWrapCallArgs(CallExpr *expr) { + auto typ = expr->begin()->getExpr()->getClassType(); + if (!typ) + return nullptr; + + auto fn = extractFunction(expr->begin()->getExpr()->getType()); + if (!fn) + E(Error::CUSTOM, getSrcInfo(), "expected a function, got '{}'", + expr->begin()->getExpr()->getType()->prettyString()); + + std::vector callArgs; + if (auto tup = cast((*expr)[1].getExpr()->getOrigExpr())) { + for (auto *a : *tup) { + callArgs.emplace_back("", a); + } + } + if (auto kw = cast((*expr)[1].getExpr()->getOrigExpr())) { + auto kwCls = getClass(expr->getClassType()); + seqassert(kwCls, "cannot find {}", expr->getClassType()->name); + for (size_t i = 0; i < kw->size(); i++) { + callArgs.emplace_back(kwCls->fields[i].name, (*kw)[i].getExpr()); + } + } + auto tempCall = transform(N(N(fn->getFuncName()), callArgs)); + if (!tempCall->isDone()) + return nullptr; + + std::vector tupArgs; + for (auto &a : *cast(tempCall)) + tupArgs.push_back(a.getExpr()); + return transform(N(tupArgs)); +} + +Expr *TypecheckVisitor::transformStaticVars(CallExpr *expr) { + auto withIdx = getBoolLiteral(extractFuncGeneric(expr->getExpr()->getType())); + + types::ClassType *typ = nullptr; + std::vector tupleItems; + auto e = transform(expr->begin()->getExpr()); + if (!(typ = e->getClassType())) + return nullptr; + + size_t idx = 0; + for (auto &f : getClassFields(typ)) { + auto k = N(f.name); + auto v = N(expr->begin()->value, f.name); + if (withIdx) { + auto i = N(idx); + tupleItems.push_back(N(std::vector{i, k, v})); + } else { + tupleItems.push_back(N(std::vector{k, v})); + } + idx++; + } + return transform(N(tupleItems)); +} + +Expr *TypecheckVisitor::transformStaticTupleType(CallExpr *expr) { + auto funcTyp = expr->getExpr()->getType()->getFunc(); + auto t = extractFuncGeneric(funcTyp)->getClass(); + if (!t || !realize(t)) + return nullptr; + auto n = getIntLiteral(extractFuncGeneric(funcTyp, 1)); + types::TypePtr typ = nullptr; + auto f = getClassFields(t); + if (n < 0 || n >= f.size()) + E(Error::CUSTOM, getSrcInfo(), "invalid index"); + auto rt = realize(instantiateType(f[n].getType(), t)); + return transform(N(rt->realizedName())); +} + +std::vector +TypecheckVisitor::populateStaticTupleLoop(Expr *iter, + const std::vector &vars) { + std::vector block; + auto stmt = N(N(vars[0]), nullptr, nullptr); + auto call = cast(cast(iter)->front()); + if (vars.size() != 1) + E(Error::CUSTOM, getSrcInfo(), "expected one item"); + for (auto &a : *call) { + stmt->rhs = transform(clean_clone(a.value)); + if (auto st = stmt->rhs->getType()->getStatic()) { + stmt->type = N(N("Static"), N(st->name)); + } else { + stmt->type = nullptr; + } + block.push_back(clone(stmt)); + } + return block; +} + +std::vector +TypecheckVisitor::populateSimpleStaticRangeLoop(Expr *iter, + const std::vector &vars) { + if (vars.size() != 1) + E(Error::CUSTOM, getSrcInfo(), "expected one item"); + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + auto stmt = N(N(vars[0]), nullptr, nullptr); + std::vector block; + auto ed = getIntLiteral(extractFuncGeneric(fn->getType())); + if (ed > MAX_STATIC_ITER) + E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, ed); + for (int64_t i = 0; i < ed; i++) { + stmt->rhs = N(i); + stmt->type = N(N("Static"), N("int")); + block.push_back(clone(stmt)); + } + return block; +} + +std::vector +TypecheckVisitor::populateStaticRangeLoop(Expr *iter, + const std::vector &vars) { + if (vars.size() != 1) + E(Error::CUSTOM, getSrcInfo(), "expected one item"); + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + auto stmt = N(N(vars[0]), nullptr, nullptr); + std::vector block; + auto st = getIntLiteral(extractFuncGeneric(fn->getType(), 0)); + auto ed = getIntLiteral(extractFuncGeneric(fn->getType(), 1)); + auto step = getIntLiteral(extractFuncGeneric(fn->getType(), 2)); + if (std::abs(st - ed) / std::abs(step) > MAX_STATIC_ITER) + E(Error::STATIC_RANGE_BOUNDS, fn, MAX_STATIC_ITER, + std::abs(st - ed) / std::abs(step)); + for (int64_t i = st; step > 0 ? i < ed : i > ed; i += step) { + stmt->rhs = N(i); + stmt->type = N(N("Static"), N("int")); + block.push_back(clone(stmt)); + } + return block; +} + +std::vector +TypecheckVisitor::populateStaticFnOverloadsLoop(Expr *iter, + const std::vector &vars) { + if (vars.size() != 1) + E(Error::CUSTOM, getSrcInfo(), "expected one item"); + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + auto stmt = N(N(vars[0]), nullptr, nullptr); + std::vector block; + auto typ = extractFuncGeneric(fn->getType(), 0)->getClass(); + seqassert(extractFuncGeneric(fn->getType(), 1)->getStrStatic(), "bad static string"); + auto name = getStrLiteral(extractFuncGeneric(fn->getType(), 1)); + if (auto n = in(getClass(typ)->methods, name)) { + auto mt = getOverloads(*n); + for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { + auto &method = mt[mti]; + auto cfn = getFunction(method); + if (isDispatch(method) || !cfn->type) + continue; + if (typ->getHeterogenousTuple()) { + if (cfn->ast->hasAttribute(Attr::AutoGenerated) && + (endswith(cfn->ast->name, ".__iter__") || + endswith(cfn->ast->name, ".__getitem__"))) { + // ignore __getitem__ and other heterogenuous methods + continue; + } + } + stmt->rhs = N(method); + block.push_back(clone(stmt)); + } + } + return block; +} + +std::vector +TypecheckVisitor::populateStaticEnumerateLoop(Expr *iter, + const std::vector &vars) { + if (vars.size() != 2) + E(Error::CUSTOM, getSrcInfo(), "expected two items"); + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + auto stmt = N(N(vars[0]), nullptr, nullptr); + std::vector block; + auto typ = extractFuncArgType(fn->getType())->getClass(); + if (typ && typ->isRecord()) { + for (size_t i = 0; i < getClassFields(typ).size(); i++) { + auto b = N(std::vector{ + N(N(vars[0]), N(i), + N(N("Static"), N("int"))), + N( + N(vars[1]), + N(clone((*cast(iter))[0].value), N(i)))}); + block.push_back(b); + } + } else { + E(Error::CUSTOM, getSrcInfo(), "staticenumerate needs a tuple"); + } + return block; +} + +std::vector +TypecheckVisitor::populateStaticVarsLoop(Expr *iter, + const std::vector &vars) { + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + bool withIdx = getBoolLiteral(extractFuncGeneric(fn->getType())); + if (!withIdx && vars.size() != 2) + E(Error::CUSTOM, getSrcInfo(), "expected two items"); + else if (withIdx && vars.size() != 3) + E(Error::CUSTOM, getSrcInfo(), "expected three items"); + std::vector block; + auto typ = extractFuncArgType(fn->getType())->getClass(); + size_t idx = 0; + for (auto &f : getClassFields(typ)) { + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(idx), + N(N("Static"), N("int")))); + } + stmts.push_back(N(N(vars[withIdx]), N(f.name), + N(N("Static"), N("str")))); + stmts.push_back( + N(N(vars[withIdx + 1]), + N(clone((*cast(iter))[0].value), f.name))); + auto b = N(stmts); + block.push_back(b); + idx++; + } + return block; +} + +std::vector +TypecheckVisitor::populateStaticVarTypesLoop(Expr *iter, + const std::vector &vars) { + auto fn = + cast(iter) ? cast(cast(iter)->getExpr()) : nullptr; + auto typ = realize(extractFuncGeneric(fn->getType(), 0)->getClass()); + bool withIdx = getBoolLiteral(extractFuncGeneric(fn->getType(), 1)); + if (!withIdx && vars.size() != 1) + E(Error::CUSTOM, getSrcInfo(), "expected one item"); + else if (withIdx && vars.size() != 2) + E(Error::CUSTOM, getSrcInfo(), "expected two items"); + + seqassert(typ, "vars_types expects a realizable type, got '{}' instead", + *(extractFuncGeneric(fn->getType(), 0))); + std::vector block; + if (auto utyp = typ->getUnion()) { + for (size_t i = 0; i < utyp->getRealizationTypes().size(); i++) { + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(i), + N(N("Static"), N("int")))); + } + stmts.push_back( + N(N(vars[1]), + N(utyp->getRealizationTypes()[i]->realizedName()))); + auto b = N(stmts); + block.push_back(b); + } + } else { + size_t idx = 0; + for (auto &f : getClassFields(typ->getClass())) { + auto ta = realize(instantiateType(f.type.get(), typ->getClass())); + seqassert(ta, "cannot realize '{}'", f.type->debugString(1)); + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(idx), + N(N("Static"), N("int")))); + } + stmts.push_back( + N(N(vars[withIdx]), N(ta->realizedName()))); + auto b = N(stmts); + block.push_back(b); + idx++; + } + } + return block; +} + +std::vector TypecheckVisitor::populateStaticHeterogenousTupleLoop( + Expr *iter, const std::vector &vars) { + std::vector block; + std::string tupleVar; + Stmt *preamble = nullptr; + if (!cast(iter)) { + tupleVar = getTemporaryVar("tuple"); + preamble = N(N(tupleVar), iter); + } else { + tupleVar = cast(iter)->getValue(); + } + for (size_t i = 0; i < iter->getClassType()->generics.size(); i++) { + auto s = N(); + if (vars.size() > 1) { + for (size_t j = 0; j < vars.size(); j++) { + s->addStmt( + N(N(vars[j]), + N(N(N(tupleVar), N(i)), + N(j)))); + } + } else { + s->addStmt(N(N(vars[0]), + N(N(tupleVar), N(i)))); + } + block.push_back(s); + } + block.push_back(preamble); + return block; +} + +} // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index ce8fec06..3061f4cb 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -2,15 +2,19 @@ #include "typecheck.h" +#include #include #include #include +#include "codon/cir/pyextension.h" +#include "codon/cir/util/irtools.h" #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/simplify/ctx.h" +#include "codon/parser/match.h" +#include "codon/parser/peg/peg.h" +#include "codon/parser/visitors/scoping/scoping.h" #include "codon/parser/visitors/typecheck/ctx.h" -#include using fmt::format; using namespace codon::error; @@ -19,77 +23,241 @@ namespace codon::ast { using namespace types; -StmtPtr TypecheckVisitor::apply(Cache *cache, const StmtPtr &stmts) { - if (!cache->typeCtx) - cache->typeCtx = std::make_shared(cache); - TypecheckVisitor v(cache->typeCtx); - auto so = clone(stmts); - auto s = v.inferTypes(so, true); - if (!s) { - LOG_REALIZE("[error] {}", so->toString(2)); - v.error("cannot typecheck the program"); +/// Simplify an AST node. Load standard library if needed. +/// @param cache Pointer to the shared cache ( @c Cache ) +/// @param file Filename to be used for error reporting +/// @param barebones Use the bare-bones standard library for faster testing +/// @param defines User-defined static values (typically passed as `codon run -DX=Y`). +/// Each value is passed as a string. +Stmt *TypecheckVisitor::apply( + Cache *cache, Stmt *node, const std::string &file, + const std::unordered_map &defines, + const std::unordered_map &earlyDefines, bool barebones) { + auto preamble = std::make_shared>(); + seqassertn(cache->module, "cache's module is not set"); + + // Load standard library if it has not been loaded + if (!in(cache->imports, STDLIB_IMPORT)) + loadStdLibrary(cache, preamble, earlyDefines, barebones); + + // Set up the context and the cache + auto ctx = std::make_shared(cache, file); + cache->imports[file] = cache->imports[MAIN_IMPORT] = {MAIN_IMPORT, file, ctx}; + ctx->setFilename(file); + ctx->moduleName = {ImportFile::PACKAGE, file, MODULE_MAIN}; + + // Prepare the code + auto tv = TypecheckVisitor(ctx, preamble); + SuiteStmt *suite = tv.N(); + auto &stmts = suite->items; + stmts.push_back(tv.N(".toplevel", std::vector{}, nullptr, + std::vector{tv.N(Attr::Internal)})); + // Load compile-time defines (e.g., codon run -DFOO=1 ...) + for (auto &d : defines) { + stmts.push_back( + tv.N(tv.N(d.first), tv.N(d.second), + tv.N(tv.N("Static"), tv.N("int")))); + } + // Set up __name__ + stmts.push_back( + tv.N(tv.N("__name__"), tv.N(MODULE_MAIN))); + stmts.push_back(node); + + if (auto err = ScopingVisitor::apply(cache, suite)) + throw exc::ParserException(std::move(err)); + auto n = tv.inferTypes(suite, true); + if (!n) { + auto errors = tv.findTypecheckErrors(suite); + throw exc::ParserException(errors); + } + + suite = tv.N(); + suite->items.push_back(tv.N(*preamble)); + + // Add dominated assignment declarations + suite->items.insert(suite->items.end(), ctx->scope.back().stmts.begin(), + ctx->scope.back().stmts.end()); + suite->items.push_back(n); + + if (cast(n)) + tv.prepareVTables(); + + if (!ctx->cache->errors.empty()) + throw exc::ParserException(ctx->cache->errors); + + return suite; +} + +void TypecheckVisitor::loadStdLibrary( + Cache *cache, const std::shared_ptr> &preamble, + const std::unordered_map &earlyDefines, bool barebones) { + // Load the internal.__init__ + auto stdlib = std::make_shared(cache, STDLIB_IMPORT); + auto stdlibPath = + getImportFile(cache->argv0, STDLIB_INTERNAL_MODULE, "", true, cache->module0); + const std::string initFile = "__init__.codon"; + if (!stdlibPath || !endswith(stdlibPath->path, initFile)) + E(Error::COMPILER_NO_STDLIB); + + /// Use __init_test__ for faster testing (e.g., #%% name,barebones) + /// TODO: get rid of it one day... + if (barebones) { + stdlibPath->path = + stdlibPath->path.substr(0, stdlibPath->path.size() - initFile.size()) + + "__init_test__.codon"; + } + stdlib->setFilename(stdlibPath->path); + cache->imports[stdlibPath->path] = + cache->imports[STDLIB_IMPORT] = {STDLIB_IMPORT, stdlibPath->path, stdlib}; + + // Load the standard library + stdlib->isStdlibLoading = true; + stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"}; + stdlib->setFilename(stdlibPath->path); + + // 1. Core definitions + cache->classes[VAR_CLASS_TOPLEVEL] = Cache::Class(); + auto coreOrErr = + parseCode(stdlib->cache, stdlibPath->path, "from internal.core import *"); + if (!coreOrErr) + throw exc::ParserException(coreOrErr.takeError()); + auto *core = *coreOrErr; + if (auto err = ScopingVisitor::apply(stdlib->cache, core)) + throw exc::ParserException(std::move(err)); + auto tv = TypecheckVisitor(stdlib, preamble); + core = tv.inferTypes(core, true); + preamble->push_back(core); + + // 2. Load early compile-time defines (for standard library) + for (auto &d : earlyDefines) { + auto tv = TypecheckVisitor(stdlib, preamble); + auto s = + tv.N(tv.N(d.first), tv.N(d.second), + tv.N(tv.N("Static"), tv.N("int"))); + auto def = tv.transform(s); + preamble->push_back(def); + } + + // 3. Load stdlib + auto stdOrErr = parseFile(stdlib->cache, stdlibPath->path); + if (!stdOrErr) + throw exc::ParserException(stdOrErr.takeError()); + auto std = *stdOrErr; + if (auto err = ScopingVisitor::apply(stdlib->cache, std)) + throw exc::ParserException(std::move(err)); + tv = TypecheckVisitor(stdlib, preamble); + std = tv.inferTypes(std, true); + preamble->push_back(std); + stdlib->isStdlibLoading = false; +} + +/// Simplify an AST node. Assumes that the standard library is loaded. +Stmt *TypecheckVisitor::apply(const std::shared_ptr &ctx, Stmt *node, + const std::string &file) { + auto oldFilename = ctx->getFilename(); + ctx->setFilename(file); + auto preamble = std::make_shared>(); + auto tv = TypecheckVisitor(ctx, preamble); + auto n = tv.inferTypes(node, true); + ctx->setFilename(oldFilename); + if (!n) { + auto errors = tv.findTypecheckErrors(node); + throw exc::ParserException(errors); } - if (s->getSuite()) - v.prepareVTables(); - return s; + if (!ctx->cache->errors.empty()) + throw exc::ParserException(ctx->cache->errors); + + auto suite = ctx->cache->N(*preamble); + suite->addStmt(n); + return suite; } /**************************************************************************************/ TypecheckVisitor::TypecheckVisitor(std::shared_ptr ctx, - const std::shared_ptr> &stmts) - : ctx(std::move(ctx)) { - prependStmts = stmts ? stmts : std::make_shared>(); + const std::shared_ptr> &pre, + const std::shared_ptr> &stmts) + : resultExpr(nullptr), resultStmt(nullptr), ctx(std::move(ctx)) { + preamble = pre ? pre : std::make_shared>(); + prependStmts = stmts ? stmts : std::make_shared>(); } /**************************************************************************************/ +Expr *TypecheckVisitor::transform(Expr *expr) { return transform(expr, true); } + /// Transform an expression node. -ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { +Expr *TypecheckVisitor::transform(Expr *expr, bool allowTypes) { if (!expr) return nullptr; - auto typ = expr->type; - if (!expr->done) { - bool isIntStatic = expr->staticValue.type == StaticValue::INT; - TypecheckVisitor v(ctx, prependStmts); + // auto k = typeid(*expr).name(); + // Cache::CTimer t(ctx->cache, k); + + if (!expr->getType()) + expr->setType(instantiateUnbound()); + + if (!expr->isDone()) { + TypecheckVisitor v(ctx, preamble, prependStmts); v.setSrcInfo(expr->getSrcInfo()); - ctx->pushSrcInfo(expr->getSrcInfo()); + ctx->pushNode(expr); expr->accept(v); - ctx->popSrcInfo(); + ctx->popNode(); if (v.resultExpr) { - v.resultExpr->attributes |= expr->attributes; - v.resultExpr->origExpr = expr; + for (auto it = expr->attributes_begin(); it != expr->attributes_end(); ++it) { + const auto *attr = expr->getAttribute(*it); + if (!v.resultExpr->hasAttribute(*it)) + v.resultExpr->setAttribute(*it, attr->clone()); + } + v.resultExpr->setOrigExpr(expr); expr = v.resultExpr; + if (!expr->getType()) + expr->setType(instantiateUnbound()); } - seqassert(expr->type, "type not set for {}", expr); - if (!(isIntStatic && expr->type->is("bool"))) - unify(typ, expr->type); - if (expr->done) { + if (!allowTypes && expr && isTypeExpr(expr)) + E(Error::UNEXPECTED_TYPE, expr, "type"); + if (expr->isDone()) ctx->changedNodes++; - } } - realize(typ); - LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), expr, expr->isDone() ? "[done]" : ""); + if (expr) { + if (!expr->hasAttribute(Attr::ExprDoNotRealize)) + if (auto p = realize(expr->getType())) + unify(expr->getType(), p); + LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), *(expr), + expr->isDone() ? "[done]" : ""); + } return expr; } /// Transform a type expression node. +/// @param allowTypeOf Set if `type()` expressions are allowed. Usually disallowed in +/// class/function definitions. /// Special case: replace `None` with `NoneType` /// @throw @c ParserException if a node is not a type (use @c transform instead). -ExprPtr TypecheckVisitor::transformType(ExprPtr &expr) { - if (expr && expr->getNone()) { - expr = N(expr->getSrcInfo(), "NoneType"); - expr->markType(); +Expr *TypecheckVisitor::transformType(Expr *expr, bool allowTypeOf) { + auto oldTypeOf = ctx->allowTypeOf; + ctx->allowTypeOf = allowTypeOf; + if (cast(expr)) { + auto ne = N("NoneType"); + ne->setSrcInfo(expr->getSrcInfo()); + expr = ne; } - transform(expr); + expr = transform(expr); + ctx->allowTypeOf = oldTypeOf; if (expr) { - if (!expr->isType() && expr->isStatic()) { - expr->setType(Type::makeStatic(ctx->cache, expr)); - } else if (!expr->isType()) { - E(Error::EXPECTED_TYPE, expr, "type"); + if (expr->getType()->isStaticType()) { + ; + } else if (isTypeExpr(expr)) { + expr->setType(instantiateType(expr->getType())); + } else if (expr->getType()->getUnbound() && + !expr->getType()->getUnbound()->genericName.empty()) { + // generic! + expr->setType(instantiateType(expr->getType())); + } else if (expr->getType()->getUnbound() && expr->getType()->getUnbound()->trait) { + // generic (is type)! + expr->setType(instantiateType(expr->getType())); } else { - expr->setType(ctx->instantiate(expr->getType())); + E(Error::EXPECTED_TYPE, expr, "type"); } } return expr; @@ -100,18 +268,25 @@ void TypecheckVisitor::defaultVisit(Expr *e) { } /// Transform a statement node. -StmtPtr TypecheckVisitor::transform(StmtPtr &stmt) { - if (!stmt || stmt->done) +Stmt *TypecheckVisitor::transform(Stmt *stmt) { + if (!stmt || stmt->isDone()) return stmt; - TypecheckVisitor v(ctx); + TypecheckVisitor v(ctx, preamble); v.setSrcInfo(stmt->getSrcInfo()); - auto oldAge = ctx->age; - stmt->age = ctx->age = std::max(stmt->age, oldAge); - ctx->pushSrcInfo(stmt->getSrcInfo()); + if (!stmt->toString(-1).empty()) + LOG_TYPECHECK("> [{}] [{}:{}] {}", getSrcInfo(), ctx->getBaseName(), + ctx->getBase()->iteration, stmt->toString(-1)); + ctx->pushNode(stmt); + + int64_t time = 0; + if (auto a = stmt->getAttribute(Attr::ExprTime)) + time = a->value; + auto oldTime = ctx->time; + ctx->time = time; stmt->accept(v); - ctx->popSrcInfo(); - ctx->age = oldAge; + ctx->time = oldTime; + ctx->popNode(); if (v.resultStmt) stmt = v.resultStmt; if (!v.prependStmts->empty()) { @@ -119,12 +294,16 @@ StmtPtr TypecheckVisitor::transform(StmtPtr &stmt) { v.prependStmts->push_back(stmt); bool done = true; for (auto &s : *(v.prependStmts)) - done &= s->done; + done &= s->isDone(); stmt = N(*v.prependStmts); - stmt->done = done; + if (done) + stmt->setDone(); } - if (stmt->done) + if (stmt->isDone()) ctx->changedNodes++; + if (!stmt->toString(-1).empty()) + LOG_TYPECHECK("< [{}] [{}:{}] {}", getSrcInfo(), ctx->getBaseName(), + ctx->getBase()->iteration, stmt->toString(-1)); return stmt; } @@ -137,206 +316,204 @@ void TypecheckVisitor::defaultVisit(Stmt *s) { /// Typecheck statement expressions. void TypecheckVisitor::visit(StmtExpr *expr) { auto done = true; - for (auto &s : expr->stmts) { - transform(s); + for (auto &s : *expr) { + s = transform(s); done &= s->isDone(); } - transform(expr->expr); - unify(expr->type, expr->expr->type); - if (done && expr->expr->isDone()) + expr->expr = transform(expr->getExpr()); + unify(expr->getType(), expr->getExpr()->getType()); + if (done && expr->getExpr()->isDone()) expr->setDone(); } /// Typecheck a list of statements. void TypecheckVisitor::visit(SuiteStmt *stmt) { - std::vector stmts; // for filtering out nullptr statements + std::vector stmts; // for filtering out nullptr statements auto done = true; - for (auto &s : stmt->stmts) { + + std::vector prepend; + if (auto b = stmt->getAttribute(Attr::Bindings)) { + for (auto &[n, hasUsed] : b->bindings) { + prepend.push_back(N(N(n), nullptr)); + if (hasUsed) + prepend.push_back(N( + N(fmt::format("{}{}", n, VAR_USED_SUFFIX)), N(false))); + } + stmt->eraseAttribute(Attr::Bindings); + } + if (!prepend.empty()) + stmt->items.insert(stmt->items.begin(), prepend.begin(), prepend.end()); + for (auto *s : *stmt) { if (ctx->returnEarly) { // If returnEarly is set (e.g., in the function) ignore the rest break; } - if (transform(s)) { - stmts.push_back(s); - done &= stmts.back()->isDone(); + if ((s = transform(s))) { + if (!cast(s)) { + done &= s->isDone(); + stmts.push_back(s); + } else { + for (auto *ss : *cast(s)) { + done &= ss->isDone(); + stmts.push_back(ss); + } + } } } - stmt->stmts = stmts; + stmt->items = stmts; if (done) stmt->setDone(); } /// Typecheck expression statements. void TypecheckVisitor::visit(ExprStmt *stmt) { - transform(stmt->expr); - if (stmt->expr->isDone()) + stmt->expr = transform(stmt->getExpr()); + if (stmt->getExpr()->isDone()) stmt->setDone(); } +void TypecheckVisitor::visit(CustomStmt *stmt) { + if (stmt->getSuite()) { + auto fn = in(ctx->cache->customBlockStmts, stmt->getKeyword()); + seqassert(fn, "unknown keyword {}", stmt->getKeyword()); + resultStmt = (*fn).second(this, stmt); + } else { + auto fn = in(ctx->cache->customExprStmts, stmt->getKeyword()); + seqassert(fn, "unknown keyword {}", stmt->getKeyword()); + resultStmt = (*fn)(this, stmt); + } +} + void TypecheckVisitor::visit(CommentStmt *stmt) { stmt->setDone(); } /**************************************************************************************/ /// Select the best method indicated of an object that matches the given argument /// types. See @c findMatchingMethods for details. -types::FuncTypePtr -TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, const std::string &member, - const std::vector &args) { - std::vector callArgs; +types::FuncType * +TypecheckVisitor::findBestMethod(ClassType *typ, const std::string &member, + const std::vector &args) { + std::vector callArgs; for (auto &a : args) { - callArgs.push_back({"", std::make_shared()}); // dummy expression - callArgs.back().value->setType(a); + callArgs.emplace_back("", N()); // dummy expression + callArgs.back().value->setType(a->shared_from_this()); } - auto methods = ctx->findMethod(typ.get(), member, false); + auto methods = findMethod(typ, member, false); auto m = findMatchingMethods(typ, methods, callArgs); return m.empty() ? nullptr : m[0]; } /// Select the best method indicated of an object that matches the given argument /// types. See @c findMatchingMethods for details. -types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, - const std::string &member, - const std::vector &args) { - std::vector callArgs; +types::FuncType *TypecheckVisitor::findBestMethod(ClassType *typ, + const std::string &member, + const std::vector &args) { + std::vector callArgs; for (auto &a : args) - callArgs.push_back({"", a}); - auto methods = ctx->findMethod(typ.get(), member, false); + callArgs.emplace_back("", a); + auto methods = findMethod(typ, member, false); auto m = findMatchingMethods(typ, methods, callArgs); return m.empty() ? nullptr : m[0]; } /// Select the best method indicated of an object that matches the given argument /// types. See @c findMatchingMethods for details. -types::FuncTypePtr TypecheckVisitor::findBestMethod( - const ClassTypePtr &typ, const std::string &member, - const std::vector> &args) { - std::vector callArgs; +types::FuncType *TypecheckVisitor::findBestMethod( + ClassType *typ, const std::string &member, + const std::vector> &args) { + std::vector callArgs; for (auto &[n, a] : args) { - callArgs.push_back({n, std::make_shared()}); // dummy expression - callArgs.back().value->setType(a); + callArgs.emplace_back(n, N()); // dummy expression + callArgs.back().value->setType(a->shared_from_this()); } - auto methods = ctx->findMethod(typ.get(), member, false); + auto methods = findMethod(typ, member, false); auto m = findMatchingMethods(typ, methods, callArgs); return m.empty() ? nullptr : m[0]; } -// Search expression tree for a identifier -class IdSearchVisitor : public CallbackASTVisitor { - std::string what; - bool result; - -public: - IdSearchVisitor(std::string what) : what(std::move(what)), result(false) {} - bool transform(const std::shared_ptr &expr) override { - if (result) - return result; - IdSearchVisitor v(what); - if (expr) - expr->accept(v); - return v.result; - } - bool transform(const std::shared_ptr &stmt) override { - if (result) - return result; - IdSearchVisitor v(what); - if (stmt) - stmt->accept(v); - return v.result; - } - void visit(IdExpr *expr) override { - if (expr->value == what) - result = true; - } -}; - /// Check if a function can be called with the given arguments. /// See @c reorderNamedArgs for details. -int TypecheckVisitor::canCall(const types::FuncTypePtr &fn, - const std::vector &args, - std::shared_ptr part) { - auto getPartialArg = [&](size_t pi) -> types::TypePtr { - if (pi < part->args.size()) - return part->args[pi]; - else - return nullptr; - }; +int TypecheckVisitor::canCall(types::FuncType *fn, const std::vector &args, + types::ClassType *part) { + std::vector partialArgs; + if (part && part->getPartial()) { + auto known = part->getPartialMask(); + auto knownArgTypes = extractClassGeneric(part, 1)->getClass(); + for (size_t i = 0, j = 0, k = 0; i < known.size(); i++) + if (known[i]) { + partialArgs.push_back(extractClassGeneric(knownArgTypes, k)); + k++; + } + } - std::vector> reordered; + std::vector> reordered; auto niGenerics = fn->ast->getNonInferrableGenerics(); - auto score = ctx->reorderNamedArgs( - fn.get(), args, + auto score = reorderNamedArgs( + fn, args, [&](int s, int k, const std::vector> &slots, bool _) { for (int si = 0, gi = 0, pi = 0; si < slots.size(); si++) { - if (fn->ast->args[si].status == Param::Generic) { + if ((*fn->ast)[si].isGeneric()) { if (slots[si].empty()) { // is this "real" type? - if (in(niGenerics, fn->ast->args[si].name) && - !fn->ast->args[si].defaultValue) { + if (in(niGenerics, (*fn->ast)[si].getName()) && + !(*fn->ast)[si].getDefault()) return -1; - } reordered.emplace_back(nullptr, 0); } else { seqassert(gi < fn->funcGenerics.size(), "bad fn"); - if (!fn->funcGenerics[gi].type->isStaticType() && - !args[slots[si][0]].value->isType()) + if (!extractFuncGeneric(fn, gi)->isStaticType() && + !isTypeExpr(args[slots[si][0]])) return -1; - reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]); + reordered.emplace_back(args[slots[si][0]].getExpr()->getType(), + slots[si][0]); } gi++; } else if (si == s || si == k || slots[si].size() != 1) { // Partials - if (slots[si].empty() && part && part->known[si]) { - reordered.emplace_back(getPartialArg(pi++), 0); + if (slots[si].empty() && part && part->getPartial() && + part->getPartialMask()[si]) { + reordered.emplace_back(partialArgs[pi++], 0); } else { // Ignore *args, *kwargs and default arguments reordered.emplace_back(nullptr, 0); } } else { - reordered.emplace_back(args[slots[si][0]].value->type, slots[si][0]); + reordered.emplace_back(args[slots[si][0]].getExpr()->getType(), + slots[si][0]); } } return 0; }, [](error::Error, const SrcInfo &, const std::string &) { return -1; }, - part ? part->known : std::vector{}); + part && part->getPartial() ? part->getPartialMask() : std::vector{}); int ai = 0, mai = 0, gi = 0, real_gi = 0; for (; score != -1 && ai < reordered.size(); ai++) { - auto expectTyp = fn->ast->args[ai].status == Param::Normal - ? fn->getArgTypes()[mai++] - : fn->funcGenerics[gi++].type; + auto expectTyp = (*fn->ast)[ai].isValue() ? extractFuncArgType(fn, mai++) + : extractFuncGeneric(fn, gi++); auto [argType, argTypeIdx] = reordered[ai]; if (!argType) continue; - real_gi += fn->ast->args[ai].status != Param::Normal; - if (fn->ast->args[ai].status != Param::Normal) { + real_gi += !(*fn->ast)[ai].isValue(); + if (!(*fn->ast)[ai].isValue()) { // Check if this is a good generic! if (expectTyp && expectTyp->isStaticType()) { - if (!args[argTypeIdx].value->isStatic()) { + if (!args[argTypeIdx].getExpr()->getType()->isStaticType()) { score = -1; break; } else { - argType = Type::makeStatic(ctx->cache, args[argTypeIdx].value); + argType = args[argTypeIdx].getExpr()->getType(); } } else { /// TODO: check if these are real types or if traits are satisfied continue; } } - try { - ExprPtr dummy = std::make_shared(""); - dummy->type = argType; - dummy->setDone(); - wrapExpr(dummy, expectTyp, fn); - types::Type::Unification undo; - if (dummy->type->unify(expectTyp.get(), &undo) >= 0) { - undo.undo(); - } else { - score = -1; - } - } catch (const exc::ParserException &) { - // Ignore failed wraps + + auto [_, newArgTyp, __] = canWrapExpr(argType, expectTyp, fn); + if (!newArgTyp) + newArgTyp = argType->shared_from_this(); + if (newArgTyp->unify(expectTyp, nullptr) < 0) score = -1; - } } if (score >= 0) score += (real_gi == fn->funcGenerics.size()); @@ -345,17 +522,16 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn, /// Select the best method among the provided methods given the list of arguments. /// See @c reorderNamedArgs for details. -std::vector -TypecheckVisitor::findMatchingMethods(const types::ClassTypePtr &typ, - const std::vector &methods, - const std::vector &args) { +std::vector TypecheckVisitor::findMatchingMethods( + types::ClassType *typ, const std::vector &methods, + const std::vector &args, types::ClassType *part) { // Pick the last method that accepts the given arguments. - std::vector results; + std::vector results; for (const auto &mi : methods) { if (!mi) continue; // avoid overloads that have not been seen yet - auto method = ctx->instantiate(mi, typ)->getFunc(); - int score = canCall(method, args); + auto method = instantiateType(mi, typ); + int score = canCall(method->getFunc(), args, part); if (score != -1) { results.push_back(mi); } @@ -372,156 +548,369 @@ TypecheckVisitor::findMatchingMethods(const types::ClassTypePtr &typ, /// expected `T`, got `Optional[T]` -> `unwrap(expr)` /// expected `Function`, got a function -> partialize function /// expected `T`, got `Union[T...]` -> `__internal__.get_union(expr, T)` -/// expected `Union[T...]`, got `T` -> `__internal__.new_union(expr, Union[T...])` -/// expected base class, got derived -> downcast to base class +/// expected `Union[T...]`, got `T` -> `__internal__.new_union(expr, +/// Union[T...])` expected base class, got derived -> downcast to base class /// @param allowUnwrap allow optional unwrapping. -bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType, - const FuncTypePtr &callee, bool allowUnwrap) { +bool TypecheckVisitor::wrapExpr(Expr **expr, Type *expectedType, FuncType *callee, + bool allowUnwrap) { + auto [canWrap, newArgTyp, fn] = canWrapExpr((*expr)->getType(), expectedType, callee, + allowUnwrap, cast(*expr)); + // TODO: get rid of this line one day! + if ((*expr)->getType()->isStaticType() && + (!expectedType || !expectedType->isStaticType())) + (*expr)->setType(getUnderlyingStaticType((*expr)->getType())->shared_from_this()); + if (canWrap && fn) { + *expr = transform(fn(*expr)); + } + return canWrap; +} + +std::tuple> +TypecheckVisitor::canWrapExpr(Type *exprType, Type *expectedType, FuncType *callee, + bool allowUnwrap, bool isEllipsis) { auto expectedClass = expectedType->getClass(); - auto exprClass = expr->getType()->getClass(); - auto doArgWrap = - !callee || !callee->ast->hasAttr("std.internal.attributes.no_argument_wrap"); + auto exprClass = exprType->getClass(); + auto doArgWrap = !callee || !callee->ast->hasAttribute( + "std.internal.attributes.no_argument_wrap.0:0"); if (!doArgWrap) - return true; - auto doTypeWrap = - !callee || !callee->ast->hasAttr("std.internal.attributes.no_type_wrap"); - if (callee && expr->isType()) { - auto c = expr->type->getClass(); + return {true, expectedType ? expectedType->shared_from_this() : nullptr, nullptr}; + + TypePtr type = nullptr; + std::function fn = nullptr; + + if (callee && exprType->is(TYPE_TYPE)) { + auto c = extractClassType(exprType); if (!c) - return false; - if (doTypeWrap) { - if (c->getRecord()) - expr = transform(N(expr, N(EllipsisExpr::PARTIAL))); - else - expr = transform(N( - N("__internal__.class_ctr:0"), - std::vector{{"T", expr}, - {"", N(EllipsisExpr::PARTIAL)}})); + return {false, nullptr, nullptr}; + if (!(expectedType && (expectedType->is(TYPE_TYPE)))) { + type = instantiateType(getStdLibType("TypeWrap"), std::vector{c}); + fn = [&](Expr *expr) -> Expr * { + return N(N("TypeWrap"), expr); + }; } + return {true, type, fn}; } std::unordered_set hints = {"Generator", "float", TYPE_OPTIONAL, "pyobj"}; + if (!expectedType || !expectedType->isStaticType()) { + if (auto c = exprType->isStaticType()) { + exprType = getUnderlyingStaticType(exprType); + exprClass = exprType->getClass(); + type = exprType->shared_from_this(); + } + } + if (!exprClass && expectedClass && in(hints, expectedClass->name)) { - return false; // argument type not yet known. - } else if (expectedClass && expectedClass->name == "Generator" && - exprClass->name != expectedClass->name && !expr->getEllipsis()) { + return {false, nullptr, nullptr}; // argument type not yet known. + } + + else if ((!expectedClass || !expectedClass->is("Ref")) && exprClass && + exprClass->is("Ref")) { + type = extractClassGeneric(exprClass)->shared_from_this(); + fn = [&](Expr *expr) -> Expr * { + return N(N("__internal__.ref_get"), expr); + }; + } + + else if (expectedClass && expectedClass->is("Generator") && + !exprClass->is(expectedClass->name) && !isEllipsis) { + if (findMethod(exprClass, "__iter__").empty()) + return {false, nullptr, nullptr}; // Note: do not do this in pipelines (TODO: why?) - expr = transform(N(N(expr, "__iter__"))); - } else if (expectedClass && expectedClass->name == "float" && - exprClass->name == "int") { - expr = transform(N(N("float"), expr)); - } else if (expectedClass && expectedClass->name == TYPE_OPTIONAL && - exprClass->name != expectedClass->name) { - expr = transform(N(N(TYPE_OPTIONAL), expr)); + type = instantiateType(expectedClass); + fn = [&](Expr *expr) -> Expr * { + return N(N(expr, "__iter__")); + }; + } + + else if (expectedClass && expectedClass->is("float") && exprClass->is("int")) { + type = instantiateType(expectedClass); + fn = [&](Expr *expr) -> Expr * { return N(N("float"), expr); }; + } + + else if (expectedClass && expectedClass->is(TYPE_OPTIONAL) && + !exprClass->is(expectedClass->name)) { + type = + instantiateType(getStdLibType(TYPE_OPTIONAL), std::vector{exprClass}); + fn = [&](Expr *expr) -> Expr * { + return N(N(TYPE_OPTIONAL), expr); + }; } else if (allowUnwrap && expectedClass && exprClass && - exprClass->name == TYPE_OPTIONAL && - exprClass->name != expectedClass->name) { // unwrap optional - expr = transform(N(N(FN_UNWRAP), expr)); - } else if (expectedClass && expectedClass->name == "pyobj" && - exprClass->name != expectedClass->name) { // wrap to pyobj - expr = transform( - N(N("pyobj"), N(N(expr, "__to_py__")))); - } else if (allowUnwrap && expectedClass && exprClass && exprClass->name == "pyobj" && - exprClass->name != expectedClass->name) { // unwrap pyobj + exprClass->is(TYPE_OPTIONAL) && + !exprClass->is(expectedClass->name)) { // unwrap optional + type = instantiateType(extractClassGeneric(exprClass)); + fn = [&](Expr *expr) -> Expr * { return N(N(FN_UNWRAP), expr); }; + } + + else if (expectedClass && expectedClass->is("pyobj") && + !exprClass->is(expectedClass->name)) { // wrap to pyobj + if (findMethod(exprClass, "__to_py__").empty()) + return {false, nullptr, nullptr}; + type = instantiateType(expectedClass); + fn = [&](Expr *expr) -> Expr * { + return N(N("pyobj"), + N(N(expr, "__to_py__"))); + }; + } + + else if (allowUnwrap && expectedClass && exprClass && exprClass->is("pyobj") && + !exprClass->is(expectedClass->name)) { // unwrap pyobj + if (findMethod(expectedClass, "__from_py__").empty()) + return {false, nullptr, nullptr}; + type = instantiateType(expectedClass); auto texpr = N(expectedClass->name); - texpr->setType(expectedType); - expr = - transform(N(N(texpr, "__from_py__"), N(expr, "p"))); - } else if (callee && exprClass && expr->type->getFunc() && - !(expectedClass && expectedClass->name == "Function")) { + texpr->setType(expectedType->shared_from_this()); + fn = [this, texpr](Expr *expr) -> Expr * { + return N(N(texpr, "__from_py__"), N(expr, "p")); + }; + } + + else if (expectedClass && expectedClass->is("ProxyFunc") && exprClass && + (exprClass->getPartial() || exprClass->getFunc() || + exprClass->is("Function"))) { + // Get list of arguments + std::vector argTypes; + Type *retType; + std::shared_ptr fnType = nullptr; + if (!exprClass->getPartial()) { + auto targs = extractClassGeneric(exprClass)->getClass(); + for (size_t i = 0; i < targs->size(); i++) + argTypes.push_back((*targs)[i]); + retType = extractClassGeneric(exprClass, 1); + } else { + fnType = instantiateType(exprClass->getPartial()->getPartialFunc()); + std::vector argumentTypes; + auto known = exprClass->getPartial()->getPartialMask(); + for (size_t i = 0; i < known.size(); i++) { + if (!known[i]) + argTypes.push_back((*fnType)[i]); + } + retType = fnType->getRetType(); + } + auto expectedArgs = extractClassGeneric(expectedClass)->getClass(); + if (argTypes.size() != expectedArgs->size()) + return {false, nullptr, nullptr}; + for (size_t i = 0; i < argTypes.size(); i++) { + if (argTypes[i]->unify((*expectedArgs)[i], nullptr) < 0) + return {false, nullptr, nullptr}; + } + if (retType->unify(extractClassGeneric(expectedClass, 1), nullptr) < 0) + return {false, nullptr, nullptr}; + + type = expectedType->shared_from_this(); + fn = [this, type](Expr *expr) -> Expr * { + auto exprClass = expr->getType()->getClass(); + auto expectedClass = type->getClass(); + + std::vector argTypes; + Type *retType; + std::shared_ptr fnType = nullptr; + if (!exprClass->getPartial()) { + auto targs = extractClassGeneric(exprClass)->getClass(); + for (size_t i = 0; i < targs->size(); i++) + argTypes.push_back((*targs)[i]); + retType = extractClassGeneric(exprClass, 1); + } else { + fnType = instantiateType(exprClass->getPartial()->getPartialFunc()); + std::vector argumentTypes; + auto known = exprClass->getPartial()->getPartialMask(); + for (size_t i = 0; i < known.size(); i++) { + if (!known[i]) + argTypes.push_back((*fnType)[i]); + } + retType = fnType->getRetType(); + } + auto expectedArgs = extractClassGeneric(expectedClass)->getClass(); + for (size_t i = 0; i < argTypes.size(); i++) + unify(argTypes[i], (*expectedArgs)[i]); + auto tr = unify(retType, extractClassGeneric(expectedClass, 1)); + + std::string fname; + Expr *retFn = nullptr, *dataArg = nullptr, *dataType = nullptr; + if (exprClass->getPartial()) { + auto rf = realize(exprClass); + fname = rf->realizedName(); + seqassert(rf, "not realizable"); + retFn = N( + N(N(N("Ptr"), N(rf->realizedName())), + N("data")), + N(0)); + dataType = N("cobj"); + } else if (exprClass->getFunc()) { + auto rf = realize(exprClass); + seqassert(rf, "not realizable"); + fname = rf->realizedName(); + retFn = N(rf->getFunc()->realizedName()); + dataArg = N(N("cobj")); + dataType = N("cobj"); + } else { + seqassert(exprClass->is("Function"), "bad type: {}", exprClass->toString()); + auto rf = realize(exprClass); + seqassert(rf, "not realizable"); + fname = rf->realizedName(); + retFn = N(N(rf->realizedName()), N("data")); + dataType = N("cobj"); + } + fname = fmt::format(".proxy.{}", fname); + + if (!ctx->find(fname)) { + // Create wrapper if needed + auto f = N( + fname, nullptr, + std::vector{ + Param{"data", dataType}, + Param{"args", N(expectedArgs->realizedName())}}, // Tuple[...] + N( + N(N(retFn, N(N("args")))))); + f = cast(transform(f)); + } + auto e = N(N("ProxyFunc"), N(fname), + dataArg ? dataArg : expr); + return e; + }; + } else if (callee && exprClass && exprType->getFunc() && + !(expectedClass && expectedClass->is("Function"))) { // Wrap raw Seq functions into Partial(...) call for easy realization. - expr = partializeFunction(expr->type->getFunc()); - } else if (allowUnwrap && exprClass && expr->type->getUnion() && expectedClass && - !expectedClass->getUnion()) { + // Special case: Seq functions are embedded (via lambda!) + // seqassert(cast(expr) || (cast(expr) && + // cast(cast(expr)->getExpr())), + // "bad partial function: {}", *expr); + auto p = partializeFunction(exprType->getFunc()); + if (expectedClass) + type = instantiateType(expectedClass); + fn = [&](Expr *expr) -> Expr * { + if (auto se = cast(expr)) + return N(se->items, p); + return p; + }; + } else if (expectedClass && expectedClass->is("Function") && exprClass && + exprClass->getPartial() && exprClass->getPartial()->isPartialEmpty()) { + type = instantiateType(expectedClass); + auto fnName = exprClass->getPartial()->getPartialFunc()->ast->name; + auto t = instantiateType(ctx->forceFind(fnName)->getType()); + if (type->unify(t.get(), nullptr) >= 0) + fn = [&](Expr *expr) -> Expr * { return N(fnName); }; + else + type = nullptr; + } + + else if (allowUnwrap && exprClass && exprType->getUnion() && expectedClass && + !expectedClass->getUnion()) { // Extract union types via __internal__.get_union if (auto t = realize(expectedClass)) { - auto e = realize(expr->type); + auto e = realize(exprType); if (!e) - return false; + return {false, nullptr, nullptr}; bool ok = false; for (auto &ut : e->getUnion()->getRealizationTypes()) { - if (ut->unify(t.get(), nullptr) >= 0) { + if (ut->unify(t, nullptr) >= 0) { ok = true; break; } } if (ok) { - expr = transform(N(N("__internal__.get_union:0"), expr, - N(t->realizedName()))); + type = t->shared_from_this(); + fn = [this, type](Expr *expr) -> Expr * { + return N(N("__internal__.get_union:0"), expr, + N(type->realizedName())); + }; } } else { - return false; + return {false, nullptr, nullptr}; } - } else if (exprClass && expectedClass && expectedClass->getUnion()) { + } + + else if (exprClass && expectedClass && expectedClass->getUnion()) { // Make union types via __internal__.new_union if (!expectedClass->getUnion()->isSealed()) { - expectedClass->getUnion()->addType(exprClass); + if (!expectedClass->getUnion()->addType(exprClass)) + E(error::Error::UNION_TOO_BIG, expectedClass->getSrcInfo(), + expectedClass->getUnion()->pendingTypes.size()); } if (auto t = realize(expectedClass)) { - if (expectedClass->unify(exprClass.get(), nullptr) == -1) - expr = transform(N(N("__internal__.new_union:0"), expr, - NT(t->realizedName()))); + if (expectedClass->unify(exprClass, nullptr) == -1) { + type = t->shared_from_this(); + fn = [this, type](Expr *expr) -> Expr * { + return N(N(N("__internal__"), "new_union"), expr, + N(type->realizedName())); + }; + } } else { - return false; + return {false, nullptr, nullptr}; } - } else if (exprClass && expectedClass && exprClass->name != expectedClass->name) { + } + + else if (exprClass && exprClass->is(TYPE_TYPE) && expectedClass && + expectedClass->is("TypeWrap")) { + type = instantiateType(getStdLibType("TypeWrap"), + std::vector{exprClass}); + fn = [this](Expr *expr) -> Expr * { + return N(N("TypeWrap"), expr); + }; + } + + else if (exprClass && expectedClass && !exprClass->is(expectedClass->name)) { // Cast derived classes to base classes - auto &mros = ctx->cache->classes[exprClass->name].mro; + const auto &mros = ctx->cache->getClass(exprClass)->mro; for (size_t i = 1; i < mros.size(); i++) { - auto t = ctx->instantiate(mros[i]->type, exprClass); - if (t->unify(expectedClass.get(), nullptr) >= 0) { - if (!expr->isId("")) { - expr = castToSuperClass(expr, expectedClass, true); - } else { // Just checking can this be done - expr->type = expectedClass; - } + auto t = instantiateType(mros[i].get(), exprClass); + if (t->unify(expectedClass, nullptr) >= 0) { + type = expectedClass->shared_from_this(); + fn = [this, type](Expr *expr) -> Expr * { + return castToSuperClass(expr, type->getClass(), true); + }; break; } } } - return true; + + return {true, type, fn}; } /// Cast derived class to a base class. -ExprPtr TypecheckVisitor::castToSuperClass(ExprPtr expr, ClassTypePtr superTyp, - bool isVirtual) { - ClassTypePtr typ = expr->type->getClass(); - for (auto &field : getClassFields(typ.get())) { - for (auto &parentField : getClassFields(superTyp.get())) +Expr *TypecheckVisitor::castToSuperClass(Expr *expr, ClassType *superTyp, + bool isVirtual) { + ClassType *typ = expr->getClassType(); + for (auto &field : getClassFields(typ)) { + for (auto &parentField : getClassFields(superTyp)) if (field.name == parentField.name) { - unify(ctx->instantiate(field.type, typ), - ctx->instantiate(parentField.type, superTyp)); + auto t = instantiateType(field.getType(), typ); + unify(t.get(), instantiateType(parentField.getType(), superTyp)); } } realize(superTyp); - auto typExpr = N(superTyp->name); - typExpr->setType(superTyp); + auto typExpr = N(superTyp->realizedName()); return transform( N(N(N("__internal__"), "class_super"), expr, typExpr)); } /// Unpack a Tuple or KwTuple expression into (name, type) vector. /// Name is empty when handling Tuple; otherwise it matches names of KwTuple. -std::shared_ptr>> -TypecheckVisitor::unpackTupleTypes(ExprPtr expr) { - auto ret = std::make_shared>>(); - if (auto tup = expr->origExpr->getTuple()) { - for (auto &a : tup->items) { - transform(a); - if (!a->getType()->getClass()) +std::shared_ptr>> +TypecheckVisitor::unpackTupleTypes(Expr *expr) { + auto ret = std::make_shared>>(); + if (auto tup = cast(expr->getOrigExpr())) { + for (auto &a : *tup) { + a = transform(a); + if (!a->getClassType()) return nullptr; - ret->push_back({"", a->getType()}); - } - } else if (auto kw = expr->origExpr->getCall()) { // origExpr? - auto kwCls = in(ctx->cache->classes, expr->getType()->getClass()->name); - seqassert(kwCls, "cannot find {}", expr->getType()->getClass()->name); - for (size_t i = 0; i < kw->args.size(); i++) { - auto &a = kw->args[i].value; - transform(a); - if (!a->getType()->getClass()) + ret->emplace_back("", a->getType()); + } + } else if (auto kw = cast(expr->getOrigExpr())) { + auto val = extractClassType(expr->getType()); + if (!val || !val->is("NamedTuple") || !extractClassGeneric(val, 1)->getClass() || + !extractClassGeneric(val)->canRealize()) + return nullptr; + auto id = getIntLiteral(val); + seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); + auto names = ctx->cache->generatedTupleNames[id]; + auto types = extractClassGeneric(val, 1)->getClass(); + seqassert(startswith(types->name, "Tuple"), "bad NamedTuple argument"); + for (size_t i = 0; i < types->generics.size(); i++) { + if (!extractClassGeneric(types, i)) return nullptr; - ret->push_back({kwCls->fields[i].name, a->getType()}); + ret->emplace_back(names[i], extractClassGeneric(types, i)); } } else { return nullptr; @@ -529,16 +918,949 @@ TypecheckVisitor::unpackTupleTypes(ExprPtr expr) { return ret; } -std::vector & -TypecheckVisitor::getClassFields(types::ClassType *t) { - seqassert(t && in(ctx->cache->classes, t->name), "cannot find '{}'", - t ? t->name : ""); - if (t->is(TYPE_TUPLE) && !t->getRecord()->args.empty()) { - auto key = ctx->generateTuple(t->getRecord()->args.size()); - return ctx->cache->classes[key].fields; +std::vector> +TypecheckVisitor::extractNamedTuple(Expr *expr) { + std::vector> ret; + + seqassert(expr->getType()->is("NamedTuple") && + extractClassGeneric(expr->getClassType())->canRealize(), + "bad named tuple: {}", *expr); + auto id = getIntLiteral(expr->getClassType()); + seqassert(id >= 0 && id < ctx->cache->generatedTupleNames.size(), "bad id: {}", id); + auto names = ctx->cache->generatedTupleNames[id]; + for (size_t i = 0; i < names.size(); i++) { + ret.emplace_back(names[i], N(N(expr, "args"), N(i))); + } + return ret; +} + +std::vector +TypecheckVisitor::getClassFields(types::ClassType *t) const { + auto f = getClass(t->name)->fields; + if (t->is(TYPE_TUPLE)) + f = std::vector(f.begin(), + f.begin() + t->generics.size()); + return f; +} + +std::vector +TypecheckVisitor::getClassFieldTypes(types::ClassType *cls) { + return withClassGenerics(cls, [&]() { + std::vector result; + for (auto &field : getClassFields(cls)) { + auto ftyp = instantiateType(field.getType(), cls); + if (!ftyp->canRealize() && field.typeExpr) { + auto t = extractType(transform(clean_clone(field.typeExpr))); + unify(ftyp.get(), t); + } + result.push_back(ftyp); + } + return result; + }); +} + +types::Type *TypecheckVisitor::extractType(types::Type *t) { + while (t && t->is(TYPE_TYPE)) + t = extractClassGeneric(t); + return t; +} + +types::Type *TypecheckVisitor::extractType(Expr *e) { + if (cast(e) && cast(e)->getValue() == TYPE_TYPE) + return e->getType(); + if (auto i = cast(e)) + if (cast(i->getExpr()) && + cast(i->getExpr())->getValue() == TYPE_TYPE) + return e->getType(); + return extractType(e->getType()); +} + +types::Type *TypecheckVisitor::extractType(const std::string &s) { + auto c = ctx->forceFind(s); + return s == TYPE_TYPE ? c->getType() : extractType(c->getType()); +} + +types::ClassType *TypecheckVisitor::extractClassType(Expr *e) { + auto t = extractType(e); + return t->getClass(); +} + +types::ClassType *TypecheckVisitor::extractClassType(types::Type *t) { + return extractType(t)->getClass(); +} + +types::ClassType *TypecheckVisitor::extractClassType(const std::string &s) { + return extractType(s)->getClass(); +} + +bool TypecheckVisitor::isUnbound(types::Type *t) const { + return t->getUnbound() != nullptr; +} + +bool TypecheckVisitor::isUnbound(Expr *e) const { return isUnbound(e->getType()); } + +bool TypecheckVisitor::hasOverloads(const std::string &root) { + auto i = in(ctx->cache->overloads, root); + return i && i->size() > 1; +} + +std::vector TypecheckVisitor::getOverloads(const std::string &root) { + auto i = in(ctx->cache->overloads, root); + seqassert(i, "bad root"); + return *i; +} + +std::string TypecheckVisitor::getUnmangledName(const std::string &s) const { + return ctx->cache->rev(s); +} + +Cache::Class *TypecheckVisitor::getClass(const std::string &t) const { + auto i = in(ctx->cache->classes, t); + return i; +} + +Cache::Class *TypecheckVisitor::getClass(types::Type *t) const { + if (t) { + if (auto c = t->getClass()) + return getClass(c->name); + } + seqassert(false, "bad class"); + return nullptr; +} + +Cache::Function *TypecheckVisitor::getFunction(const std::string &n) const { + auto i = in(ctx->cache->functions, n); + return i; +} + +Cache::Function *TypecheckVisitor::getFunction(types::Type *t) const { + seqassert(t->getFunc(), "bad function"); + return getFunction(t->getFunc()->getFuncName()); +} + +Cache::Class::ClassRealization * +TypecheckVisitor::getClassRealization(types::Type *t) const { + seqassert(t->canRealize(), "bad class"); + auto i = in(getClass(t)->realizations, t->getClass()->realizedName()); + seqassert(i, "bad class realization"); + return i->get(); +} + +std::string TypecheckVisitor::getRootName(types::FuncType *t) { + auto i = in(ctx->cache->functions, t->getFuncName()); + seqassert(i && !i->rootName.empty(), "bad function"); + return i->rootName; +} + +bool TypecheckVisitor::isTypeExpr(Expr *e) { + return e && e->getType() && e->getType()->is(TYPE_TYPE); +} + +Cache::Module *TypecheckVisitor::getImport(const std::string &s) { + auto i = in(ctx->cache->imports, s); + seqassert(i, "bad import"); + return i; +} + +std::string TypecheckVisitor::getArgv() const { return ctx->cache->argv0; } + +std::string TypecheckVisitor::getRootModulePath() const { return ctx->cache->module0; } + +std::vector TypecheckVisitor::getPluginImportPaths() const { + return ctx->cache->pluginImportPaths; +} + +bool TypecheckVisitor::isDispatch(const std::string &s) { + return endswith(s, FN_DISPATCH_SUFFIX); +} + +bool TypecheckVisitor::isDispatch(FunctionStmt *ast) { + return ast && isDispatch(ast->name); +} + +bool TypecheckVisitor::isDispatch(types::Type *f) { + return f->getFunc() && isDispatch(f->getFunc()->ast); +} + +void TypecheckVisitor::addClassGenerics(types::ClassType *typ, bool func, + bool onlyMangled, bool instantiate) { + auto addGen = [&](const types::ClassType::Generic &g) { + auto t = g.type; + if (instantiate) { + if (auto l = t->getLink()) + if (l->kind == types::LinkType::Generic) { + auto lx = std::make_shared(*l); + lx->kind = types::LinkType::Unbound; + t = lx; + } + } + seqassert(!g.isStatic || t->isStaticType(), "{} not a static: {}", g.name, + *(g.type)); + if (!g.isStatic && !t->is(TYPE_TYPE)) + t = instantiateTypeVar(t.get()); + auto v = ctx->addType(onlyMangled ? g.name : getUnmangledName(g.name), g.name, t); + v->generic = true; + // LOG("+ {} {} {} {}", getUnmangledName(g.name), g.name, t->debugString(2), + // v->getBaseName()); + }; + + if (func && typ->getFunc()) { + auto tf = typ->getFunc(); + // LOG("// adding {}", tf->debugString(2)); + for (auto parent = tf->funcParent; parent;) { + if (auto f = parent->getFunc()) { + // Add parent function generics + for (auto &g : f->funcGenerics) + addGen(g); + parent = f->funcParent; + } else { + // Add parent class generics + seqassert(parent->getClass(), "not a class: {}", *parent); + for (auto &g : parent->getClass()->generics) + addGen(g); + for (auto &g : parent->getClass()->hiddenGenerics) + addGen(g); + break; + } + } + for (auto &g : tf->funcGenerics) + addGen(g); } else { - return ctx->cache->classes[t->name].fields; + for (auto &g : typ->hiddenGenerics) + addGen(g); + for (auto &g : typ->generics) + addGen(g); + } +} + +types::TypePtr TypecheckVisitor::instantiateTypeVar(types::Type *t) { + return instantiateType(ctx->forceFind(TYPE_TYPE)->getType(), {t}); +} + +void TypecheckVisitor::registerGlobal(const std::string &name, bool initialized) { + if (!in(ctx->cache->globals, name)) { + ctx->cache->globals[name] = {initialized, nullptr}; + } +} + +types::ClassType *TypecheckVisitor::getStdLibType(const std::string &type) { + auto t = getImport(STDLIB_IMPORT)->ctx->forceFind(type)->getType(); + if (type == TYPE_TYPE) + return t->getClass(); + return extractClassType(t); +} + +types::Type *TypecheckVisitor::extractClassGeneric(types::Type *t, int idx) const { + seqassert(t->getClass() && idx < t->getClass()->generics.size(), "bad class"); + return t->getClass()->generics[idx].type.get(); +} + +types::Type *TypecheckVisitor::extractFuncGeneric(types::Type *t, int idx) const { + seqassert(t->getFunc() && idx < t->getFunc()->funcGenerics.size(), "bad function"); + return t->getFunc()->funcGenerics[idx].type.get(); +} + +types::Type *TypecheckVisitor::extractFuncArgType(types::Type *t, int idx) { + seqassert(t->getFunc(), "bad function"); + return extractClassGeneric(extractClassGeneric(t), idx); +} + +std::string TypecheckVisitor::getClassMethod(types::Type *typ, + const std::string &member) { + if (auto cls = getClass(typ)) { + if (auto t = in(cls->methods, member)) + return *t; + } + seqassertn(false, "cannot find '{}' in '{}'", member, *typ); + return ""; +} + +std::string TypecheckVisitor::getTemporaryVar(const std::string &s) { + return ctx->cache->getTemporaryVar(s); +} + +std::string TypecheckVisitor::getStrLiteral(types::Type *t, size_t pos) { + seqassert(t && t->getClass(), "not a class"); + if (t->getStrStatic()) + return t->getStrStatic()->value; + auto ct = extractClassGeneric(t, pos); + seqassert(ct->canRealize() && ct->getStrStatic(), "not a string literal"); + return ct->getStrStatic()->value; +} + +int64_t TypecheckVisitor::getIntLiteral(types::Type *t, size_t pos) { + seqassert(t && t->getClass(), "not a class"); + if (t->getIntStatic()) + return t->getIntStatic()->value; + auto ct = extractClassGeneric(t, pos); + seqassert(ct->canRealize() && ct->getIntStatic(), "not a string literal"); + return ct->getIntStatic()->value; +} + +bool TypecheckVisitor::getBoolLiteral(types::Type *t, size_t pos) { + seqassert(t && t->getClass(), "not a class"); + if (t->getBoolStatic()) + return t->getBoolStatic()->value; + auto ct = extractClassGeneric(t, pos); + seqassert(ct->canRealize() && ct->getBoolStatic(), "not a string literal"); + return ct->getBoolStatic()->value; +} + +bool TypecheckVisitor::isImportFn(const std::string &s) { + return startswith(s, "%_import_"); +} + +int64_t TypecheckVisitor::getTime() { return ctx->time; } + +types::Type *TypecheckVisitor::getUnderlyingStaticType(types::Type *t) { + if (t->getStatic()) { + return t->getStatic()->getNonStaticType(); + } else if (auto c = t->isStaticType()) { + if (c == 1) + return getStdLibType("int"); + if (c == 2) + return getStdLibType("str"); + if (c == 3) + return getStdLibType("bool"); + } + return t; +} + +std::shared_ptr +TypecheckVisitor::instantiateUnbound(const SrcInfo &srcInfo, int level) const { + auto typ = std::make_shared( + ctx->cache, types::LinkType::Unbound, ctx->cache->unboundCount++, level, nullptr); + typ->setSrcInfo(srcInfo); + return typ; +} + +std::shared_ptr +TypecheckVisitor::instantiateUnbound(const SrcInfo &srcInfo) const { + return instantiateUnbound(srcInfo, ctx->typecheckLevel); +} + +std::shared_ptr TypecheckVisitor::instantiateUnbound() const { + return instantiateUnbound(getSrcInfo(), ctx->typecheckLevel); +} + +types::TypePtr TypecheckVisitor::instantiateType(const SrcInfo &srcInfo, + types::Type *type, + types::ClassType *generics) { + seqassert(type, "type is null"); + std::unordered_map genericCache; + if (generics) { + for (auto &g : generics->generics) + if (g.type && + !(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) { + genericCache[g.id] = g.type; + } + // special case: __SELF__ + if (type->getFunc() && !type->getFunc()->funcGenerics.empty() && + type->getFunc()->funcGenerics[0].niceName == "__SELF__") { + genericCache[type->getFunc()->funcGenerics[0].id] = generics->shared_from_this(); + } + } + auto t = type->instantiate(ctx->typecheckLevel, &(ctx->cache->unboundCount), + &genericCache); + for (auto &i : genericCache) { + if (auto l = i.second->getLink()) { + i.second->setSrcInfo(srcInfo); + if (l->defaultType) { + ctx->getBase()->pendingDefaults[0].insert(i.second); + } + } + } + if (auto ft = t->getFunc()) { + if (auto b = ft->ast->getAttribute(Attr::Bindings)) { + auto module = + ft->ast->getAttribute(Attr::Module)->value; + auto imp = getImport(module); + std::unordered_map key; + + // look for generics [todo: speed-up!] + std::unordered_set generics; + for (auto &g : ft->funcGenerics) + generics.insert(getUnmangledName(g.name)); + for (auto parent = ft->funcParent; parent;) { + if (auto f = parent->getFunc()) { + for (auto &g : f->funcGenerics) + generics.insert(getUnmangledName(g.name)); + parent = f->funcParent; + } else { + for (auto &g : parent->getClass()->generics) + generics.insert(getUnmangledName(g.name)); + for (auto &g : parent->getClass()->hiddenGenerics) + generics.insert(getUnmangledName(g.name)); + break; + } + } + for (const auto &[c, _] : b->captures) { + if (!in(generics, c)) { // ignore inherited captures! + if (auto h = imp->ctx->find(c)) { + key[c] = h->canonicalName; + } else { + key.clear(); + break; + } + } + } + if (!key.empty()) { + auto &cm = getFunction(ft->getFuncName())->captureMappings; + size_t idx = 0; + for (; idx < cm.size(); idx++) + if (cm[idx] == key) + break; + if (idx == cm.size()) + cm.push_back(key); + ft->index = idx; + } + } + } + if (t->getUnion() && !t->getUnion()->isSealed()) { + t->setSrcInfo(srcInfo); + ctx->getBase()->pendingDefaults[0].insert(t); + } + return t; +} + +types::TypePtr +TypecheckVisitor::instantiateType(const SrcInfo &srcInfo, types::Type *root, + const std::vector &generics) { + auto c = root->getClass(); + seqassert(c, "root class is null"); + // dummy generic type + auto g = std::make_shared(ctx->cache, "", ""); + if (generics.size() != c->generics.size()) { + E(Error::GENERICS_MISMATCH, srcInfo, getUnmangledName(c->name), c->generics.size(), + generics.size()); + } + for (int i = 0; i < c->generics.size(); i++) { + auto t = generics[i]; + seqassert(c->generics[i].type, "generic is null"); + if (!c->generics[i].isStatic && t->getStatic()) + t = t->getStatic()->getNonStaticType(); + g->generics.emplace_back("", "", t->shared_from_this(), c->generics[i].id, + c->generics[i].isStatic); + } + return instantiateType(srcInfo, root, g.get()); +} + +std::vector TypecheckVisitor::findMethod(types::ClassType *type, + const std::string &method, + bool hideShadowed) { + std::vector vv; + std::unordered_set signatureLoci; + + auto populate = [&](const auto &cls) { + auto t = in(cls.methods, method); + if (!t) + return; + + auto mt = getOverloads(*t); + for (int mti = int(mt.size()) - 1; mti >= 0; mti--) { + auto method = mt[mti]; + auto f = getFunction(method); + if (isDispatch(method) || !f->getType()) + continue; + if (hideShadowed) { + auto sig = f->ast->signature(); + if (!in(signatureLoci, sig)) { + signatureLoci.insert(sig); + vv.emplace_back(f->getType()); + } + } else { + vv.emplace_back(f->getType()); + } + } + }; + if (type->is("Ref")) { + type = extractClassGeneric(type)->getClass(); + } + if (type && type->is(TYPE_TUPLE) && method == "__new__" && !type->generics.empty()) { + generateTuple(type->generics.size()); + auto mc = getClass(TYPE_TUPLE); + populate(*mc); + for (auto f : vv) + if (f->size() == type->generics.size()) + return {f}; + return {}; + } + if (auto cls = getClass(type)) { + for (const auto &pc : cls->mro) { + auto mc = getClass(pc->name == "__NTuple__" ? TYPE_TUPLE : pc->name); + populate(*mc); + } + } + return vv; +} + +Cache::Class::ClassField * +TypecheckVisitor::findMember(types::ClassType *type, const std::string &member) const { + if (type->is("Ref")) { + type = extractClassGeneric(type)->getClass(); + } + if (auto cls = getClass(type)) { + for (const auto &pc : cls->mro) { + auto mc = getClass(pc.get()); + for (auto &mm : mc->fields) { + if (pc->is(TYPE_TUPLE) && (&mm - &(mc->fields[0])) >= type->generics.size()) + break; + if (mm.name == member) + return &mm; + } + } + } + return nullptr; +} + +int TypecheckVisitor::reorderNamedArgs(types::FuncType *func, + const std::vector &args, + const ReorderDoneFn &onDone, + const ReorderErrorFn &onError, + const std::vector &known) { + // See https://docs.python.org/3.6/reference/expressions.html#calls for details. + // Final score: + // - +1 for each matched argument + // - 0 for *args/**kwargs/default arguments + // - -1 for failed match + int score = 0; + + // 0. Find *args and **kwargs + // True if there is a trailing ellipsis (full partial: fn(all_args, ...)) + bool partial = + !args.empty() && cast(args.back().value) && + cast(args.back().value)->getMode() != EllipsisExpr::PIPE && + args.back().name.empty(); + + int starArgIndex = -1, kwstarArgIndex = -1; + for (int i = 0; i < func->ast->size(); i++) { + if (startswith((*func->ast)[i].name, "**")) + kwstarArgIndex = i, score -= 2; + else if (startswith((*func->ast)[i].name, "*")) + starArgIndex = i, score -= 2; + } + + // 1. Assign positional arguments to slots + // Each slot contains a list of arg's indices + std::vector> slots(func->ast->size()); + seqassert(known.empty() || func->ast->size() == known.size(), "bad 'known' string"); + std::vector extra; + std::map namedArgs, + extraNamedArgs; // keep the map--- we need it sorted! + for (int ai = 0, si = 0; ai < args.size() - partial; ai++) { + if (args[ai].name.empty()) { + while (!known.empty() && si < slots.size() && known[si]) + si++; + if (si < slots.size() && (starArgIndex == -1 || si < starArgIndex)) + slots[si++] = {ai}; + else + extra.emplace_back(ai); + } else { + namedArgs[args[ai].name] = ai; + } + } + score += 2 * int(slots.size() - func->funcGenerics.size()); + + for (auto ai : std::vector{std::max(starArgIndex, kwstarArgIndex), + std::min(starArgIndex, kwstarArgIndex)}) + if (ai != -1 && !slots[ai].empty()) { + extra.insert(extra.begin(), ai); + slots[ai].clear(); + } + + // 2. Assign named arguments to slots + if (!namedArgs.empty()) { + std::map slotNames; + for (int i = 0; i < func->ast->size(); i++) + if (known.empty() || !known[i]) { + auto [_, n] = (*func->ast)[i].getNameWithStars(); + slotNames[getUnmangledName(n)] = i; + } + for (auto &n : namedArgs) { + if (!in(slotNames, n.first)) + extraNamedArgs[n.first] = n.second; + else if (slots[slotNames[n.first]].empty()) + slots[slotNames[n.first]].push_back(n.second); + else + return onError(Error::CALL_REPEATED_NAME, args[n.second].value->getSrcInfo(), + Emsg(Error::CALL_REPEATED_NAME, n.first)); + } + } + + // 3. Fill in *args, if present + if (!extra.empty() && starArgIndex == -1) + return onError(Error::CALL_ARGS_MANY, getSrcInfo(), + Emsg(Error::CALL_ARGS_MANY, + // func->ast->getName(), + getUnmangledName(func->ast->getName()), func->ast->size(), + args.size() - partial)); + + if (starArgIndex != -1) + slots[starArgIndex] = extra; + + // 4. Fill in **kwargs, if present + if (!extraNamedArgs.empty() && kwstarArgIndex == -1) + return onError(Error::CALL_ARGS_INVALID, + args[extraNamedArgs.begin()->second].value->getSrcInfo(), + Emsg(Error::CALL_ARGS_INVALID, extraNamedArgs.begin()->first, + getUnmangledName(func->ast->getName()))); + if (kwstarArgIndex != -1) + for (auto &e : extraNamedArgs) + slots[kwstarArgIndex].push_back(e.second); + + // 5. Fill in the default arguments + for (auto i = 0; i < func->ast->size(); i++) + if (slots[i].empty() && i != starArgIndex && i != kwstarArgIndex) { + if ((*func->ast)[i].isValue() && + ((*func->ast)[i].getDefault() || (!known.empty() && known[i]))) { + score -= 2; + } else if (!partial && (*func->ast)[i].isValue()) { + auto [_, n] = (*func->ast)[i].getNameWithStars(); + return onError(Error::CALL_ARGS_MISSING, getSrcInfo(), + Emsg(Error::CALL_ARGS_MISSING, + getUnmangledName(func->ast->getName()), + getUnmangledName(n))); + } + } + auto s = onDone(starArgIndex, kwstarArgIndex, slots, partial); + return s != -1 ? score + s : -1; +} + +bool TypecheckVisitor::isCanonicalName(const std::string &name) const { + return name.rfind('.') != std::string::npos; +} + +types::FuncType *TypecheckVisitor::extractFunction(types::Type *t) const { + if (auto f = t->getFunc()) + return f; + if (auto p = t->getPartial()) + return p->getPartialFunc(); + return nullptr; +} + +class SearchVisitor : public CallbackASTVisitor { + std::function exprPredicate; + std::function stmtPredicate; + +public: + std::vector result; + +public: + SearchVisitor(const std::function &exprPredicate, + const std::function &stmtPredicate) + : exprPredicate(exprPredicate), stmtPredicate(stmtPredicate) {} + void transform(Expr *expr) override { + if (expr && exprPredicate(expr)) { + result.push_back(expr); + } else { + SearchVisitor v(exprPredicate, stmtPredicate); + if (expr) + expr->accept(v); + result.insert(result.end(), v.result.begin(), v.result.end()); + } + } + void transform(Stmt *stmt) override { + if (stmt && stmtPredicate(stmt)) { + SearchVisitor v(exprPredicate, stmtPredicate); + stmt->accept(v); + result.insert(result.end(), v.result.begin(), v.result.end()); + } + } +}; + +ParserErrors TypecheckVisitor::findTypecheckErrors(Stmt *n) { + SearchVisitor v([](Expr *e) { return !e->isDone(); }, + [](Stmt *s) { return !s->isDone(); }); + v.transform(n); + std::vector errors; + for (auto e : v.result) + errors.emplace_back( + fmt::format("cannot typecheck {}", split(e->toString(0), '\n').front()), + e->getSrcInfo()); + return ParserErrors(errors); +} + +ir::PyType TypecheckVisitor::cythonizeClass(const std::string &name) { + auto c = getClass(name); + if (!c->module.empty()) + return {"", ""}; + if (!in(c->methods, "__to_py__") || !in(c->methods, "__from_py__")) + return {"", ""}; + + LOG_USER("[py] Cythonizing {}", name); + ir::PyType py{getUnmangledName(name), c->ast->getDocstr()}; + + auto tc = ctx->forceFind(name)->getType(); + if (!tc->canRealize()) + E(Error::CUSTOM, c->ast, "cannot realize '{}' for Python export", + getUnmangledName(name)); + tc = realize(tc); + seqassertn(tc, "cannot realize '{}'", name); + + // 1. Replace to_py / from_py with _PyWrap.wrap_to_py/from_py + if (auto ofnn = in(c->methods, "__to_py__")) { + auto fnn = getOverloads(*ofnn).front(); // default first overload! + auto fna = getFunction(fnn)->ast; + fna->suite = SuiteStmt::wrap( + N(N(N(format("{}.wrap_to_py:0", CYTHON_PYWRAP)), + N(fna->begin()->name)))); + } + if (auto ofnn = in(c->methods, "__from_py__")) { + auto fnn = getOverloads(*ofnn).front(); // default first overload! + auto fna = getFunction(fnn)->ast; + fna->suite = SuiteStmt::wrap( + N(N(N(format("{}.wrap_from_py:0", CYTHON_PYWRAP)), + N(fna->begin()->name), N(name)))); + } + for (auto &n : std::vector{"__from_py__", "__to_py__"}) { + auto fnn = getOverloads(*in(c->methods, n)).front(); + auto fn = getFunction(fnn); + ir::Func *oldIR = nullptr; + if (!fn->realizations.empty()) + oldIR = fn->realizations.begin()->second->ir; + fn->realizations.clear(); + auto tf = realize(fn->type); + seqassertn(tf, "cannot re-realize '{}'", fnn); + if (oldIR) { + std::vector args; + for (auto it = oldIR->arg_begin(); it != oldIR->arg_end(); ++it) { + args.push_back(ctx->cache->module->Nr(*it)); + } + cast(oldIR)->setBody( + ir::util::series(ir::util::call(fn->realizations.begin()->second->ir, args))); + } + } + for (auto &[rn, r] : + getFunction(format("{}.py_type:0", CYTHON_PYWRAP))->realizations) { + if (r->type->funcGenerics[0].type->unify(tc, nullptr) >= 0) { + py.typePtrHook = r->ir; + break; + } + } + + // 2. Handle methods + auto methods = c->methods; + for (const auto &[n, ofnn] : methods) { + auto canonicalName = getOverloads(ofnn).back(); + auto fn = getFunction(canonicalName); + if (getOverloads(ofnn).size() == 1 && fn->ast->hasAttribute(Attr::AutoGenerated)) + continue; + auto fna = fn->ast; + bool isMethod = fna->hasAttribute(Attr::Method); + bool isProperty = fna->hasAttribute(Attr::Property); + + std::string call = format("{}.wrap_multiple", CYTHON_PYWRAP); + bool isMagic = false; + if (startswith(n, "__") && endswith(n, "__")) { + auto m = n.substr(2, n.size() - 4); + if (m == "new" && c->ast->hasAttribute(Attr::Tuple)) + m = "init"; + auto cls = getClass(CYTHON_PYWRAP); + if (auto i = in(c->methods, "wrap_magic_" + m)) { + call = *i; + isMagic = true; + } + } + if (isProperty) + call = format("{}.wrap_get", CYTHON_PYWRAP); + + auto fnName = call + ":0"; + auto generics = std::vector{tc->shared_from_this()}; + if (isProperty) { + generics.push_back(instantiateStatic(getUnmangledName(canonicalName))); + } else if (!isMagic) { + generics.push_back(instantiateStatic(n)); + generics.push_back(instantiateStatic(int64_t(isMethod))); + } + auto f = realizeIRFunc(getFunction(fnName)->getType(), generics); + if (!f) + continue; + + LOG_USER("[py] {} -> {} ({}; {})", n, call, isMethod, isProperty); + if (isProperty) { + py.getset.push_back({getUnmangledName(canonicalName), "", f, nullptr}); + } else if (n == "__repr__") { + py.repr = f; + } else if (n == "__add__") { + py.add = f; + } else if (n == "__iadd__") { + py.iadd = f; + } else if (n == "__sub__") { + py.sub = f; + } else if (n == "__isub__") { + py.isub = f; + } else if (n == "__mul__") { + py.mul = f; + } else if (n == "__imul__") { + py.imul = f; + } else if (n == "__mod__") { + py.mod = f; + } else if (n == "__imod__") { + py.imod = f; + } else if (n == "__divmod__") { + py.divmod = f; + } else if (n == "__pow__") { + py.pow = f; + } else if (n == "__ipow__") { + py.ipow = f; + } else if (n == "__neg__") { + py.neg = f; + } else if (n == "__pos__") { + py.pos = f; + } else if (n == "__abs__") { + py.abs = f; + } else if (n == "__bool__") { + py.bool_ = f; + } else if (n == "__invert__") { + py.invert = f; + } else if (n == "__lshift__") { + py.lshift = f; + } else if (n == "__ilshift__") { + py.ilshift = f; + } else if (n == "__rshift__") { + py.rshift = f; + } else if (n == "__irshift__") { + py.irshift = f; + } else if (n == "__and__") { + py.and_ = f; + } else if (n == "__iand__") { + py.iand = f; + } else if (n == "__xor__") { + py.xor_ = f; + } else if (n == "__ixor__") { + py.ixor = f; + } else if (n == "__or__") { + py.or_ = f; + } else if (n == "__ior__") { + py.ior = f; + } else if (n == "__int__") { + py.int_ = f; + } else if (n == "__float__") { + py.float_ = f; + } else if (n == "__floordiv__") { + py.floordiv = f; + } else if (n == "__ifloordiv__") { + py.ifloordiv = f; + } else if (n == "__truediv__") { + py.truediv = f; + } else if (n == "__itruediv__") { + py.itruediv = f; + } else if (n == "__index__") { + py.index = f; + } else if (n == "__matmul__") { + py.matmul = f; + } else if (n == "__imatmul__") { + py.imatmul = f; + } else if (n == "__len__") { + py.len = f; + } else if (n == "__getitem__") { + py.getitem = f; + } else if (n == "__setitem__") { + py.setitem = f; + } else if (n == "__contains__") { + py.contains = f; + } else if (n == "__hash__") { + py.hash = f; + } else if (n == "__call__") { + py.call = f; + } else if (n == "__str__") { + py.str = f; + } else if (n == "__iter__") { + py.iter = f; + } else if (n == "__del__") { + py.del = f; + } else if (n == "__init__" || + (c->ast->hasAttribute(Attr::Tuple) && n == "__new__")) { + py.init = f; + } else { + py.methods.push_back(ir::PyFunction{ + n, fna->getDocstr(), f, + fna->hasAttribute(Attr::Method) ? ir::PyFunction::Type::METHOD + : ir::PyFunction::Type::CLASS, + // always use FASTCALL for now; works even for 0- or 1- arg methods + 2}); + py.methods.back().keywords = true; + } + } + + for (auto &m : py.methods) { + if (in(std::set{"__lt__", "__le__", "__eq__", "__ne__", "__gt__", + "__ge__"}, + m.name)) { + py.cmp = realizeIRFunc( + ctx->forceFind(format("{}.wrap_cmp:0", CYTHON_PYWRAP))->type->getFunc(), + {tc->shared_from_this()}); + break; + } + } + + if (c->realizations.size() != 1) + E(Error::CUSTOM, c->ast, "cannot pythonize generic class '{}'", name); + auto r = c->realizations.begin()->second; + py.type = r->ir; + seqassertn(!r->type->is(TYPE_TUPLE), "tuples not yet done"); + for (auto &[mn, mt] : r->fields) { + /// TODO: handle PyMember for tuples + // Generate getters & setters + auto generics = + std::vector{tc->shared_from_this(), instantiateStatic(mn)}; + auto gf = realizeIRFunc( + getFunction(format("{}.wrap_get:0", CYTHON_PYWRAP))->getType(), generics); + ir::Func *sf = nullptr; + if (!c->ast->hasAttribute(Attr::Tuple)) + sf = realizeIRFunc(getFunction(format("{}.wrap_set:0", CYTHON_PYWRAP))->getType(), + generics); + py.getset.push_back({mn, "", gf, sf}); + LOG_USER("[py] {}: {} . {}", "member", name, mn); + } + return py; +} + +ir::PyType TypecheckVisitor::cythonizeIterator(const std::string &name) { + LOG_USER("[py] iterfn: {}", name); + ir::PyType py{name, ""}; + auto cr = ctx->cache->classes[CYTHON_ITER].realizations[name]; + auto tc = cr->getType(); + for (auto &[rn, r] : + getFunction(format("{}.py_type:0", CYTHON_PYWRAP))->realizations) { + if (extractFuncGeneric(r->getType())->unify(tc, nullptr) >= 0) { + py.typePtrHook = r->ir; + break; + } + } + + const auto &methods = getClass(CYTHON_ITER)->methods; + for (auto &n : std::vector{"_iter", "_iternext"}) { + auto fnn = getOverloads(getClass(CYTHON_ITER)->methods[n]).front(); + auto rtv = realize(instantiateType(getFunction(fnn)->getType(), tc->getClass())); + auto f = + getFunction(rtv->getFunc()->getFuncName())->realizations[rtv->realizedName()]; + if (n == "_iter") + py.iter = f->ir; + else + py.iternext = f->ir; + } + py.type = cr->ir; + return py; +} + +ir::PyFunction TypecheckVisitor::cythonizeFunction(const std::string &name) { + auto f = getFunction(name); + if (f->isToplevel) { + auto fnName = format("{}.wrap_multiple:0", CYTHON_PYWRAP); + auto generics = std::vector{ + ctx->forceFind(".toplevel")->type, + instantiateStatic(getUnmangledName(f->ast->getName())), + instantiateStatic(int64_t(0))}; + if (auto ir = realizeIRFunc(getFunction(fnName)->getType(), generics)) { + LOG_USER("[py] {}: {}", "toplevel", name); + ir::PyFunction fn{getUnmangledName(name), f->ast->getDocstr(), ir, + ir::PyFunction::Type::TOPLEVEL, int(f->ast->size())}; + fn.keywords = true; + return fn; + } } + return {"", ""}; } } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index 26f32777..05a720c3 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -11,7 +11,6 @@ #include "codon/parser/ast.h" #include "codon/parser/common.h" -#include "codon/parser/visitors/format/format.h" #include "codon/parser/visitors/typecheck/ctx.h" #include "codon/parser/visitors/visitor.h" @@ -23,43 +22,46 @@ namespace codon::ast { * -> Note: this stage *modifies* the provided AST. Clone it before simplification * if you need it intact. */ -class TypecheckVisitor : public CallbackASTVisitor { +class TypecheckVisitor : public ReplacingCallbackASTVisitor { /// Shared simplification context. std::shared_ptr ctx; /// Statements to prepend before the current statement. - std::shared_ptr> prependStmts; + std::shared_ptr> prependStmts = nullptr; + std::shared_ptr> preamble = nullptr; /// Each new expression is stored here (as @c visit does not return anything) and /// later returned by a @c transform call. - ExprPtr resultExpr; + Expr *resultExpr = nullptr; /// Each new statement is stored here (as @c visit does not return anything) and /// later returned by a @c transform call. - StmtPtr resultStmt; + Stmt *resultStmt = nullptr; public: - static StmtPtr apply(Cache *cache, const StmtPtr &stmts); + // static Stmt * apply(Cache *cache, const Stmt * &stmts); + static Stmt * + apply(Cache *cache, Stmt *node, const std::string &file, + const std::unordered_map &defines = {}, + const std::unordered_map &earlyDefines = {}, + bool barebones = false); + static Stmt *apply(const std::shared_ptr &cache, Stmt *node, + const std::string &file = ""); + +private: + static void loadStdLibrary(Cache *, const std::shared_ptr> &, + const std::unordered_map &, + bool); public: explicit TypecheckVisitor( std::shared_ptr ctx, - const std::shared_ptr> &stmts = nullptr); + const std::shared_ptr> &preamble = nullptr, + const std::shared_ptr> &stmts = nullptr); public: // Convenience transformators - ExprPtr transform(ExprPtr &e) override; - ExprPtr transform(const ExprPtr &expr) override { - auto e = expr; - return transform(e); - } - StmtPtr transform(StmtPtr &s) override; - StmtPtr transform(const StmtPtr &stmt) override { - auto s = stmt; - return transform(s); - } - ExprPtr transformType(ExprPtr &expr); - ExprPtr transformType(const ExprPtr &expr) { - auto e = expr; - return transformType(e); - } + Expr *transform(Expr *e) override; + Expr *transform(Expr *expr, bool allowTypes); + Stmt *transform(Stmt *s) override; + Expr *transformType(Expr *expr, bool allowTypeOf = true); private: void defaultVisit(Expr *e) override; @@ -70,181 +72,422 @@ class TypecheckVisitor : public CallbackASTVisitor { void visit(NoneExpr *) override; void visit(BoolExpr *) override; void visit(IntExpr *) override; + Expr *transformInt(IntExpr *); void visit(FloatExpr *) override; + Expr *transformFloat(FloatExpr *); void visit(StringExpr *) override; /* Identifier access expressions (access.cpp) */ void visit(IdExpr *) override; + bool checkCapture(const TypeContext::Item &); void visit(DotExpr *) override; - ExprPtr transformDot(DotExpr *, std::vector * = nullptr); - ExprPtr getClassMember(DotExpr *, std::vector *); - types::TypePtr findSpecialMember(const std::string &); - types::FuncTypePtr getBestOverload(Expr *, std::vector *); - types::FuncTypePtr getDispatch(const std::string &); + std::pair getImport(const std::vector &); + Expr *getClassMember(DotExpr *); + types::FuncType *getDispatch(const std::string &); /* Collection and comprehension expressions (collections.cpp) */ void visit(TupleExpr *) override; void visit(ListExpr *) override; void visit(SetExpr *) override; void visit(DictExpr *) override; + Expr *transformComprehension(const std::string &, const std::string &, + std::vector &); void visit(GeneratorExpr *) override; - ExprPtr transformComprehension(const std::string &, const std::string &, - std::vector &); /* Conditional expression and statements (cond.cpp) */ + void visit(RangeExpr *) override; void visit(IfExpr *) override; void visit(IfStmt *) override; + void visit(MatchStmt *) override; + Stmt *transformPattern(Expr *, Expr *, Stmt *); /* Operators (op.cpp) */ void visit(UnaryExpr *) override; - ExprPtr evaluateStaticUnary(UnaryExpr *); + Expr *evaluateStaticUnary(UnaryExpr *); void visit(BinaryExpr *) override; - ExprPtr evaluateStaticBinary(BinaryExpr *); - ExprPtr transformBinarySimple(BinaryExpr *); - ExprPtr transformBinaryIs(BinaryExpr *); + Expr *evaluateStaticBinary(BinaryExpr *); + Expr *transformBinarySimple(BinaryExpr *); + Expr *transformBinaryIs(BinaryExpr *); std::pair getMagic(const std::string &); - ExprPtr transformBinaryInplaceMagic(BinaryExpr *, bool); - ExprPtr transformBinaryMagic(BinaryExpr *); + Expr *transformBinaryInplaceMagic(BinaryExpr *, bool); + Expr *transformBinaryMagic(BinaryExpr *); + void visit(ChainBinaryExpr *) override; void visit(PipeExpr *) override; void visit(IndexExpr *) override; - std::pair transformStaticTupleIndex(const types::ClassTypePtr &, - const ExprPtr &, const ExprPtr &); + std::pair transformStaticTupleIndex(types::ClassType *, Expr *, Expr *); int64_t translateIndex(int64_t, int64_t, bool = false); int64_t sliceAdjustIndices(int64_t, int64_t *, int64_t *, int64_t); void visit(InstantiateExpr *) override; void visit(SliceExpr *) override; /* Calls (call.cpp) */ + void visit(PrintStmt *) override; /// Holds partial call information for a CallExpr. struct PartialCallData { - bool isPartial = false; // true if the call is partial - std::string var; // set if calling a partial type itself - std::vector known = {}; // mask of known arguments - ExprPtr args = nullptr, kwArgs = nullptr; // partial *args/**kwargs expressions + bool isPartial = false; // true if the call is partial + std::string var; // set if calling a partial type itself + std::vector known = {}; // mask of known arguments + Expr *args = nullptr, *kwArgs = nullptr; // partial *args/**kwargs expressions }; void visit(StarExpr *) override; void visit(KeywordStarExpr *) override; void visit(EllipsisExpr *) override; void visit(CallExpr *) override; - bool transformCallArgs(std::vector &); - std::pair getCalleeFn(CallExpr *, PartialCallData &); - ExprPtr callReorderArguments(types::FuncTypePtr, CallExpr *, PartialCallData &); - bool typecheckCallArgs(const types::FuncTypePtr &, std::vector &); - std::pair transformSpecialCall(CallExpr *); - ExprPtr transformSuperF(CallExpr *expr); - ExprPtr transformSuper(); - ExprPtr transformPtr(CallExpr *expr); - ExprPtr transformArray(CallExpr *expr); - ExprPtr transformIsInstance(CallExpr *expr); - ExprPtr transformStaticLen(CallExpr *expr); - ExprPtr transformHasAttr(CallExpr *expr); - ExprPtr transformGetAttr(CallExpr *expr); - ExprPtr transformSetAttr(CallExpr *expr); - ExprPtr transformCompileError(CallExpr *expr); - ExprPtr transformTupleFn(CallExpr *expr); - ExprPtr transformTypeFn(CallExpr *expr); - ExprPtr transformRealizedFn(CallExpr *expr); - ExprPtr transformStaticPrintFn(CallExpr *expr); - ExprPtr transformHasRttiFn(CallExpr *expr); - std::pair transformInternalStaticFn(CallExpr *expr); - std::vector getSuperTypes(const types::ClassTypePtr &cls); - void addFunctionGenerics(const types::FuncType *t); - std::string generatePartialStub(const std::vector &mask, types::FuncType *fn); + void validateCall(CallExpr *expr); + bool transformCallArgs(CallExpr *); + std::pair, Expr *> getCalleeFn(CallExpr *, + PartialCallData &); + Expr *callReorderArguments(types::FuncType *, CallExpr *, PartialCallData &); + bool typecheckCallArgs(types::FuncType *, std::vector &, bool); + std::pair transformSpecialCall(CallExpr *); + std::vector getSuperTypes(types::ClassType *); /* Assignments (assign.cpp) */ + void visit(AssignExpr *) override; void visit(AssignStmt *) override; - void transformUpdate(AssignStmt *); + Stmt *unpackAssignment(Expr *lhs, Expr *rhs); + Stmt *transformUpdate(AssignStmt *); + Stmt *transformAssignment(AssignStmt *, bool = false); + void visit(DelStmt *) override; void visit(AssignMemberStmt *) override; - std::pair transformInplaceUpdate(AssignStmt *); + std::pair transformInplaceUpdate(AssignStmt *); + + /* Imports (import.cpp) */ + void visit(ImportStmt *) override; + Stmt *transformSpecialImport(ImportStmt *); + std::vector getImportPath(Expr *, size_t = 0); + Stmt *transformCImport(const std::string &, const std::vector &, Expr *, + const std::string &); + Stmt *transformCVarImport(const std::string &, Expr *, const std::string &); + Stmt *transformCDLLImport(Expr *, const std::string &, const std::vector &, + Expr *, const std::string &, bool); + Stmt *transformPythonImport(Expr *, const std::vector &, Expr *, + const std::string &); + Stmt *transformNewImport(const ImportFile &); /* Loops (loops.cpp) */ void visit(BreakStmt *) override; void visit(ContinueStmt *) override; void visit(WhileStmt *) override; void visit(ForStmt *) override; - StmtPtr transformHeterogenousTupleFor(ForStmt *); - StmtPtr transformStaticForLoop(ForStmt *); + Expr *transformForDecorator(Expr *); + std::pair transformStaticForLoop(ForStmt *); /* Errors and exceptions (error.cpp) */ + void visit(AssertStmt *) override; void visit(TryStmt *) override; void visit(ThrowStmt *) override; + void visit(WithStmt *) override; /* Functions (function.cpp) */ void visit(YieldExpr *) override; void visit(ReturnStmt *) override; void visit(YieldStmt *) override; + void visit(YieldFromStmt *) override; + void visit(LambdaExpr *) override; + void visit(GlobalStmt *) override; void visit(FunctionStmt *) override; - ExprPtr partializeFunction(const types::FuncTypePtr &); - std::shared_ptr getFuncTypeBase(size_t); - -public: - types::FuncTypePtr makeFunctionType(FunctionStmt *); + Stmt *transformPythonDefinition(const std::string &, const std::vector &, + Expr *, Stmt *); + Stmt *transformLLVMDefinition(Stmt *); + std::pair getDecorator(Expr *); + Expr *partializeFunction(types::FuncType *); + std::shared_ptr getFuncTypeBase(size_t); private: /* Classes (class.cpp) */ void visit(ClassStmt *) override; - void parseBaseClasses(ClassStmt *); - std::string generateTuple(size_t, const std::string & = TYPE_TUPLE, - std::vector = {}, bool = true); + std::vector parseBaseClasses(std::vector &, + std::vector &, Stmt *, + const std::string &, Expr *, + types::ClassType *); + void autoDeduceMembers(ClassStmt *, std::vector &); + std::vector getClassMethods(Stmt *s); + void transformNestedClasses(ClassStmt *, std::vector &, std::vector &, + std::vector &); + Stmt *codegenMagic(const std::string &, Expr *, const std::vector &, bool); + int generateKwId(const std::vector & = {}); +public: + types::ClassType *generateTuple(size_t n, bool = true); + +private: /* The rest (typecheck.cpp) */ void visit(SuiteStmt *) override; void visit(ExprStmt *) override; void visit(StmtExpr *) override; void visit(CommentStmt *stmt) override; + void visit(CustomStmt *) override; -private: +public: /* Type inference (infer.cpp) */ - types::TypePtr unify(types::TypePtr &a, const types::TypePtr &b); - types::TypePtr unify(types::TypePtr &&a, const types::TypePtr &b) { - auto x = a; - return unify(x, b); + types::Type *unify(types::Type *a, types::Type *b); + types::Type *unify(types::Type *a, types::TypePtr &&b) { return unify(a, b.get()); } + types::Type *realize(types::Type *); + types::TypePtr &&realize(types::TypePtr &&t) { + realize(t.get()); + return std::move(t); } - StmtPtr inferTypes(StmtPtr, bool isToplevel = false); - types::TypePtr realize(types::TypePtr); - types::TypePtr realizeFunc(types::FuncType *, bool = false); - types::TypePtr realizeType(types::ClassType *); - std::shared_ptr generateSpecialAst(types::FuncType *); + +private: + Stmt *inferTypes(Stmt *, bool isToplevel = false); + types::Type *realizeFunc(types::FuncType *, bool = false); + types::Type *realizeType(types::ClassType *); + SuiteStmt *generateSpecialAst(types::FuncType *); size_t getRealizationID(types::ClassType *, types::FuncType *); codon::ir::types::Type *makeIRType(types::ClassType *); codon::ir::Func * makeIRFunction(const std::shared_ptr &); private: - types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ, - const std::string &member, - const std::vector &args); - types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ, - const std::string &member, - const std::vector &args); - types::FuncTypePtr - findBestMethod(const types::ClassTypePtr &typ, const std::string &member, - const std::vector> &args); - int canCall(const types::FuncTypePtr &, const std::vector &, - std::shared_ptr = nullptr); - std::vector - findMatchingMethods(const types::ClassTypePtr &typ, - const std::vector &methods, - const std::vector &args); - bool wrapExpr(ExprPtr &expr, const types::TypePtr &expectedType, - const types::FuncTypePtr &callee = nullptr, bool allowUnwrap = true); - ExprPtr castToSuperClass(ExprPtr expr, types::ClassTypePtr superTyp, bool = false); - StmtPtr prepareVTables(); + types::FuncType *findBestMethod(types::ClassType *typ, const std::string &member, + const std::vector &args); + types::FuncType *findBestMethod(types::ClassType *typ, const std::string &member, + const std::vector &args); + types::FuncType * + findBestMethod(types::ClassType *typ, const std::string &member, + const std::vector> &args); + int canCall(types::FuncType *, const std::vector &, + types::ClassType * = nullptr); + std::vector findMatchingMethods( + types::ClassType *typ, const std::vector &methods, + const std::vector &args, types::ClassType *part = nullptr); + Expr *castToSuperClass(Expr *expr, types::ClassType *superTyp, bool = false); + void prepareVTables(); + std::vector> extractNamedTuple(Expr *); + std::vector getClassFieldTypes(types::ClassType *); + std::vector> findEllipsis(Expr *); public: - bool isTuple(const std::string &s) const { return s == TYPE_TUPLE; } - std::vector &getClassFields(types::ClassType *); + bool wrapExpr(Expr **expr, types::Type *expectedType, + types::FuncType *callee = nullptr, bool allowUnwrap = true); + std::tuple> + canWrapExpr(types::Type *exprType, types::Type *expectedType, + types::FuncType *callee = nullptr, bool allowUnwrap = true, + bool isEllipsis = false); + std::vector getClassFields(types::ClassType *) const; + std::shared_ptr getCtx() const { return ctx; } + Expr *generatePartialCall(const std::vector &, types::FuncType *, + Expr * = nullptr, Expr * = nullptr); friend class Cache; + friend class TypeContext; friend class types::CallableTrait; friend class types::UnionType; private: // Helpers - std::shared_ptr>> - unpackTupleTypes(ExprPtr); - std::pair>> - transformStaticLoopCall(const std::vector &, ExprPtr, - std::function(StmtPtr)>); + std::shared_ptr>> + unpackTupleTypes(Expr *); + std::tuple> + transformStaticLoopCall(Expr *, SuiteStmt **, Expr *, + const std::function &, bool = false); + +public: + template Tn *N(Ts &&...args) { + Tn *t = ctx->cache->N(std::forward(args)...); + t->setSrcInfo(getSrcInfo()); + if (cast(t) && getTime()) + t->setAttribute(Attr::ExprTime, getTime()); + return t; + } + template Tn *NC(Ts &&...args) { + Tn *t = ctx->cache->N(std::forward(args)...); + return t; + } + +private: + template void log(const std::string &prefix, Ts &&...args) { + fmt::print(codon::getLogger().log, "[{}] [{}${}]: " + prefix + "\n", + ctx->getSrcInfo(), ctx->getBaseName(), ctx->getBase()->iteration, + std::forward(args)...); + } + template + void logfile(const std::string &file, const std::string &prefix, Ts &&...args) { + if (in(ctx->getSrcInfo().file, file)) + fmt::print(codon::getLogger().log, "[{}] [{}${}]: " + prefix + "\n", + ctx->getSrcInfo(), ctx->getBaseName(), ctx->getBase()->iteration, + std::forward(args)...); + } + +public: + types::Type *extractType(types::Type *t); + types::Type *extractType(Expr *e); + types::Type *extractType(const std::string &); + types::ClassType *extractClassType(Expr *e); + types::ClassType *extractClassType(types::Type *t); + types::ClassType *extractClassType(const std::string &s); + bool isUnbound(types::Type *t) const; + bool isUnbound(Expr *e) const; + bool hasOverloads(const std::string &root); + std::vector getOverloads(const std::string &root); + std::string getUnmangledName(const std::string &s) const; + Cache::Class *getClass(const std::string &t) const; + Cache::Class *getClass(types::Type *t) const; + Cache::Function *getFunction(const std::string &n) const; + Cache::Function *getFunction(types::Type *t) const; + Cache::Class::ClassRealization *getClassRealization(types::Type *t) const; + std::string getRootName(types::FuncType *t); + bool isTypeExpr(Expr *e); + Cache::Module *getImport(const std::string &s); + std::string getArgv() const; + std::string getRootModulePath() const; + std::vector getPluginImportPaths() const; + bool isDispatch(const std::string &s); + bool isDispatch(FunctionStmt *ast); + bool isDispatch(types::Type *f); + void addClassGenerics(types::ClassType *typ, bool func = false, + bool onlyMangled = false, bool instantiate = false); + template + auto withClassGenerics(types::ClassType *typ, F fn, bool func = false, + bool onlyMangled = false, bool instantiate = false) { + ctx->addBlock(); + addClassGenerics(typ, func, onlyMangled, instantiate); + auto t = fn(); + ctx->popBlock(); + return t; + } + types::TypePtr instantiateTypeVar(types::Type *t); + void registerGlobal(const std::string &s, bool = false); + types::ClassType *getStdLibType(const std::string &type); + types::Type *extractClassGeneric(types::Type *t, int idx = 0) const; + types::Type *extractFuncGeneric(types::Type *t, int idx = 0) const; + types::Type *extractFuncArgType(types::Type *t, int idx = 0); + std::string getClassMethod(types::Type *typ, const std::string &member); + std::string getTemporaryVar(const std::string &s); + bool isImportFn(const std::string &s); + int64_t getTime(); + types::Type *getUnderlyingStaticType(types::Type *t); + + int64_t getIntLiteral(types::Type *t, size_t pos = 0); + bool getBoolLiteral(types::Type *t, size_t pos = 0); + std::string getStrLiteral(types::Type *t, size_t pos = 0); + + Expr *transformNamedTuple(CallExpr *); + Expr *transformFunctoolsPartial(CallExpr *); + Expr *transformSuperF(CallExpr *); + Expr *transformSuper(); + Expr *transformPtr(CallExpr *); + Expr *transformArray(CallExpr *); + Expr *transformIsInstance(CallExpr *); + Expr *transformStaticLen(CallExpr *); + Expr *transformHasAttr(CallExpr *); + Expr *transformGetAttr(CallExpr *); + Expr *transformSetAttr(CallExpr *); + Expr *transformCompileError(CallExpr *); + Expr *transformTupleFn(CallExpr *); + Expr *transformTypeFn(CallExpr *); + Expr *transformRealizedFn(CallExpr *); + Expr *transformStaticPrintFn(CallExpr *); + Expr *transformHasRttiFn(CallExpr *); + Expr *transformStaticFnCanCall(CallExpr *); + Expr *transformStaticFnArgHasType(CallExpr *); + Expr *transformStaticFnArgGetType(CallExpr *); + Expr *transformStaticFnArgs(CallExpr *); + Expr *transformStaticFnHasDefault(CallExpr *); + Expr *transformStaticFnGetDefault(CallExpr *); + Expr *transformStaticFnWrapCallArgs(CallExpr *); + Expr *transformStaticVars(CallExpr *); + Expr *transformStaticTupleType(CallExpr *); + SuiteStmt *generateClassPopulateVTablesAST(); + SuiteStmt *generateBaseDerivedDistAST(types::FuncType *); + FunctionStmt *generateThunkAST(types::FuncType *fp, types::ClassType *base, + types::ClassType *derived); + SuiteStmt *generateFunctionCallInternalAST(types::FuncType *); + SuiteStmt *generateUnionNewAST(types::FuncType *); + SuiteStmt *generateUnionTagAST(types::FuncType *); + SuiteStmt *generateNamedKeysAST(types::FuncType *); + SuiteStmt *generateTupleMulAST(types::FuncType *); + std::vector populateStaticTupleLoop(Expr *, const std::vector &); + std::vector populateSimpleStaticRangeLoop(Expr *, + const std::vector &); + std::vector populateStaticRangeLoop(Expr *, const std::vector &); + std::vector populateStaticFnOverloadsLoop(Expr *, + const std::vector &); + std::vector populateStaticEnumerateLoop(Expr *, + const std::vector &); + std::vector populateStaticVarsLoop(Expr *, const std::vector &); + std::vector populateStaticVarTypesLoop(Expr *, + const std::vector &); + std::vector + populateStaticHeterogenousTupleLoop(Expr *, const std::vector &); + ParserErrors findTypecheckErrors(Stmt *n); + +public: +public: + /// Get the current realization depth (i.e., the number of nested realizations). + size_t getRealizationDepth() const; + /// Get the name of the current realization stack (e.g., `fn1:fn2:...`). + std::string getRealizationStackName() const; + +public: + /// Create an unbound type with the provided typechecking level. + std::shared_ptr instantiateUnbound(const SrcInfo &info, + int level) const; + std::shared_ptr instantiateUnbound(const SrcInfo &info) const; + std::shared_ptr instantiateUnbound() const; + + /// Call `type->instantiate`. + /// Prepare the generic instantiation table with the given generics parameter. + /// Example: when instantiating List[T].foo, generics=List[int].foo will ensure that + /// T=int. + /// @param expr Expression that needs the type. Used to set type's srcInfo. + /// @param setActive If True, add unbounds to activeUnbounds. + types::TypePtr instantiateType(const SrcInfo &info, types::Type *type, + types::ClassType *generics = nullptr); + types::TypePtr instantiateType(const SrcInfo &info, types::Type *root, + const std::vector &generics); + template + std::shared_ptr instantiateType(T *type, types::ClassType *generics = nullptr) { + return std::static_pointer_cast( + instantiateType(getSrcInfo(), std::move(type), generics)); + } + template + std::shared_ptr instantiateType(T *root, + const std::vector &generics) { + return std::static_pointer_cast( + instantiateType(getSrcInfo(), std::move(root), generics)); + } + std::shared_ptr instantiateStatic(int64_t i) { + return std::make_shared(ctx->cache, i); + } + std::shared_ptr instantiateStatic(const std::string &s) { + return std::make_shared(ctx->cache, s); + } + std::shared_ptr instantiateStatic(bool i) { + return std::make_shared(ctx->cache, i); + } + + /// Returns the list of generic methods that correspond to typeName.method. + std::vector findMethod(types::ClassType *type, + const std::string &method, + bool hideShadowed = true); + /// Returns the generic type of typeName.member, if it exists (nullptr otherwise). + /// Special cases: __elemsize__ and __atomic__. + Cache::Class::ClassField *findMember(types::ClassType *, const std::string &) const; + + using ReorderDoneFn = + std::function> &, bool)>; + using ReorderErrorFn = std::function; + /// Reorders a given vector or named arguments (consisting of names and the + /// corresponding types) according to the signature of a given function. + /// Returns the reordered vector and an associated reordering score (missing + /// default arguments' score is half of the present arguments). + /// Score is -1 if the given arguments cannot be reordered. + /// @param known Bitmask that indicated if an argument is already provided + /// (partial function) or not. + int reorderNamedArgs(types::FuncType *func, const std::vector &args, + const ReorderDoneFn &onDone, const ReorderErrorFn &onError, + const std::vector &known = std::vector()); + + bool isCanonicalName(const std::string &name) const; + types::FuncType *extractFunction(types::Type *t) const; + + ir::PyType cythonizeClass(const std::string &name); + ir::PyType cythonizeIterator(const std::string &name); + ir::PyFunction cythonizeFunction(const std::string &name); + ir::Func *realizeIRFunc(types::FuncType *fn, + const std::vector &generics = {}); + // types::Type *getType(const std::string &); }; } // namespace codon::ast diff --git a/codon/parser/visitors/visitor.cpp b/codon/parser/visitors/visitor.cpp index 40e771c8..c823c55c 100644 --- a/codon/parser/visitors/visitor.cpp +++ b/codon/parser/visitors/visitor.cpp @@ -22,7 +22,6 @@ void ASTVisitor::visit(ListExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(SetExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(DictExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(GeneratorExpr *expr) { defaultVisit(expr); } -void ASTVisitor::visit(DictGeneratorExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(IfExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(UnaryExpr *expr) { defaultVisit(expr); } void ASTVisitor::visit(BinaryExpr *expr) { defaultVisit(expr); } @@ -57,6 +56,7 @@ void ASTVisitor::visit(IfStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(MatchStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ImportStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(TryStmt *stmt) { defaultVisit(stmt); } +void ASTVisitor::visit(ExceptStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(GlobalStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(ThrowStmt *stmt) { defaultVisit(stmt); } void ASTVisitor::visit(FunctionStmt *stmt) { defaultVisit(stmt); } diff --git a/codon/parser/visitors/visitor.h b/codon/parser/visitors/visitor.h index b7a175b2..f10b6055 100644 --- a/codon/parser/visitors/visitor.h +++ b/codon/parser/visitors/visitor.h @@ -36,7 +36,6 @@ struct ASTVisitor { virtual void visit(SetExpr *); virtual void visit(DictExpr *); virtual void visit(GeneratorExpr *); - virtual void visit(DictGeneratorExpr *); virtual void visit(IfExpr *); virtual void visit(UnaryExpr *); virtual void visit(BinaryExpr *); @@ -71,6 +70,7 @@ struct ASTVisitor { virtual void visit(MatchStmt *); virtual void visit(ImportStmt *); virtual void visit(TryStmt *); + virtual void visit(ExceptStmt *); virtual void visit(GlobalStmt *); virtual void visit(ThrowStmt *); virtual void visit(FunctionStmt *); @@ -81,23 +81,17 @@ struct ASTVisitor { virtual void visit(CommentStmt *); }; -template /** * Callback AST visitor. * This visitor extends base ASTVisitor and stores node's source location (SrcObject). - * Function simplify() will visit a node and return the appropriate transformation. As + * Function transform() will visit a node and return the appropriate transformation. As * each node type (expression or statement) might return a different type, * this visitor is generic for each different return type. */ +template struct CallbackASTVisitor : public ASTVisitor, public SrcObject { - virtual TE transform(const std::shared_ptr &expr) = 0; - virtual TE transform(std::shared_ptr &expr) { - return transform(static_cast &>(expr)); - } - virtual TS transform(const std::shared_ptr &stmt) = 0; - virtual TS transform(std::shared_ptr &stmt) { - return transform(static_cast &>(stmt)); - } + virtual TE transform(Expr *expr) = 0; + virtual TS transform(Stmt *stmt) = 0; /// Convenience method that transforms a vector of nodes. template auto transform(const std::vector &ts) { @@ -107,46 +101,6 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { return r; } - /// Convenience method that constructs a clone of a node. - template auto N(const Tn &ptr) { return std::make_shared(ptr); } - /// Convenience method that constructs a node. - /// @param s source location. - template auto N(codon::SrcInfo s, Ts &&...args) { - auto t = std::make_shared(std::forward(args)...); - t->setSrcInfo(s); - return t; - } - /// Convenience method that constructs a node with the visitor's source location. - template auto N(Ts &&...args) { - auto t = std::make_shared(std::forward(args)...); - t->setSrcInfo(getSrcInfo()); - return t; - } - template auto NT(Ts &&...args) { - auto t = std::make_shared(std::forward(args)...); - t->setSrcInfo(getSrcInfo()); - t->markType(); - return t; - } - - /// Convenience method that raises an error at the current source location. - template void error(const char *format, TArgs &&...args) { - error::raise_error(-1, getSrcInfo(), fmt::format(format, args...).c_str()); - } - - /// Convenience method that raises an error at the source location of p. - template - void error(const T &p, const char *format, TArgs &&...args) { - error::raise_error(-1, p->getSrcInfo(), fmt::format(format, args...).c_str()); - } - - /// Convenience method that raises an internal error. - template - void internalError(const char *format, TArgs &&...args) { - throw exc::ParserException( - fmt::format("INTERNAL: {}", fmt::format(format, args...), getSrcInfo())); - } - public: void visit(NoneExpr *expr) override {} void visit(BoolExpr *expr) override {} @@ -154,8 +108,8 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { void visit(FloatExpr *expr) override {} void visit(StringExpr *expr) override {} void visit(IdExpr *expr) override {} - void visit(StarExpr *expr) override { transform(expr->what); } - void visit(KeywordStarExpr *expr) override { transform(expr->what); } + void visit(StarExpr *expr) override { transform(expr->expr); } + void visit(KeywordStarExpr *expr) override { transform(expr->expr); } void visit(TupleExpr *expr) override { for (auto &i : expr->items) transform(i); @@ -172,25 +126,7 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { for (auto &i : expr->items) transform(i); } - void visit(GeneratorExpr *expr) override { - transform(expr->expr); - for (auto &l : expr->loops) { - transform(l.vars); - transform(l.gen); - for (auto &c : l.conds) - transform(c); - } - } - void visit(DictGeneratorExpr *expr) override { - transform(expr->key); - transform(expr->expr); - for (auto &l : expr->loops) { - transform(l.vars); - transform(l.gen); - for (auto &c : l.conds) - transform(c); - } - } + void visit(GeneratorExpr *expr) override { transform(expr->loops); } void visit(IfExpr *expr) override { transform(expr->cond); transform(expr->ifexpr); @@ -215,7 +151,7 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { } void visit(CallExpr *expr) override { transform(expr->expr); - for (auto &a : expr->args) + for (auto &a : expr->items) transform(a.value); } void visit(DotExpr *expr) override { transform(expr->expr); } @@ -225,7 +161,13 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { transform(expr->step); } void visit(EllipsisExpr *expr) override {} - void visit(LambdaExpr *expr) override { transform(expr->expr); } + void visit(LambdaExpr *expr) override { + for (auto &a : expr->items) { + transform(a.type); + transform(a.defaultValue); + } + transform(expr->expr); + } void visit(YieldExpr *expr) override {} void visit(AssignExpr *expr) override { transform(expr->var); @@ -236,17 +178,17 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { transform(expr->stop); } void visit(InstantiateExpr *expr) override { - transform(expr->typeExpr); - for (auto &e : expr->typeParams) + transform(expr->expr); + for (auto &e : expr->items) transform(e); } void visit(StmtExpr *expr) override { - for (auto &s : expr->stmts) + for (auto &s : expr->items) transform(s); transform(expr->expr); } void visit(SuiteStmt *stmt) override { - for (auto &s : stmt->stmts) + for (auto &s : stmt->items) transform(s); } void visit(BreakStmt *stmt) override {} @@ -292,8 +234,8 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { transform(stmt->elseSuite); } void visit(MatchStmt *stmt) override { - transform(stmt->what); - for (auto &m : stmt->cases) { + transform(stmt->expr); + for (auto &m : stmt->items) { transform(m.pattern); transform(m.guard); transform(m.suite); @@ -310,17 +252,23 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { } void visit(TryStmt *stmt) override { transform(stmt->suite); - for (auto &a : stmt->catches) { - transform(a.exc); - transform(a.suite); - } + for (auto &a : stmt->items) + transform(a); + transform(stmt->elseSuite); transform(stmt->finally); } + void visit(ExceptStmt *stmt) override { + transform(stmt->exc); + transform(stmt->suite); + } void visit(GlobalStmt *stmt) override {} - void visit(ThrowStmt *stmt) override { transform(stmt->expr); } + void visit(ThrowStmt *stmt) override { + transform(stmt->expr); + transform(stmt->from); + } void visit(FunctionStmt *stmt) override { transform(stmt->ret); - for (auto &a : stmt->args) { + for (auto &a : stmt->items) { transform(a.type); transform(a.defaultValue); } @@ -329,7 +277,7 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { transform(d); } void visit(ClassStmt *stmt) override { - for (auto &a : stmt->args) { + for (auto &a : stmt->items) { transform(a.type); transform(a.defaultValue); } @@ -353,4 +301,204 @@ struct CallbackASTVisitor : public ASTVisitor, public SrcObject { } }; +/** + * Callback AST visitor. + * This visitor extends base ASTVisitor and stores node's source location (SrcObject). + * Function transform() will visit a node and return the appropriate transformation. As + * each node type (expression or statement) might return a different type, + * this visitor is generic for each different return type. + */ +struct ReplacingCallbackASTVisitor : public CallbackASTVisitor { +public: + void visit(StarExpr *expr) override { expr->expr = transform(expr->expr); } + void visit(KeywordStarExpr *expr) override { expr->expr = transform(expr->expr); } + void visit(TupleExpr *expr) override { + for (auto &i : expr->items) + i = transform(i); + } + void visit(ListExpr *expr) override { + for (auto &i : expr->items) + i = transform(i); + } + void visit(SetExpr *expr) override { + for (auto &i : expr->items) + i = transform(i); + } + void visit(DictExpr *expr) override { + for (auto &i : expr->items) + i = transform(i); + } + void visit(GeneratorExpr *expr) override { expr->loops = transform(expr->loops); } + void visit(IfExpr *expr) override { + expr->cond = transform(expr->cond); + expr->ifexpr = transform(expr->ifexpr); + expr->elsexpr = transform(expr->elsexpr); + } + void visit(UnaryExpr *expr) override { expr->expr = transform(expr->expr); } + void visit(BinaryExpr *expr) override { + expr->lexpr = transform(expr->lexpr); + expr->rexpr = transform(expr->rexpr); + } + void visit(ChainBinaryExpr *expr) override { + for (auto &e : expr->exprs) + e.second = transform(e.second); + } + void visit(PipeExpr *expr) override { + for (auto &e : expr->items) + e.expr = transform(e.expr); + } + void visit(IndexExpr *expr) override { + expr->expr = transform(expr->expr); + expr->index = transform(expr->index); + } + void visit(CallExpr *expr) override { + expr->expr = transform(expr->expr); + for (auto &a : expr->items) + a.value = transform(a.value); + } + void visit(DotExpr *expr) override { expr->expr = transform(expr->expr); } + void visit(SliceExpr *expr) override { + expr->start = transform(expr->start); + expr->stop = transform(expr->stop); + expr->step = transform(expr->step); + } + void visit(EllipsisExpr *expr) override {} + void visit(LambdaExpr *expr) override { + for (auto &a : expr->items) { + a.type = transform(a.type); + a.defaultValue = transform(a.defaultValue); + } + expr->expr = transform(expr->expr); + } + void visit(YieldExpr *expr) override {} + void visit(AssignExpr *expr) override { + expr->var = transform(expr->var); + expr->expr = transform(expr->expr); + } + void visit(RangeExpr *expr) override { + expr->start = transform(expr->start); + expr->stop = transform(expr->stop); + } + void visit(InstantiateExpr *expr) override { + expr->expr = transform(expr->expr); + for (auto &e : expr->items) + e = transform(e); + } + void visit(StmtExpr *expr) override { + for (auto &s : expr->items) + s = transform(s); + expr->expr = transform(expr->expr); + } + void visit(SuiteStmt *stmt) override { + for (auto &s : stmt->items) + s = transform(s); + } + void visit(ExprStmt *stmt) override { stmt->expr = transform(stmt->expr); } + void visit(AssignStmt *stmt) override { + stmt->lhs = transform(stmt->lhs); + stmt->rhs = transform(stmt->rhs); + stmt->type = transform(stmt->type); + } + void visit(AssignMemberStmt *stmt) override { + stmt->lhs = transform(stmt->lhs); + stmt->rhs = transform(stmt->rhs); + } + void visit(DelStmt *stmt) override { stmt->expr = transform(stmt->expr); } + void visit(PrintStmt *stmt) override { + for (auto &e : stmt->items) + e = transform(e); + } + void visit(ReturnStmt *stmt) override { stmt->expr = transform(stmt->expr); } + void visit(YieldStmt *stmt) override { stmt->expr = transform(stmt->expr); } + void visit(AssertStmt *stmt) override { + stmt->expr = transform(stmt->expr); + stmt->message = transform(stmt->message); + } + void visit(WhileStmt *stmt) override { + stmt->cond = transform(stmt->cond); + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); + } + void visit(ForStmt *stmt) override { + stmt->var = transform(stmt->var); + stmt->iter = transform(stmt->iter); + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); + stmt->decorator = transform(stmt->decorator); + for (auto &a : stmt->ompArgs) + a.value = transform(a.value); + } + void visit(IfStmt *stmt) override { + stmt->cond = transform(stmt->cond); + stmt->ifSuite = SuiteStmt::wrap(transform(stmt->ifSuite)); + stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); + } + void visit(MatchStmt *stmt) override { + stmt->expr = transform(stmt->expr); + for (auto &m : stmt->items) { + m.pattern = transform(m.pattern); + m.guard = transform(m.guard); + m.suite = SuiteStmt::wrap(transform(m.suite)); + } + } + void visit(ImportStmt *stmt) override { + stmt->from = transform(stmt->from); + stmt->what = transform(stmt->what); + for (auto &a : stmt->args) { + a.type = transform(a.type); + a.defaultValue = transform(a.defaultValue); + } + stmt->ret = transform(stmt->ret); + } + void visit(TryStmt *stmt) override { + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + for (auto &a : stmt->items) + a = (ExceptStmt *)transform(a); + stmt->elseSuite = SuiteStmt::wrap(transform(stmt->elseSuite)); + stmt->finally = SuiteStmt::wrap(transform(stmt->finally)); + } + void visit(ExceptStmt *stmt) override { + stmt->exc = transform(stmt->exc); + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + } + void visit(GlobalStmt *stmt) override {} + void visit(ThrowStmt *stmt) override { + stmt->expr = transform(stmt->expr); + stmt->from = transform(stmt->from); + } + void visit(FunctionStmt *stmt) override { + stmt->ret = transform(stmt->ret); + for (auto &a : stmt->items) { + a.type = transform(a.type); + a.defaultValue = transform(a.defaultValue); + } + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + for (auto &d : stmt->decorators) + d = transform(d); + } + void visit(ClassStmt *stmt) override { + for (auto &a : stmt->items) { + a.type = transform(a.type); + a.defaultValue = transform(a.defaultValue); + } + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + for (auto &d : stmt->decorators) + d = transform(d); + for (auto &d : stmt->baseClasses) + d = transform(d); + for (auto &d : stmt->staticBaseClasses) + d = transform(d); + } + void visit(YieldFromStmt *stmt) override { stmt->expr = transform(stmt->expr); } + void visit(WithStmt *stmt) override { + for (auto &a : stmt->items) + a = transform(a); + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + } + void visit(CustomStmt *stmt) override { + stmt->expr = transform(stmt->expr); + stmt->suite = SuiteStmt::wrap(transform(stmt->suite)); + } +}; + } // namespace codon::ast diff --git a/codon/runtime/exc.cpp b/codon/runtime/exc.cpp index 2e0c4c76..6a07f7eb 100644 --- a/codon/runtime/exc.cpp +++ b/codon/runtime/exc.cpp @@ -157,6 +157,7 @@ struct SeqExcHeader_t { seq_int_t line; seq_int_t col; void *python_type; + void *cause; }; void seq_exc_init(int flags) { diff --git a/codon/runtime/lib.cpp b/codon/runtime/lib.cpp index 85c6f560..53d10ca0 100644 --- a/codon/runtime/lib.cpp +++ b/codon/runtime/lib.cpp @@ -23,6 +23,7 @@ #define GC_THREADS #include "codon/runtime/lib.h" +#include #include #define FASTFLOAT_ALLOWS_LEADING_PLUS @@ -33,6 +34,21 @@ * General */ +/// ORC patches + +#include "llvm/BinaryFormat/MachO.h" + +// Define a minimal mach header for JIT'd code. +static llvm::MachO::mach_header_64 fake_mach_header = { + .magic = llvm::MachO::MH_MAGIC_64, + .cputype = llvm::MachO::CPU_TYPE_ARM64, + .cpusubtype = llvm::MachO::CPU_SUBTYPE_ARM64_ALL, + .filetype = llvm::MachO::MH_DYLIB, + .ncmds = 0, + .sizeofcmds = 0, + .flags = 0, + .reserved = 0}; + // OpenMP patch with GC callbacks typedef int (*gc_setup_callback)(GC_stack_base *); typedef void (*gc_roots_callback)(void *, void *); @@ -255,9 +271,10 @@ template seq_str_t fmt_conv(T n, seq_str_t format, bool *error) { if (format.len == 0) { return string_conv(default_format(n)); } else { + auto locale = std::locale("en_US.UTF-8"); std::string fstr(format.str, format.len); return string_conv( - fmt::format(fmt::runtime(fmt::format(FMT_STRING("{{:{}}}"), fstr)), n)); + fmt::format(locale, fmt::runtime(fmt::format(FMT_STRING("{{:{}}}"), fstr)), n)); } } catch (const std::runtime_error &f) { *error = true; diff --git a/codon/util/common.cpp b/codon/util/common.cpp index 39a545e1..c0db1828 100644 --- a/codon/util/common.cpp +++ b/codon/util/common.cpp @@ -40,10 +40,13 @@ void compilationMessage(const std::string &header, const std::string &msg, } if (line > 0) fmt::print(out, ":{}", line); - if (col > 0) - fmt::print(out, ":{}", col); - if (len > 0) - fmt::print(out, "-{}", col + len); + if (col > 0) { + fmt::print(out, " ({}", col); + if (len > 0) + fmt::print(out, "-{})", col + len); + else + fmt::print(out, ")"); + } if (!file.empty()) fmt::print(out, ": "); fmt::print(out, "{}\033[1m {}\033[0m{}\n", header, msg, diff --git a/codon/util/common.h b/codon/util/common.h index ba5827f7..4baba52a 100644 --- a/codon/util/common.h +++ b/codon/util/common.h @@ -17,10 +17,10 @@ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" -#define DBG(c, ...) \ +#define DBGI(c, ...) \ fmt::print(codon::getLogger().log, "{}" c "\n", \ - std::string(size_t(2) * size_t(codon::getLogger().level), ' '), \ - ##__VA_ARGS__) + std::string(2 * codon::getLogger().level, ' '), ##__VA_ARGS__) +#define DBG(c, ...) fmt::print(codon::getLogger().log, c "\n", ##__VA_ARGS__) #define LOG(c, ...) DBG(c, ##__VA_ARGS__) #define LOG_TIME(c, ...) \ { \ @@ -136,7 +136,10 @@ struct SrcObject { SrcInfo getSrcInfo() const { return info; } - void setSrcInfo(SrcInfo info) { this->info = std::move(info); } + SrcObject *setSrcInfo(SrcInfo info) { + this->info = std::move(info); + return this; + } }; template void E(error::Error e, codon::SrcObject *o, const TA &...args) { E(e, o->getSrcInfo(), args...); diff --git a/codon/util/serialize.h b/codon/util/serialize.h new file mode 100644 index 00000000..bd3e9e74 --- /dev/null +++ b/codon/util/serialize.h @@ -0,0 +1,72 @@ +// Copyright (C) 2022-2024 Exaloop Inc. + +#pragma once + +#include +#include +#include +#include + +namespace codon { + +template struct PolymorphicSerializer { + struct Serializer { + std::function save; + std::function load; + }; + template static Serializer serializerFor() { + return {[](Base *b, Archive &a) { a.save(*(static_cast(b))); }, + [](Base *&b, Archive &a) { + b = new Derived(); + a.template load(static_cast(*b)); + }}; + } + + static inline std::unordered_map _serializers; + static inline std::unordered_map _factory; + template static void register_types() { + (_serializers.emplace((void *)(Derived::nodeId()), Derived::_typeName), ...); + (_factory.emplace(std::string(Derived::_typeName), serializerFor()), ...); + } + static void save(const std::string &s, Base *b, Archive &a) { + auto i = _factory.find(s); + assert(i != _factory.end() && "bad op"); + i->second.save(b, a); + } + static void load(const std::string &s, Base *&b, Archive &a) { + auto i = _factory.find(s); + assert(i != _factory.end() && "bad op"); + i->second.load(b, a); + } +}; +} // namespace codon + +#define SERIALIZE(Type, ...) \ + inline decltype(auto) members() const { return std::tie(__VA_ARGS__); } \ + inline decltype(auto) members() { return std::tie(__VA_ARGS__); } \ + static constexpr std::array \ + _memberNameData = []() { \ + std::array chars{'\0'}; \ + size_t _idx = 0; \ + constexpr auto *ini(#__VA_ARGS__); \ + for (char const *_c = ini; *_c; ++_c, ++_idx) \ + if (*_c != ',' && *_c != ' ') \ + chars[_idx] = *_c; \ + return chars; \ + }(); \ + static constexpr const char *_typeName = #Type; \ + static constexpr std::array \ + _memberNames = []() { \ + std::array out{}; \ + for (size_t _i = 0, nArgs = 0; nArgs < tser::detail::n_args(#__VA_ARGS__); \ + ++_i) { \ + while (Type::_memberNameData[_i] == '\0') \ + _i++; \ + out[nArgs++] = &Type::_memberNameData[_i]; \ + while (Type::_memberNameData[++_i] != '\0') \ + ; \ + } \ + return out; \ + }() + +#define BASE(T) tser::base(this) diff --git a/docs/intro/differences.md b/docs/intro/differences.md index 41118ff6..3897744b 100644 --- a/docs/intro/differences.md +++ b/docs/intro/differences.md @@ -51,3 +51,6 @@ does *not* change `int`s from 64-bit. While most of the commonly used builtin modules have Codon-native implementations, a few are not yet implemented. However these can still be used within Codon via `from python import`. + + +Many other missing features are also described in [Roadmap]. diff --git a/jit/codon/__init__.py b/jit/codon/__init__.py index 568e4bf9..e6ef34f9 100644 --- a/jit/codon/__init__.py +++ b/jit/codon/__init__.py @@ -3,3 +3,5 @@ __all__ = ["jit", "convert", "JITError"] from .decorator import jit, convert, execute, JITError + +__codon__ = False diff --git a/jit/codon/decorator.py b/jit/codon/decorator.py index 60e34fc9..de354118 100644 --- a/jit/codon/decorator.py +++ b/jit/codon/decorator.py @@ -34,6 +34,8 @@ "Cannot locate Codon. Please install Codon or set CODON_PATH." ) +debug_override = int(os.environ.get("CODON_JIT_DEBUG", 0)) + pod_conversions = { type(None): "pyobj", int: "int", @@ -137,8 +139,8 @@ def _codon_type(arg, **kwargs): j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__) return "{}[{}]".format(s, j) - debug = kwargs.get("debug", None) - if debug: + debug = kwargs.get("debug", 0) + if debug > 0: msg = "cannot convert " + t.__name__ if msg not in _error_msgs: print("[python]", msg, file=sys.stderr) @@ -160,7 +162,9 @@ def _reset_jit(): "import numpy as np\n" "import numpy.pybridge\n" ) - _jit.execute(init_code, "", 0, False) + if debug_override == 2: + print(f"[jit_debug] execute:\n{init_code}", file=sys.stderr) + _jit.execute(init_code, "", 0, int(debug_override > 0)) return _jit @@ -241,7 +245,9 @@ def convert(t): name, ", ".join("a{}".format(i) for i in range(len(slots))) ) - _jit.execute(code, "", 0, False) + if debug_override == 2: + print(f"[jit_debug] execute:\n{code}", file=sys.stderr) + _jit.execute(code, "", 0, int(debug_override > 0)) custom_conversions[t] = name return t @@ -252,38 +258,45 @@ def _jit_register_fn(f, pyvars, debug): fn, fl = "", 1 if hasattr(f, "__code__"): fn, fl = f.__code__.co_filename, f.__code__.co_firstlineno - _jit.execute(obj_str, fn, fl, 1 if debug else 0) + if debug == 2: + print(f"[jit_debug] execute:\n{obj_str}", file=sys.stderr) + _jit.execute(obj_str, fn, fl, int(debug > 0)) return obj_name except JITError: _reset_jit() raise -def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs): +def _jit_callback_fn(obj_name, module, debug=0, sample_size=5, pyvars=None, *args, **kwargs): try: args = (*args, *kwargs.values()) types = _codon_types(args, debug=debug, sample_size=sample_size) - if debug: + if debug > 0: print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr) return _jit.run_wrapper( - obj_name, list(types), module, list(pyvars), args, 1 if debug else 0 + obj_name, list(types), module, list(pyvars), args, int(debug > 0) ) except JITError: _reset_jit() raise -def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None): +def _jit_str_fn(fstr, debug=0, sample_size=5, pyvars=None): obj_name = _jit_register_fn(fstr, pyvars, debug) def wrapped(*args, **kwargs): return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs) return wrapped -def jit(fn=None, debug=None, sample_size=5, pyvars=None): +def jit(fn=None, debug=0, sample_size=5, pyvars=None): + if debug is None: + debug = 0 if not pyvars: pyvars = [] if not isinstance(pyvars, list): raise ArgumentError("pyvars must be a list") + if debug_override: + debug = debug_override + if fn and isinstance(fn, str): return _jit_str_fn(fn, debug, sample_size, pyvars) @@ -296,8 +309,14 @@ def wrapped(*args, **kwargs): return _decorate(fn) if fn else _decorate -def execute(code, debug=False): +def execute(code, debug=0): + if debug is None: + debug = 0 + if debug_override: + debug = debug_override try: + if debug == 2: + print(f"[jit_debug] execute:\n{code}", file=sys.stderr) _jit.execute(code, "", 0, int(debug)) except JITError: _reset_jit() diff --git a/scripts/Dockerfile.codon-build b/scripts/Dockerfile.codon-build index ad8d09b9..7d0e6e47 100644 --- a/scripts/Dockerfile.codon-build +++ b/scripts/Dockerfile.codon-build @@ -1,38 +1,43 @@ -FROM codon:llvm as codon-llvm +FROM exaloop/codon-llvm as codon-llvm +FROM nvidia/cuda:12.4.0-devel-centos7 +COPY --from=codon-llvm /opt/llvm-codon /opt/llvm-codon -FROM nvidia/cuda:11.8.0-devel-centos7 RUN yum -y update RUN yum -y install centos-release-scl-rh epel-release RUN yum -y install \ + devtoolset-7 \ ninja-build libuuid-devel openssl openssl-devel \ libsodium-devel cmake3 zlib-devel git patch perl-Data-Dumper -COPY --from=codon-llvm /opt/llvm-codon /opt/llvm-codon -RUN mkdir -p /github/codon -COPY cmake /github/codon/cmake -COPY codon /github/codon/codon -COPY docs /github/codon/docs -COPY jit /github/codon/jit -COPY stdlib /github/codon/stdlib -COPY test /github/codon/test -COPY CMakeLists.txt /github/codon -RUN cmake3 -S /github/codon -B /github/codon/build \ +RUN scl enable devtoolset-7 -- g++ -v + +RUN git clone -b develop https://github.com/exaloop/codon /github/codon +RUN scl enable devtoolset-7 -- cmake3 -S /github/codon -B /github/codon/build \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=/opt/llvm-codon/bin/clang \ + -DCMAKE_CXX_COMPILER=/opt/llvm-codon/bin/clang++ \ + -DLLVM_DIR=/opt/llvm-codon/lib/cmake/llvm \ + -DCODON_GPU=ON +RUN LD_LIBRARY_PATH=/usr/local/cuda-12.4/compat:${LD_LIBRARY_PATH} scl enable devtoolset-7 -- cmake3 --build /github/codon/build +RUN scl enable devtoolset-7 -- cmake3 --install /github/codon/build --prefix /opt/codon + +RUN scl enable devtoolset-7 -- cmake3 -S /github/codon/jupyter -B /github/codon/jupyter/build \ -G Ninja \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_C_COMPILER=/opt/llvm-codon/bin/clang \ -DCMAKE_CXX_COMPILER=/opt/llvm-codon/bin/clang++ \ -DLLVM_DIR=/opt/llvm-codon/lib/cmake/llvm \ - -DCODON_GPU=ON \ - -DCODON_JUPYTER=ON \ + -DCODON_PATH=/opt/codon \ -DOPENSSL_ROOT_DIR=$(openssl version -d | cut -d' ' -f2 | tr -d '"') \ -DOPENSSL_CRYPTO_LIBRARY=/usr/lib64/libssl.so \ - -DCMAKE_INSTALL_PREFIX=/opt/codon \ -DXEUS_USE_DYNAMIC_UUID=ON -RUN LD_LIBRARY_PATH=/usr/local/cuda-11.8/compat:${LD_LIBRARY_PATH} cmake3 --build /github/codon/build -# TODO: fix install RUN cmake3 --install /github/codon/build -RUN mkdir -p /opt/codon/bin -RUN cp /github/codon/build/codon /opt/codon/bin/ -RUN mkdir -p /opt/codon/lib/codon -RUN cp -r /github/codon/build/lib*.so /opt/codon/lib/codon/ -RUN cp -r /github/codon/stdlib /opt/codon/lib/codon/ -RUN cd /github/codon && tar cjvf /opt/codon-$(git rev-parse --short HEAD).tar.bz2 -C /opt codon/ -CMD cp /opt/codon-*.tar.bz2 /mnt/ +RUN scl enable devtoolset-7 -- cmake3 --build /github/codon/jupyter/build +RUN scl enable devtoolset-7 -- cmake3 --install /github/codon/jupyter/build + +# RUN mkdir -p /opt/codon/bin +# RUN cp /github/codon/build/codon /opt/codon/bin/ +# RUN mkdir -p /opt/codon/lib/codon +# RUN cp -r /github/codon/build/lib*.so /opt/codon/lib/codon/ +# RUN cp -r /github/codon/stdlib /opt/codon/lib/codon/ +RUN cd /github/codon && tar czvf /opt/codon-$(git rev-parse --short HEAD).tar.gz -C /opt codon/ +CMD cp /opt/codon-*.tar.gz /mnt/ diff --git a/scripts/get_system_libs.sh b/scripts/get_system_libs.sh index 89ac23a6..d9c8d1c4 100755 --- a/scripts/get_system_libs.sh +++ b/scripts/get_system_libs.sh @@ -50,6 +50,9 @@ if [ "$UNAME" = "Darwin" ]; then install_name_tool -id "@rpath/${LIBGFORTRAN_BASE}" ${LIBGFORTRAN} install_name_tool -id "@rpath/${LIBQUADMATH_BASE}" ${LIBQUADMATH} install_name_tool -id "@rpath/${LIBGCC_BASE}" ${LIBGCC} + codesign -f -s - ${LIBGFORTRAN} + codesign -f -s - ${LIBQUADMATH} + codesign -f -s - ${LIBGCC} else patchelf --set-rpath '$ORIGIN' ${LIBGFORTRAN} patchelf --set-rpath '$ORIGIN' ${LIBQUADMATH} diff --git a/stdlib/collections.codon b/stdlib/collections.codon index e3029cb3..e957d914 100644 --- a/stdlib/collections.codon +++ b/stdlib/collections.codon @@ -2,6 +2,7 @@ from internal.types.optional import unwrap +@dataclass(init=False) class deque: _arr: Array[T] _head: int @@ -187,13 +188,13 @@ class Counter(Static[Dict[T, int]]): def most_common(self, n: Optional[int] = None) -> List[Tuple[T, int]]: if len(self) == 0: - return List[_CounterItem](capacity=0) + return [] if n is None: - v = List[_CounterItem](capacity=len(self)) + v = List(capacity=len(self)) for t in self.items(): v.append(t) - v.sort(reverse=True) + v.sort(reverse=True, key=lambda i: i[1]) return v else: from heapq import heapify, heapreplace @@ -201,27 +202,28 @@ class Counter(Static[Dict[T, int]]): n: int = n if n == 1: - top: Optional[_CounterItem] = None + top: Optional[Tuple[T, int]] = None for t in self.items(): - if top is None or t[1] > top.count: + if top is None or t[1] > top[1]: top = t return [unwrap(top)] if n <= 0: - return List[_CounterItem](capacity=0) + return [] result = List[_CounterItem](capacity=n) for t in self.items(): + ct = _CounterItem(*t) if len(result) < n: - result.append(t) + result.append(ct) if len(result) == n: heapify(result) else: - if result[0] < t: - heapreplace(result, t) + if result[0] < ct: + heapreplace(result, ct) result.sort(reverse=True) - return result + return [tuple(i) for i in result] def subtract(self, elements: Generator[T]): for a in elements: diff --git a/stdlib/datetime.codon b/stdlib/datetime.codon index 2e1f09cb..4f0fbf1c 100644 --- a/stdlib/datetime.codon +++ b/stdlib/datetime.codon @@ -418,7 +418,7 @@ class timedelta: _microseconds: int def _new(microseconds: int) -> timedelta: - return (microseconds,) + return __internal__.tuple_cast_unsafe((microseconds,), timedelta) @inline def _accum(sofar: int, leftover: float, num: int, factor: int) -> Tuple[int, float]: @@ -442,10 +442,6 @@ class timedelta: y = s + int(intpart) return y, leftover + fracpart - # override default constructor - def __new__(days: int) -> timedelta: - return timedelta(days, 0) - def __new__( days: float = 0, seconds: float = 0, @@ -475,7 +471,11 @@ class timedelta: whole_us = 2.0 * ((leftover + is_odd) * 0.5).__round__() - is_odd us += int(whole_us) - return (us,) + return timedelta._new(us) + + # override default constructor + def __new__(days: int) -> timedelta: + return timedelta(days, 0) @property def days(self) -> int: @@ -772,7 +772,7 @@ class time: ) -> time: _check_time_args(hour, minute, second, microsecond) v = (hour << 40) | (minute << 32) | (second << 24) | microsecond - return (v,) + return superf(v) @property def hour(self) -> int: diff --git a/stdlib/functools.codon b/stdlib/functools.codon index 2ef86998..e66073a3 100644 --- a/stdlib/functools.codon +++ b/stdlib/functools.codon @@ -1,4 +1,5 @@ # Copyright (C) 2022-2025 Exaloop Inc. +@no_arg_reorder def partial(): # internal pass diff --git a/stdlib/internal/__init__.codon b/stdlib/internal/__init__.codon index 70167c86..15b9837c 100644 --- a/stdlib/internal/__init__.codon +++ b/stdlib/internal/__init__.codon @@ -1,7 +1,7 @@ # Copyright (C) 2022-2025 Exaloop Inc. # Core library - +# from internal.core import * # done automatically by compiler from internal.attributes import * from internal.static import static_print as __static_print__ from internal.types.ptr import * @@ -15,10 +15,14 @@ from internal.types.float import * from internal.types.byte import * from internal.types.generator import * from internal.types.optional import * + +import internal.c_stubs as _C +from internal.format import * +from internal.internal import * + from internal.types.slice import * from internal.types.range import * from internal.types.complex import * -from internal.internal import * __argv__ = Array[str](0) @@ -28,22 +32,19 @@ from internal.types.collections.set import * from internal.types.collections.dict import * from internal.types.collections.tuple import * -# Extended core library - -import internal.c_stubs as _C -from internal.format import * from internal.builtin import * from internal.builtin import _jit_display from internal.str import * from internal.sort import sorted -from openmp import Ident as __OMPIdent, for_par +from openmp import Ident as __OMPIdent, for_par, for_par as par from gpu import _gpu_loop_outline_template from internal.file import File, gzFile, open, gzopen from pickle import pickle, unpickle from internal.dlopen import dlsym as _dlsym import internal.python +from internal.python import PyError if __py_numerics__: import internal.pynumerics diff --git a/stdlib/internal/__init_test__.codon b/stdlib/internal/__init_test__.codon index 93bdf637..a7b4f14b 100644 --- a/stdlib/internal/__init_test__.codon +++ b/stdlib/internal/__init_test__.codon @@ -2,8 +2,9 @@ # Core library -from internal.core import * from internal.attributes import * +from internal.static import static_print as __static_print__ + from internal.types.ptr import * from internal.types.str import * from internal.types.int import * @@ -15,10 +16,10 @@ from internal.types.float import * from internal.types.byte import * from internal.types.generator import * from internal.types.optional import * +from internal.internal import * from internal.types.slice import * from internal.types.range import * from internal.types.complex import * -from internal.internal import * from internal.types.strbuf import strbuf as _strbuf from internal.types.collections.list import * import internal.c_stubs as _C @@ -144,10 +145,31 @@ class str: return str(self.ptr + i, j - i) + def join(self, l: Generator[str]) -> str: + buf = _strbuf() + if len(self) == 0: + for a in l: + buf.append(a) + else: + first = True + for a in l: + if first: + first = False + else: + buf.append(self) + buf.append(a) + return buf.__str__() + def __repr__(self) -> str: return f"'{self}'" + def _isdigit(a: byte) -> bool: + return _C.isdigit(i32(int(a))) != i32(0) + +set = Set +dict = Dict + from internal.builtin import * -from openmp import Ident as __OMPIdent, for_par +# from openmp import Ident as __OMPIdent, for_par from internal.dlopen import dlsym as _dlsym diff --git a/stdlib/internal/attributes.codon b/stdlib/internal/attributes.codon index f194274f..50c8d3f0 100644 --- a/stdlib/internal/attributes.codon +++ b/stdlib/internal/attributes.codon @@ -16,10 +16,6 @@ def inline(): def noinline(): pass -@__attribute__ -def pure(): - pass - @__attribute__ def nonpure(): pass @@ -36,10 +32,6 @@ def nocapture(): def pycapture(): pass -@__attribute__ -def derives(): - pass - @__attribute__ def self_captures(): pass @@ -71,3 +63,7 @@ def no_argument_wrap(): @__attribute__ def no_type_wrap(): pass + +@__attribute__ +def no_arg_reorder(): + pass diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index 71f6d754..77bd5d6d 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -1,15 +1,31 @@ # Copyright (C) 2022-2025 Exaloop Inc. +@tuple @__internal__ -class __internal__: +class type[T]: + # __new__ / __init__: always done internally + # __repr__ + # __call__ + # __name__ + # __module__ + # __doc__ pass + +@tuple @__internal__ -class __magic__: +class unrealized_type[T]: pass @tuple @__internal__ -class NoneType: +class TypeWrap[T]: + pass + +@__internal__ +class __internal__: + pass +@__internal__ +class __magic__: pass @tuple @@ -68,18 +84,18 @@ class float128: @tuple @__internal__ -class type: +@__notuple__ +class Function[T, TR]: pass @tuple @__internal__ -@__notuple__ -class Function[T, TR]: +class Callable[T, TR]: pass @tuple @__internal__ -class Callable[T, TR]: +class NoneType: pass @tuple @@ -89,6 +105,12 @@ class Ptr[T]: pass cobj = Ptr[byte] +@tuple +@__internal__ +@__notuple__ +class Ref[T]: + val: Ptr[T] + @tuple @__internal__ @__notuple__ @@ -123,58 +145,78 @@ class str: ptr: Ptr[byte] len: int - @tuple @__internal__ class Tuple: - @__internal__ + @llvm def __new__() -> Tuple: - pass - def __add__(self, obj): + ret {} {} + def __add__(self: __SELF__, obj, __SELF__: type): return __magic__.add(self, obj) - def __mul__(self, n: Static[int]): + def __mul__(self: __SELF__, n: Static[int], __SELF__: type): return __magic__.mul(self, n) - def __contains__(self, obj) -> bool: + def __contains__(self: __SELF__, obj, __SELF__: type) -> bool: return __magic__.contains(self, obj) - def __getitem__(self, idx: int): + def __getitem__(self: __SELF__, idx: int, __SELF__: type): return __magic__.getitem(self, idx) - def __iter__(self): + def __iter__(self: __SELF__, __SELF__: type): yield from __magic__.iter(self) - def __hash__(self) -> int: + def __hash__(self: __SELF__, __SELF__: type) -> int: return __magic__.hash(self) - def __repr__(self) -> str: + def __repr__(self: __SELF__, __SELF__: type) -> str: return __magic__.repr(self) - def __len__(self) -> int: + def __len__(self: __SELF__, __SELF__: type) -> int: return __magic__.len(self) - def __eq__(self, obj: Tuple) -> bool: + def __eq__(self: __SELF__, obj: __SELF__, __SELF__: type) -> bool: return __magic__.eq(self, obj) - def __ne__(self, obj: Tuple) -> bool: + def __ne__(self: __SELF__, obj: __SELF__, __SELF__: type) -> bool: return __magic__.ne(self, obj) - def __gt__(self, obj: Tuple) -> bool: + def __gt__(self: __SELF__, obj: __SELF__, __SELF__: type) -> bool: return __magic__.gt(self, obj) - def __ge__(self, obj: Tuple) -> bool: + def __ge__(self: __SELF__, obj: __SELF__, __SELF__: type) -> bool: return __magic__.ge(self, obj) - def __lt__(self, obj: Tuple) -> bool: + def __lt__(self: __SELF__, obj: __SELF__, __SELF__: type) -> bool: return __magic__.lt(self, obj) - def __le__(self, obj: Tuple) -> bool: + def __le__(self: __SELF__, obj: __SELF__, __SELF__: type) -> bool: return __magic__.le(self, obj) - def __pickle__(self, dest: Ptr[byte]): + def __pickle__(self: __SELF__, dest: Ptr[byte], __SELF__: type): return __magic__.pickle(self, dest) - def __unpickle__(src: Ptr[byte]) -> Tuple: + def __unpickle__(src: Ptr[byte], __SELF__: type) -> __SELF__: return __magic__.unpickle(src) - def __to_py__(self) -> Ptr[byte]: + def __to_py__(self: __SELF__, __SELF__: type) -> Ptr[byte]: return __magic__.to_py(self) - def __from_py__(src: Ptr[byte]) -> Tuple: + def __from_py__(src: Ptr[byte], __SELF__: type) -> __SELF__: return __magic__.from_py(src) - def __to_gpu__(self, cache) -> Tuple: + def __to_gpu__(self: __SELF__, cache, __SELF__: type) -> __SELF__: return __magic__.to_gpu(self, cache) - def __from_gpu__(self, other: Tuple): + def __from_gpu__(self: __SELF__, other: __SELF__, __SELF__: type): return __magic__.from_gpu(self, other) - def __from_gpu_new__(other: Tuple) -> Tuple: + def __from_gpu_new__(other: __SELF__, __SELF__: type) -> __SELF__: return __magic__.from_gpu_new(other) - def __tuplesize__(self) -> int: + def __tuplesize__(self: __SELF__, __SELF__: type) -> int: return __magic__.tuplesize(self) +tuple = Tuple + +@tuple +@__internal__ +class __NTuple__[N: Static[int], T]: + pass +@__attribute__ +def pure(): + pass + +@__attribute__ +def derives(): + pass + +@extend +class NoneType: + @pure + @derives + @llvm + def __new__() -> NoneType: + ret {} {} @tuple @__internal__ @@ -189,10 +231,6 @@ class type: pass function = Function -@__internal__ -class Ref[T]: - pass - @tuple @__internal__ @__notuple__ @@ -203,6 +241,14 @@ class Union[TU]: def __call__(self, *args, **kwargs): return __internal__.union_call(self, args, kwargs) +@extend +class Function: + @pure + @derives + @llvm + def __new__() -> Function[T, TR]: + ret ptr null + # dummy @__internal__ class TypeVar[T]: pass @@ -221,7 +267,13 @@ class RTTI: @__internal__ @tuple -class ellipsis: pass +class ellipsis: + @pure + @derives + @llvm + def __new__() -> ellipsis: + ret {} {} +Ellipsis = ellipsis() @tuple @__internal__ @@ -230,6 +282,26 @@ class __array__: def __new__(sz: Static[int]) -> Array[T]: pass +@tuple +@__internal__ +class Import: + loaded: bool + file: Static[str] + name: Static[str] + + @pure + @derives + @llvm + def __new__(loaded: bool, file: Static[str], name: Static[str]) -> Import[file, name]: + %0 = insertvalue { {=bool} } undef, {=bool} %loaded, 0 + ret { {=bool} } %0 + + def _set_loaded(i: Ptr[Import]): + Ptr[bool](i.as_byte())[0] = True + + def __repr__(self) -> str: + return f"" + def __ptr__(var): pass @@ -256,9 +328,6 @@ def getattr(obj, attr: Static[str]): def setattr(obj, attr: Static[str], what): pass -def tuple(iterable): - pass - def super(): pass @@ -275,20 +344,78 @@ def statictuple(*args): def __has_rtti__(T: type): pass -@dataclass(init=False) +#(init=False, repr=False, eq=False, order=False, hash=False, pickle=False, python=False, gpu=False, container=False) +@__internal__ @tuple -class Import: - loaded: bool - name: Static[str] - file: Static[str] +class NamedTuple: + args: T + N: Static[int] # name cache ID + T: type - def __new__(name: Static[str], - file: Static[str], - loaded: bool) -> Import[name, file]: - return (loaded,) + @pure + @derives + @llvm + def __new__(args: T = (), N: Static[int] = 0, T: type) -> NamedTuple[N, T]: + %0 = insertvalue { {=T} } undef, {=T} %args, 0 + ret { {=T} } %0 - def _set_loaded(i: Ptr[Import]): - Ptr[bool](i.as_byte())[0] = True + def __getitem__(self, key: Static[str]): + return getattr(self, key) - def __repr__(self) -> str: - return f"" + def __contains__(self, key: Static[str]): + return hasattr(self, key) + + def get(self, key: Static[str], default = NoneType()): + return __internal__.kwargs_get(self, key, default) + + def __keys__(self): + return __internal__.namedkeys(N) + + def __repr__(self): + keys = self.__keys__() + values = [v.__repr__() for v in self.args] + s = ', '.join(f"{keys[i]}: {values[i]}" for i in range(len(keys))) + return f"({s})" + +@__internal__ +@tuple +class Partial: + args: T # format: (arg1, arg2, ..., (star_args...)) + kwargs: K + + M: Static[str] # mask + T: type # Tuple + K: type # NamedTuple + F: type # must be unrealized_type + + @pure + @derives + @llvm + def __new__(args: T, kwargs: K, M: Static[str], F: type, T: type, K: type) -> Partial[M, T, K, F]: + %0 = insertvalue { {=T}, {=K} } undef, {=T} %args, 0 + %1 = insertvalue { {=T}, {=K} } %0, {=K} %kwargs, 1 + ret { {=T}, {=K} } %1 + + def __repr__(self): + return __magic__.repr_partial(self) + + def __call__(self, *args, **kwargs): + return self(*args, **kwargs) + + @property + def __fn_name__(self): + return F.__name__[16:-1] # chop off unrealized_type + + def __raw__(self): + # TODO: better error message + return F.T.__raw__() + +@__internal__ +@tuple +class ProxyFunc: + fn: Function[[Ptr[byte], T], TR] + data: Ptr[byte] + T: type + TR: type + +__codon__: Static[bool] = True diff --git a/stdlib/internal/file.codon b/stdlib/internal/file.codon index 730e37e3..215605ae 100644 --- a/stdlib/internal/file.codon +++ b/stdlib/internal/file.codon @@ -17,6 +17,12 @@ class File: raise IOError(f"file {path} could not be opened") self._reset() + def __init__(self, fd: int, mode: str): + self.fp = _C.fdopen(fd, mode.c_str()) + if not self.fp: + raise IOError(f"file descriptor {fd} could not be opened") + self._reset() + def _errcheck(self, msg: str): err = int(_C.ferror(self.fp)) if err: @@ -392,7 +398,7 @@ class bzFile: return i -def open(path: str, mode: str = "r") -> File: +def open(path, mode: str = "r") -> File: return File(path, mode) def gzopen(path: str, mode: str = "r") -> gzFile: diff --git a/stdlib/internal/format.codon b/stdlib/internal/format.codon index 5e5df80f..9d0c87ec 100644 --- a/stdlib/internal/format.codon +++ b/stdlib/internal/format.codon @@ -3,11 +3,88 @@ def _format_error(ret: str): raise ValueError(f"invalid format specifier: {ret}") +def python_to_fmt_format(s): + i = 0 + + fill, align = None, None + if i + 1 < len(s) and (s[i + 1] == '<' or s[i + 1] == '>' or s[i + 1] == '=' or s[i + 1] == '^'): + fill = s[i] + align = s[i + 1] + i += 2 + elif i < len(s) and (s[i] == '<' or s[i] == '>' or s[i] == '=' or s[i] == '^'): + align = s[i] + i += 1 + if align and align == '=': + raise NotImplementedError("'=' alignment not yet supported") + + sign = None + if i < len(s) and (s[i] == '+' or s[i] == '-' or s[i] == ' '): + sign = s[i] + i += 1 + + coerce_negative_float = False + if i < len(s) and s[i] == 'z': + coerce_negative_float = True + i += 1 + if coerce_negative_float: + raise NotImplementedError("'z' not yet supported") + + alternate_form = False + if i < len(s) and s[i] == '#': + alternate_form = True + i += 1 + + width_pre_zero = False + if i < len(s) and s[i] == '#': + width_pre_zero = True + i += 1 + width = 0 + while i < len(s) and str._isdigit(s.ptr[i]): + width = 10 * width + ord(s[i]) - ord('0') + i += 1 + + grouping = None + if i < len(s) and (s[i] == '_' or s[i] == ','): + grouping = s[i] + i += 1 + if grouping == '_': + raise NotImplementedError("'_' grouping not yet supported") + + precision = None + if i < len(s) and s[i] == '.': + i += 1 + precision = 0 + while i < len(s) and str._isdigit(s.ptr[i]): + precision = 10 * precision + ord(s[i]) - ord('0') + i += 1 + + type = None + if i < len(s): + type = s[i] + i += 1 + + if i != len(s): + raise ValueError("bad format string") + + # Construct fmt::format-compatible string + ns = "" + if align: + if fill: ns += fill + ns += align + if sign: ns += sign + if alternate_form: ns += "#" + if width_pre_zero: ns += "0" + if width: ns += str(width) + if precision is not None: ns += f".{precision}" + if grouping: ns += "L" + if type: ns += type + return ns + @extend class int: def __format__(self, format_spec: str) -> str: err = False - ret = _C.seq_str_int(self, format_spec, __ptr__(err)) + ret = _C.seq_str_int(self, python_to_fmt_format(format_spec), __ptr__(err)) if format_spec and err: _format_error(ret) return ret @@ -16,7 +93,7 @@ class int: class Int: def __format__(self, format_spec: str) -> str: err = False - ret = _C.seq_str_int(self.__int__(), format_spec, __ptr__(err)) + ret = _C.seq_str_int(self.__int__(), python_to_fmt_format(format_spec), __ptr__(err)) if format_spec and err: _format_error(ret) return ret @@ -25,7 +102,7 @@ class Int: class UInt: def __format__(self, format_spec: str) -> str: err = False - ret = _C.seq_str_uint(self.__int__(), format_spec, __ptr__(err)) + ret = _C.seq_str_uint(self.__int__(), python_to_fmt_format(format_spec), __ptr__(err)) if format_spec and err: _format_error(ret) return ret @@ -34,7 +111,7 @@ class UInt: class float: def __format__(self, format_spec: str) -> str: err = False - ret = _C.seq_str_float(self, format_spec, __ptr__(err)) + ret = _C.seq_str_float(self, python_to_fmt_format(format_spec), __ptr__(err)) if format_spec and err: _format_error(ret) return ret if ret != "-nan" else "nan" @@ -43,7 +120,7 @@ class float: class str: def __format__(self, format_spec: str) -> str: err = False - ret = _C.seq_str_str(self, format_spec, __ptr__(err)) + ret = _C.seq_str_str(self, python_to_fmt_format(format_spec), __ptr__(err)) if format_spec and err: _format_error(ret) return ret @@ -52,7 +129,7 @@ class str: class Ptr: def __format__(self, format_spec: str) -> str: err = False - ret = _C.seq_str_ptr(self.as_byte(), format_spec, __ptr__(err)) + ret = _C.seq_str_ptr(self.as_byte(), python_to_fmt_format(format_spec), __ptr__(err)) if format_spec and err: _format_error(ret) return ret diff --git a/stdlib/internal/gc.codon b/stdlib/internal/gc.codon index b5c35864..f6c2e0b6 100644 --- a/stdlib/internal/gc.codon +++ b/stdlib/internal/gc.codon @@ -105,7 +105,7 @@ def register_finalizer(p): def f(x: cobj, data: cobj, T: type): Ptr[T](__ptr__(x).as_byte())[0].__del__() - seq_register_finalizer(p.__raw__(), f(T=type(p), ...).__raw__()) + seq_register_finalizer(p.__raw__(), f(T=type(p), ...).F.T.__raw__()) def construct_ref[T](args) -> T: p = T.__new__() diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index e5582893..c989fbc5 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -6,19 +6,12 @@ from internal.gc import ( ) from internal.static import vars_types, tuple_type, vars as _vars, fn_overloads, fn_can_call -def vars(obj, with_index: Static[int] = 0): +def vars(obj, with_index: Static[bool] = False): return _vars(obj, with_index) __vtables__ = Ptr[Ptr[cobj]]() __vtable_size__ = 0 -@extend -class ellipsis: - def __new__() -> ellipsis: - return () - -Ellipsis = ellipsis() - @extend class __internal__: def yield_final(val): @@ -73,6 +66,13 @@ class __internal__: __vtables__ = Ptr[Ptr[cobj]](p) __internal__.class_populate_vtables() + def _print(a): + from C import seq_print(str) + if hasattr(a, '__repr__'): + seq_print(a.__repr__()) + else: + seq_print(a.__str__()) + def class_populate_vtables() -> None: """ Populate content of vtables. Compiler generated. @@ -148,6 +148,24 @@ class __internal__: pr = RTTI(B.__id__).__raw__() return __internal__.to_class_ptr_rtti((ptr, pr), B) + @llvm + def _ref_make_helper(val: Ptr[T], T: type) -> Ref[T]: + %0 = insertvalue { ptr } undef, ptr %val, 0 + ret { ptr } %0 + + def ref_make(val: T, T: type) -> Ref[T]: + p = Ptr[T](1) + p[0] = val + return __internal__._ref_make_helper(p) + + @llvm + def ref_get(ref: Ref[T], T: type) -> T: + %0 = extractvalue { ptr } %ref, 0 + %1 = getelementptr {=T}, ptr %0, i64 0 + %2 = load {=T}, ptr %1 + ret {=T} %2 + + # Unions def get_union_tag(u, tag: Static[int]): # compiler-generated @@ -177,7 +195,7 @@ class __internal__: return u def new_union(value, U: type) -> U: - for tag, T in vars_types(U, with_index=1): + for tag, T in vars_types(U, with_index=True): if isinstance(value, T): return __internal__.union_make(tag, value, U) if isinstance(value, Union[T]): @@ -186,14 +204,14 @@ class __internal__: raise TypeError("invalid union constructor") def get_union(union, T: type) -> T: - for tag, TU in vars_types(union, with_index=1): + for tag, TU in vars_types(union, with_index=True): if isinstance(TU, T): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, TU) raise TypeError(f"invalid union getter for type '{T.__class__.__name__}'") def _union_member_helper(union, member: Static[str]) -> Union: - for tag, T in vars_types(union, with_index=1): + for tag, T in vars_types(union, with_index=True): if hasattr(T, member): if __internal__.union_get_tag(union) == tag: return getattr(__internal__.union_get_data(union, T), member) @@ -207,7 +225,7 @@ class __internal__: return t def _union_call_helper(union, args, kwargs) -> Union: - for tag, T in vars_types(union, with_index=1): + for tag, T in vars_types(union, with_index=True): if fn_can_call(T, *args, **kwargs): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, T)(*args, **kwargs) @@ -225,7 +243,7 @@ class __internal__: return t def union_str(union): - for tag, T in vars_types(union, with_index=1): + for tag, T in vars_types(union, with_index=True): if hasattr(T, '__str__'): if __internal__.union_get_tag(union) == tag: return __internal__.union_get_data(union, T).__str__() @@ -236,6 +254,9 @@ class __internal__: # Tuples + def namedkeys(N: Static[int]): + pass + @pure @derives @llvm @@ -258,6 +279,10 @@ class __internal__: t, __internal__.tuple_fix_index(idx, staticlen(t)), T, E ) + @llvm + def tuple_cast_unsafe(t, U: type) -> U: + ret {=U} %t + @pure @derives @llvm @@ -309,22 +334,15 @@ class __internal__: if msg: raise OSError(prefix + msg) - @pure - @llvm - def opt_tuple_new(T: type) -> Optional[T]: - ret { i1, {=T} } { i1 false, {=T} undef } - @pure @llvm def opt_ref_new(T: type) -> Optional[T]: ret ptr null + @pure + @llvm def opt_ref_new_rtti(T: type) -> Optional[T]: - obj = Ptr[byte]() - rsz = sizeof(tuple(T)) - rtti = alloc_atomic(rsz) if RTTI.__contents_atomic__ else alloc(rsz) - __internal__.to_class_ptr(rtti, RTTI).id = T.__id__ - return __internal__.opt_ref_new_arg(__internal__.to_class_ptr_rtti((obj, rtti), T)) + ret { ptr, ptr } { ptr null, ptr null } @pure @derives @@ -361,7 +379,7 @@ class __internal__: @pure def opt_ref_bool_rtti(what: Optional[T], T: type) -> bool: - return __internal__.class_raw_rtti_ptr() != cobj() + return __internal__.class_raw_rtti_ptr(what) != cobj() @pure @derives @@ -481,10 +499,10 @@ class __internal__: def undef(v, s): if not v: - raise NameError(f"variable '{s}' not yet defined") + raise NameError(f"name '{s}' is not defined") @__hidden__ - def set_header(e, func, file, line, col): + def set_header(e, func, file, line, col, cause): if not isinstance(e, BaseException): compile_error("exceptions must derive from BaseException") @@ -492,6 +510,8 @@ class __internal__: e.file = file e.line = line e.col = col + if cause is not None: + e.cause = cause return e def kwargs_get(kw, key: Static[str], default): @@ -625,7 +645,7 @@ class __magic__: def hash(slf) -> int: seed = 0 for _, v in vars(slf): - seed = seed ^ ((v.__hash__() + 2654435769) + ((seed << 6) + (seed >> 2))) + seed = seed ^ ((v.__hash__() + 2_654_435_769) + ((seed << 6) + (seed >> 2))) return seed # @dataclass parameter: pickle=True @@ -636,7 +656,7 @@ class __magic__: # @dataclass parameter: pickle=True def unpickle(src: Ptr[byte], T: type) -> T: if isinstance(T, ByVal): - return tuple(type(t).__unpickle__(src) for t in vars_types(T)) + return __internal__.tuple_cast_unsafe(tuple(type(t).__unpickle__(src) for t in vars_types(T)), T) else: obj = T.__new__() for k, v in vars(obj): @@ -646,17 +666,17 @@ class __magic__: # @dataclass parameter: python=True def to_py(slf) -> Ptr[byte]: o = pyobj._tuple_new(staticlen(slf)) - for i, _, v in vars(slf, with_index=1): + for i, _, v in vars(slf, with_index=True): pyobj._tuple_set(o, i, v.__to_py__()) return o # @dataclass parameter: python=True def from_py(src: Ptr[byte], T: type) -> T: if isinstance(T, ByVal): - return tuple( + return __internal__.tuple_cast_unsafe(tuple( type(t).__from_py__(pyobj._tuple_get(src, i)) - for i, t in vars_types(T, with_index=1) - ) + for i, t in vars_types(T, with_index=True) + ), T) else: obj = T.__new__() for i, k, v in vars(obj, with_index=True): @@ -749,3 +769,89 @@ class ellipsis: return 269626442 # same as CPython __internal__.class_init_vtables() + + +class __cast__: + def cast(obj: T, T: type) -> Generator[T]: + return obj.__iter__() + + def cast(obj: int) -> float: + return float(obj) + + def cast(obj: T, T: type) -> Optional[T]: + return Optional[T](obj) + + def cast(obj: Optional[T], T: type) -> T: + return obj.unwrap() + + def cast(obj: T, T: type) -> pyobj: + return obj.__to_py__() + + def cast(obj: pyobj, T: type) -> T: + return T.__from_py__(obj) + + # Function[[T...], R] + # ExternFunction[[T...], R] + # CodonFunction[[T...], R] + # Partial[foo, [T...], R] + + # function into partial (if not Function) / fn(foo) -> fn(foo(...)) + # empty partial (!!) into Function[] + # union extract + # any into Union[] + # derived to base + + def conv_float(obj: float) -> int: + return int(obj) + +def __type_repr__(T: type): + return f"" + +@extend +class TypeWrap: + def __new__(T: type) -> TypeWrap[T]: + return __internal__.tuple_cast_unsafe((), TypeWrap[T]) + + def __call_no_self__(*args, **kwargs) -> T: + return T(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> T: + return T(*args, **kwargs) + + def __repr__(self): + return __type_repr__(T) + + @property + def __name__(self): + return T.__name__ + +@extend +class Ref: + def __init__(self, val: T): + self.val = val + +@extend +class ProxyFunc: + def __new__(fn: Function[[Ptr[byte], T], TR], data: Ptr[byte]) -> ProxyFunc[T, TR]: + return __internal__.tuple_cast_unsafe((fn, data), ProxyFunc[T, TR]) + + def __new__(fn: Function[[Ptr[byte], T], TR], data: Partial[M,PT,K,F], + T: type, TR: type, + M: Static[str], PT: type, F: type, K: type) -> ProxyFunc[T, TR]: + p = Ptr[Partial[M,PT,K,F]](1) + p[0] = data + return ProxyFunc(fn, p.as_byte()) + + def __new__(fn: Function[[Ptr[byte], T], TR], data: Function[T, TR]) -> ProxyFunc[T, TR]: + return ProxyFunc(fn, data.__raw__()) + + def __new__(fn: Function[T, TR]) -> ProxyFunc[T, TR]: + def _wrap(data: Ptr[byte], args, f: type): + return f(data)(*args) + return ProxyFunc( + __realized__(_wrap(f=Function[T, TR], ...), (Ptr[byte], T)), + fn.__raw__() + ) + + def __call__(self, *args): + return self.fn.__call__(self.data, args) diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 9d8ecefd..384a6cfd 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -940,6 +940,17 @@ class _PyArg_Parser: o = cobj() return _PyArg_Parser(z, format, keywords, fname, o, z, z, z, o, o) +@dataclass(init=False) +class PyError(Static[Exception]): + pytype: pyobj + + def __init__(self, message: str): + super().__init__("PyError", message) + self.pytype = pyobj(cobj(), steal=True) + def __init__(self, message: str, pytype: pyobj): + super().__init__("PyError", message) + self.pytype = pytype + @extend class pyobj: def __new__() -> pyobj: @@ -959,6 +970,10 @@ class pyobj: def _getattr(self, name: str) -> pyobj: return pyobj(pyobj.exc_wrap(PyObject_GetAttrString(self.p, name.c_str())), steal=True) + def exc_wrap(_retval: T, T: type) -> T: + pyobj.exc_check() + return _retval + def __add__(self, other): return pyobj(pyobj.exc_wrap(PyNumber_Add(self.p, other.__to_py__())), steal=True) @@ -1197,10 +1212,6 @@ class pyobj: # pyobj.decref(pvalue) raise PyError(msg, pyobj(pvalue)) - def exc_wrap(_retval: T, T: type) -> T: - pyobj.exc_check() - return _retval - def incref(self): Py_IncRef(self.p) return self @@ -1218,10 +1229,10 @@ class pyobj: def __call__(self, *args, **kwargs): args_py = args.__to_py__() kws_py = cobj() - if staticlen(kwargs) > 0: - names = iter(kwargs.__dict__()) - kws = {next(names): pyobj(i.__to_py__(), steal=True) for i in kwargs} - kws_py = kws.__to_py__() + if staticlen(kwargs.args) > 0: + keys = kwargs.__keys__() + values = [pyobj(v.__to_py__(), steal=True) for v in kwargs.args] + kws_py = {keys[i]: values[i] for i in range(len(keys))}.__to_py__() return pyobj(pyobj.exc_wrap(PyObject_Call(self.p, args_py, kws_py)), steal=True) def _tuple_new(length: int): diff --git a/stdlib/internal/static.codon b/stdlib/internal/static.codon index 1b24565d..b0f4ecb9 100644 --- a/stdlib/internal/static.codon +++ b/stdlib/internal/static.codon @@ -34,10 +34,10 @@ def fn_get_default(F, i: Static[int]): def static_print(*args): pass -def vars(obj, with_index: Static[int] = 0): +def vars(obj, with_index: Static[bool] = False): pass -def vars_types(T: type, with_index: Static[int] = 0): +def vars_types(T: type, with_index: Static[bool] = False): pass def tuple_type(T: type, N: Static[int]): diff --git a/stdlib/internal/types/array.codon b/stdlib/internal/types/array.codon index 78cfc4f0..a40becfa 100644 --- a/stdlib/internal/types/array.codon +++ b/stdlib/internal/types/array.codon @@ -4,16 +4,21 @@ from internal.gc import sizeof @extend class Array: + @llvm + @pure + @derives def __new__(ptr: Ptr[T], sz: int) -> Array[T]: - return (sz, ptr) + %0 = insertvalue { i64, ptr } undef, i64 %sz, 0 + %1 = insertvalue { i64, ptr } %0, ptr %ptr, 1 + ret { i64, ptr } %1 def __new__(sz: int) -> Array[T]: - return (sz, Ptr[T](sz)) + return Array[T](Ptr[T](sz), sz) def __copy__(self) -> Array[T]: p = Ptr[T](self.len) str.memcpy(p.as_byte(), self.ptr.as_byte(), self.len * sizeof(T)) - return (self.len, p) + return Array[T](p, self.len) def __deepcopy__(self) -> Array[T]: p = Ptr[T](self.len) @@ -21,7 +26,7 @@ class Array: while i < self.len: p[i] = self.ptr[i].__deepcopy__() i += 1 - return (self.len, p) + return Array[T](p, self.len) def __len__(self) -> int: return self.len @@ -36,6 +41,13 @@ class Array: self.ptr[index] = what def slice(self, s: int, e: int) -> Array[T]: - return (e - s, self.ptr + s) + return Array[T](self.ptr + s, e - s) array = Array + +# Forward declarations +@dataclass(init=False) +class List: + len: int + arr: Array[T] + T: type diff --git a/stdlib/internal/types/collections/dict.codon b/stdlib/internal/types/collections/dict.codon index 64d76c56..5ad313cc 100644 --- a/stdlib/internal/types/collections/dict.codon +++ b/stdlib/internal/types/collections/dict.codon @@ -198,6 +198,10 @@ class Dict: for k, v in other: self[k] = v + def update(self, other: Dict[K, V]): + for k, v in other.items(): + self[k] = v + def pop(self, key: K) -> V: x = self._kh_get(key) if x != self._kh_end(): @@ -283,6 +287,7 @@ class Dict: if new_n_buckets < 4: new_n_buckets = 4 + if self._size >= int(new_n_buckets * HASH_UPPER + 0.5): j = 0 else: diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index ff56b738..494a2db4 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -87,7 +87,10 @@ class List: other.append(self._get(i)) return other - def __setitem__(self, s: Slice, other): + def __setitem__(self, s: Slice, other: Generator[T]): + return self.__setitem__(s, [a for a in other]) + + def __setitem__(self, s: Slice, other: List[T]): if s.start is None and s.stop is None and s.step is None: self.clear() for a in other: @@ -96,25 +99,14 @@ class List: start, stop, step, length = s.adjust_indices(self.__len__()) if s.step is None or step == 1: - if isinstance(other, List[T]): - if other is self: - other = other.__copy__() - self._assign_slice(start, stop, other.arr.ptr, other.__len__()) - else: - items = [a for a in other] - self._assign_slice(start, stop, items.arr.ptr, items.__len__()) + items = [a for a in other] + self._assign_slice(start, stop, items.arr.ptr, items.__len__()) else: if (step < 0 and start < stop) or (step > 0 and start > stop): stop = start seq: Optional[List[T]] = None - if isinstance(other, List[T]): - if other is self: - seq = other.__copy__() - else: - seq = other - else: - seq = [a for a in other] + seq = [a for a in other] seq_len = seq.__len__() if seq_len != length: @@ -276,7 +268,8 @@ class List: else: buf = _strbuf() buf.append("[") - buf.append(self._get(0).__repr__()) + p = self._get(0).__repr__() + buf.append(p) for i in range(1, n): buf.append(", ") buf.append(self._get(i).__repr__()) diff --git a/stdlib/internal/types/collections/tuple.codon b/stdlib/internal/types/collections/tuple.codon index 27f12019..e34790e6 100644 --- a/stdlib/internal/types/collections/tuple.codon +++ b/stdlib/internal/types/collections/tuple.codon @@ -34,7 +34,7 @@ class DynamicTuple: return DynamicTuple(p, n) - def __new__(): + def __new__() -> DynamicTuple[T]: return DynamicTuple(Ptr[T](), 0) def __len__(self): diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index 15ae9743..65077322 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -11,7 +11,7 @@ class complex: imag: float def __new__() -> complex: - return (0.0, 0.0) + return complex(0.0, 0.0) def __new__(what): # do not overload! (needed to avoid pyobj conversion) @@ -21,7 +21,7 @@ class complex: return what.__complex__() def __new__(real, imag) -> complex: - return (float(real), float(imag)) + return superf(float(real), float(imag)) def __complex__(self) -> complex: return self @@ -340,6 +340,9 @@ class int: @extend class float: + def __complex__(self) -> complex: + return complex(self, 0.0) + def __suffix_j__(x: float) -> complex: return complex(0, x) @@ -348,7 +351,7 @@ f32 = float32 @extend class complex64: def __new__() -> complex64: - return (f32(0.0), f32(0.0)) + return complex64(f32(0.0), f32(0.0)) def __new__(other): if isinstance(other, str): @@ -360,10 +363,10 @@ class complex64: return complex64(real, f32(0.0)) def __new__(other: complex) -> complex64: - return (f32(other.real), f32(other.imag)) + return complex64(f32(other.real), f32(other.imag)) def __new__(real, imag) -> complex64: - return (f32(float(real)), f32(float(imag))) + return superf(f32(float(real)), f32(float(imag))) def __complex__(self) -> complex: return complex(float(self.real), float(self.imag)) @@ -762,3 +765,8 @@ class complex64: declare float @llvm.log.f32(float) %y = call float @llvm.log.f32(float %x) ret float %y + +@extend +class int: + def __complex__(self) -> complex: + return complex(float(self), 0.0) diff --git a/stdlib/internal/types/error.codon b/stdlib/internal/types/error.codon index 995eb04d..c02e6149 100644 --- a/stdlib/internal/types/error.codon +++ b/stdlib/internal/types/error.codon @@ -11,6 +11,7 @@ class BaseException: line: int col: int python_type: cobj + cause: Optional[BaseException] def __init__(self, typename: str, message: str = ""): self.typename = typename @@ -20,6 +21,7 @@ class BaseException: self.line = 0 self.col = 0 self.python_type = BaseException._pytype + self.cause = __internal__.opt_ref_new(T=BaseException) def __str__(self): return self.message @@ -27,11 +29,15 @@ class BaseException: def __repr__(self): return f'{self.typename}({self.message.__repr__()})' + @property + def __cause__(self): + return self.cause + class Exception(Static[BaseException]): _pytype: ClassVar[cobj] = cobj() def __init__(self, typename: str, msg: str = ""): super().__init__(typename, msg) - if (hasattr(self.__class__, "_pytype")): + if hasattr(self.__class__, "_pytype"): self.python_type = self.__class__._pytype class NameError(Static[Exception]): @@ -85,13 +91,6 @@ class CError(Static[Exception]): super().__init__("CError", message) self.python_type = self.__class__._pytype -class PyError(Static[Exception]): - pytype: pyobj - - def __init__(self, message: str, pytype: pyobj = pyobj(cobj(), steal=True)): - super().__init__("PyError", message) - self.pytype = pytype - class TypeError(Static[Exception]): _pytype: ClassVar[cobj] = cobj() def __init__(self, message: str = ""): diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index 15e3601e..37887a6d 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -2,7 +2,6 @@ from internal.attributes import commutative from internal.gc import alloc_atomic, free -from internal.types.complex import complex def _float_int_pow(a: F, b: int, F: type) -> F: abs_exp = b.__abs__() @@ -57,9 +56,6 @@ class float: %1 = zext i1 %0 to i8 ret i8 %1 - def __complex__(self) -> complex: - return complex(self, 0.0) - def __pos__(self) -> float: return self diff --git a/stdlib/internal/types/generator.codon b/stdlib/internal/types/generator.codon index 0e59b581..731f4d63 100644 --- a/stdlib/internal/types/generator.codon +++ b/stdlib/internal/types/generator.codon @@ -30,6 +30,9 @@ class Generator: def __new__(ptr: cobj) -> Generator[T]: ret ptr %ptr + def __new__() -> Generator[T]: + raise ValueError("invalid generator") + @pure @llvm def __done__(self) -> bool: diff --git a/stdlib/internal/types/int.codon b/stdlib/internal/types/int.codon index 8220e013..2525ab60 100644 --- a/stdlib/internal/types/int.codon +++ b/stdlib/internal/types/int.codon @@ -1,7 +1,6 @@ # Copyright (C) 2022-2025 Exaloop Inc. from internal.attributes import commutative, associative, distributive -from internal.types.complex import complex @extend class int: @@ -29,9 +28,6 @@ class int: %tmp = sitofp i64 %self to double ret double %tmp - def __complex__(self) -> complex: - return complex(float(self), 0.0) - def __index__(self) -> int: return self diff --git a/stdlib/internal/types/optional.codon b/stdlib/internal/types/optional.codon index 9f5f0a16..1838bc22 100644 --- a/stdlib/internal/types/optional.codon +++ b/stdlib/internal/types/optional.codon @@ -1,5 +1,13 @@ # Copyright (C) 2022-2025 Exaloop Inc. +@extend +class __internal__: + @pure + @llvm + def opt_tuple_new(T: type) -> Optional[T]: + ret { i1, {=T} } { i1 false, {=T} undef } + + @extend class Optional: def __new__() -> Optional[T]: @@ -22,7 +30,7 @@ class Optional: if isinstance(T, ByVal): return __internal__.opt_tuple_bool(self, T) elif __has_rtti__(T): - return __internal__.opt_ref_bool_rtti(T) + return __internal__.opt_ref_bool_rtti(self, T) else: return __internal__.opt_ref_bool(self, T) @@ -30,7 +38,7 @@ class Optional: if isinstance(T, ByVal): return __internal__.opt_tuple_invert(self, T) elif __has_rtti__(T): - return __internal__.opt_ref_invert_rtti(T) + return __internal__.opt_ref_invert_rtti(self, T) else: return __internal__.opt_ref_invert(self, T) @@ -85,4 +93,4 @@ optional = Optional def unwrap(opt: Optional[T], T: type) -> T: if opt.__has__(): return opt.__val__() - raise ValueError("optional is None") + raise ValueError(f"optional unpack failed: expected {T.__class__.__name__}, got None") diff --git a/stdlib/internal/types/ptr.codon b/stdlib/internal/types/ptr.codon index 8477454c..6c79be66 100644 --- a/stdlib/internal/types/ptr.codon +++ b/stdlib/internal/types/ptr.codon @@ -187,20 +187,12 @@ class Ptr: def __repr__(self) -> str: return self.__format__("") + ptr = Ptr Jar = Ptr[byte] -# Forward declarations -class List: - len: int - arr: Array[T] - T: type - @extend class NoneType: - def __new__() -> NoneType: - return () - def __eq__(self, other: NoneType): return True diff --git a/stdlib/internal/types/range.codon b/stdlib/internal/types/range.codon index 133d9788..8c9fbb17 100644 --- a/stdlib/internal/types/range.codon +++ b/stdlib/internal/types/range.codon @@ -8,16 +8,16 @@ class range: # Magic methods - def __new__(start: int, stop: int, step: int) -> range: - if step == 0: - raise ValueError("range() step argument must not be zero") - return (start, stop, step) - def __new__(start: int, stop: int) -> range: - return (start, stop, 1) + return range(start, stop, 1) def __new__(stop: int) -> range: - return (0, stop, 1) + return range(0, stop, 1) + + def __new__(start: int, stop: int, step: int) -> range: + if step == 0: + raise ValueError("range() step argument must not be zero") + return superf(start, stop, step) def _get(self, idx: int) -> int: return self.start + (idx * self.step) @@ -95,7 +95,6 @@ class range: else: return f"range({self.start}, {self.stop}, {self.step})" -@overload def staticrange(start: Static[int], stop: Static[int], step: Static[int] = 1): return range(start, stop, step) diff --git a/stdlib/internal/types/slice.codon b/stdlib/internal/types/slice.codon index bbd6a695..09099267 100644 --- a/stdlib/internal/types/slice.codon +++ b/stdlib/internal/types/slice.codon @@ -26,7 +26,7 @@ class Slice: T: type = int, U: type = int, V: type = int) -> Slice[T, U, V]: - return (start, stop, step) + return superf(start, stop, step) def adjust_indices(self, length: int) -> Tuple[int, int, int, int]: if not (T is int and U is int and V is int): diff --git a/stdlib/internal/types/str.codon b/stdlib/internal/types/str.codon index 53641dde..159d5ddc 100644 --- a/stdlib/internal/types/str.codon +++ b/stdlib/internal/types/str.codon @@ -20,6 +20,8 @@ class str: def __new__(what) -> str: if isinstance(what, Union): return __internal__.union_str(what) + elif isinstance(what, type): + return what.__repr__() elif hasattr(what, "__str__"): return what.__str__() else: diff --git a/stdlib/itertools.codon b/stdlib/itertools.codon index b8d974b2..69d0586e 100644 --- a/stdlib/itertools.codon +++ b/stdlib/itertools.codon @@ -747,12 +747,12 @@ def _permutations_static(pool, r: Static[int]): if i < 0: break -def permutations(pool, r = None): +def permutations(pool, r = None) -> Generator: if isinstance(pool, Tuple) and r is None: return _permutations_static(pool, staticlen(pool)) else: return _permutations_non_static(pool, r) @overload -def permutations(pool, r: Static[int]): +def permutations(pool, r: Static[int]) -> Generator: return _permutations_static(pool, r) diff --git a/stdlib/numpy/dtype.codon b/stdlib/numpy/dtype.codon index 07032ff1..96eea4c2 100644 --- a/stdlib/numpy/dtype.codon +++ b/stdlib/numpy/dtype.codon @@ -318,7 +318,7 @@ class finfo: @property def tiny(self): - return self.smallest_normal() + return self.smallest_normal def _type_name(self): if dtype is float64: diff --git a/stdlib/numpy/fft/__init__.codon b/stdlib/numpy/fft/__init__.codon index 240fec92..beb67bfa 100644 --- a/stdlib/numpy/fft/__init__.codon +++ b/stdlib/numpy/fft/__init__.codon @@ -251,8 +251,8 @@ def ihfft(a, n: Optional[int] = None, out.map(lambda x: x.conjugate(), inplace=True) return out -def _cook_nd_args(a, s = None, axes = None, invreal: Static[int] = False): - shapeless: Static[int] = (s is None) +def _cook_nd_args(a, s = None, axes = None, invreal: Static[bool] = False): + shapeless: Static[bool] = (s is None) if s is None: if axes is None: diff --git a/stdlib/numpy/fft/pocketfft.codon b/stdlib/numpy/fft/pocketfft.codon index 49fb60dd..a6c70b74 100644 --- a/stdlib/numpy/fft/pocketfft.codon +++ b/stdlib/numpy/fft/pocketfft.codon @@ -59,7 +59,7 @@ class arr: T: type def __new__(sz: int) -> arr[T]: - return (Ptr[T](sz), sz) + return arr(Ptr[T](sz), sz) def dealloc(self): free(self.p) @@ -139,7 +139,7 @@ class sincos_2pibyn: for i in range(1, v2.size()): v2[i] = sincos_2pibyn.calc(i*(mask+1), n, ang) - return (n, mask, shift, v1, v2) + return sincos_2pibyn[T](n, mask, shift, v1, v2) def cmplx(self, re, im): if T is float: diff --git a/stdlib/numpy/indexing.codon b/stdlib/numpy/indexing.codon index ee217a99..f453c79d 100644 --- a/stdlib/numpy/indexing.codon +++ b/stdlib/numpy/indexing.codon @@ -227,7 +227,7 @@ def _adv_idx_eliminate_new_and_used(arr, indexes): indexes = _adv_idx_prune_index(indexes) return arr, indexes -def _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arr_shape, saw_array: Static[int] = False): +def _adv_idx_build_for_contig_array(indexes, shape_from_non_arrays, arr_shape, saw_array: Static[bool] = False): if staticlen(indexes) == 0: return () diff --git a/stdlib/numpy/lib/arraysetops.codon b/stdlib/numpy/lib/arraysetops.codon index 861695c6..b86d8b7e 100644 --- a/stdlib/numpy/lib/arraysetops.codon +++ b/stdlib/numpy/lib/arraysetops.codon @@ -7,9 +7,9 @@ from ..ndmath import isnan from ..util import multirange, free, count, normalize_axis_index, coerce, cast def _unique1d(ar, - return_index: Static[int] = False, - return_inverse: Static[int] = False, - return_counts: Static[int] = False, + return_index: Static[bool] = False, + return_inverse: Static[bool] = False, + return_counts: Static[bool] = False, equal_nan: bool = True): n = ar.size @@ -117,9 +117,9 @@ def _unique1d(ar, return ans3 def unique(ar, - return_index: Static[int] = False, - return_inverse: Static[int] = False, - return_counts: Static[int] = False, + return_index: Static[bool] = False, + return_inverse: Static[bool] = False, + return_counts: Static[bool] = False, axis=None, equal_nan: bool = True): @@ -583,7 +583,7 @@ def _intersect1d_indices(ar1, ar2, assume_unique: bool = False): def intersect1d(ar1, ar2, assume_unique: bool = False, - return_indices: Static[int] = False): + return_indices: Static[bool] = False): if return_indices: return _intersect1d_indices(ar1, ar2, assume_unique=assume_unique) else: diff --git a/stdlib/numpy/linalg/linalg.codon b/stdlib/numpy/linalg/linalg.codon index eef93ad2..ebe30c78 100644 --- a/stdlib/numpy/linalg/linalg.codon +++ b/stdlib/numpy/linalg/linalg.codon @@ -111,10 +111,6 @@ class LinearizeData: col_strides: int out_lead_dim: int - def __new__(rows: int, cols: int, row_strides: int, col_strides: int, - out_lead_dim: int) -> LinearizeData: - return (rows, cols, row_strides, col_strides, out_lead_dim) - def __new__(rows: int, cols: int, row_strides: int, col_strides: int): return LinearizeData(rows, cols, row_strides, col_strides, cols) @@ -512,7 +508,7 @@ class EighResult[A, B]: else: compile_error("tuple ('EighResult') index out of range") -def _eigh(a, JOBZ: byte, UPLO: byte, compute_eigenvectors: Static[int]): +def _eigh(a, JOBZ: byte, UPLO: byte, compute_eigenvectors: Static[bool]): a = _asarray(a) B = type(_basetype(a.dtype)) @@ -761,7 +757,7 @@ def solve(a, b): else: return _solve(a, b) -def _inv(a, ignore_errors: Static[int]): +def _inv(a, ignore_errors: Static[bool]): a = _asarray(a) n = _square_rows(a) @@ -1160,7 +1156,7 @@ class EigResult[A, B]: else: compile_error("tuple ('EigResult') index out of range") -def _eig(a, JOBVL: byte, JOBVR: byte, compute_eigenvectors: Static[int]): +def _eig(a, JOBVL: byte, JOBVR: byte, compute_eigenvectors: Static[bool]): a = _asarray(a) B = type(_basetype(a.dtype)) C = type(_complextype(a.dtype)) @@ -1469,7 +1465,7 @@ class SVDResult[A1, A2]: else: compile_error("tuple ('SVDResult') index out of range") -def _svd(a, JOBZ: byte, compute_uv: Static[int]): +def _svd(a, JOBZ: byte, compute_uv: Static[bool]): B = type(_basetype(a.dtype)) m, n = _rows_cols(a) min_m_n = min(m, n) @@ -1544,7 +1540,7 @@ def _svd(a, JOBZ: byte, compute_uv: Static[int]): def svd(a, full_matrices: bool = True, - compute_uv: Static[int] = True, + compute_uv: Static[bool] = True, hermitian: bool = False): a = _asarray(a) @@ -3003,7 +2999,7 @@ def _multi_dot(arrays, order, i, j, out=None): _multi_dot(arrays, order, order[i, j] + 1, j), out=out) -def _multi_dot_matrix_chain_order(arrays, return_costs: Static[int] = False): +def _multi_dot_matrix_chain_order(arrays, return_costs: Static[bool] = False): n = len(arrays) p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] m = zeros((n, n), dtype=float) @@ -4006,13 +4002,13 @@ def _multi_svd_norm(x: ndarray, row_axis: int, col_axis: int, op): y = moveaxis(x, (row_axis, col_axis), (-2, -1)) return op(svd(y, compute_uv=False), axis=-1) -def _norm_wrap(x, kd_shape, keepdims: Static[int]): +def _norm_wrap(x, kd_shape, keepdims: Static[bool]): if keepdims: return asarray(x).reshape(kd_shape) else: return x -def norm(x, ord = None, axis = None, keepdims: Static[int] = False): +def norm(x, ord = None, axis = None, keepdims: Static[bool] = False): x = asarray(x) if not (x.dtype is float or x.dtype is float32 or x.dtype is float16): diff --git a/stdlib/numpy/misc.codon b/stdlib/numpy/misc.codon index 38de0ae2..25bd367c 100644 --- a/stdlib/numpy/misc.codon +++ b/stdlib/numpy/misc.codon @@ -54,9 +54,6 @@ class DiophantineTerm: a: int ub: int - def __new__(a: int, ub: int) -> DiophantineTerm: - return (a, ub) - def __lt__(self, other: DiophantineTerm): return self.a < other.a diff --git a/stdlib/numpy/ndarray.codon b/stdlib/numpy/ndarray.codon index 32b6db24..bcb19c21 100644 --- a/stdlib/numpy/ndarray.codon +++ b/stdlib/numpy/ndarray.codon @@ -203,29 +203,29 @@ class flatiter[A]: def copy(self): return self.base.flatten() -@tuple(init=False) +@tuple class _UnaryFunctor: op: F F: type def __new__(op: F, F: type) -> _UnaryFunctor[F]: - return (op, ) + return superf(op) def __call__(self, y, x): y[0] = self.op(x[0]) -@tuple(init=False) +@tuple class _InplaceUnaryFunctor: op: F F: type def __new__(op: F, F: type) -> _InplaceUnaryFunctor[F]: - return (op, ) + return superf(op) def __call__(self, x): x[0] = self.op(x[0]) -@tuple(init=False) +@tuple class _BinaryFunctor: op: F F: type @@ -233,23 +233,23 @@ class _BinaryFunctor: R2: type def __new__(op: F, R1: type, R2: type, F: type) -> _BinaryFunctor[F, R1, R2]: - return (op, ) + return superf(op) def __call__(self, z, x, y): z[0] = self.op(util.cast(x[0], R1), util.cast(y[0], R2)) -@tuple(init=False) +@tuple class _InplaceBinaryFunctor: op: F F: type def __new__(op: F, F: type) -> _InplaceBinaryFunctor[F]: - return (op, ) + return superf(op) def __call__(self, x, y): x[0] = self.op(x[0], util.cast(y[0], type(x[0]))) -@tuple(init=False) +@tuple class _RightBinaryFunctor: op: F F: type @@ -257,12 +257,12 @@ class _RightBinaryFunctor: R2: type def __new__(op: F, R1: type, R2: type, F: type) -> _RightBinaryFunctor[F, R1, R2]: - return (op, ) + return superf(op) def __call__(self, z, x, y): z[0] = self.op(util.cast(y[0], R2), util.cast(x[0], R1)) -@tuple(init=False) +@tuple class _ScalarFunctor: op: F y: Y @@ -272,12 +272,12 @@ class _ScalarFunctor: R2: type def __new__(op: F, y: Y, R1: type, R2: type, F: type, Y: type) -> _ScalarFunctor[F, Y, R1, R2]: - return (op, y) + return superf(op, y) def __call__(self, z, x): z[0] = self.op(util.cast(x[0], R1), util.cast(self.y, R2)) -@tuple(init=False) +@tuple class _InplaceScalarFunctor: op: F y: Y @@ -285,13 +285,13 @@ class _InplaceScalarFunctor: Y: type def __new__(op: F, y: Y, F: type, Y: type) -> _InplaceScalarFunctor[F, Y]: - return (op, y) + return superf(op, y) def __call__(self, x): x[0] = self.op(x[0], util.cast(self.y, type(x[0]))) -@tuple(init=False) +@tuple class _RightScalarFunctor: op: F y: Y @@ -301,7 +301,7 @@ class _RightScalarFunctor: R2: type def __new__(op: F, y: Y, R1: type, R2: type, F: type, Y: type) -> _RightScalarFunctor[F, Y, R1, R2]: - return (op, y) + return superf(op, y) def __call__(self, z, x): z[0] = self.op(util.cast(self.y, R2), util.cast(x[0], R1)) @@ -315,7 +315,7 @@ class ndarray[dtype, ndim: Static[int]]: def __new__(shape: Tuple[ndim, int], strides: Tuple[ndim, int], data: Ptr[dtype]) -> ndarray[dtype, ndim]: - return (shape, strides, data) + return __internal__.tuple_cast_unsafe((shape, strides, data), ndarray[dtype, ndim]) def __new__(shape: Tuple[ndim, int], data: Ptr[dtype], fcontig: bool = False): strides = util.strides(shape, fcontig, dtype) @@ -619,8 +619,8 @@ class ndarray[dtype, ndim: Static[int]]: return ndarray(shape, self._data, fcontig=(order == 'F')) def _loop(arrays, func, broadcast: Static[str] = 'all', - check: Static[int] = True, alloc: type = type(()), - optimize_order: Static[int] = True, extra = None): + check: Static[bool] = True, alloc: type = type(()), + optimize_order: Static[bool] = True, extra = None): def call(func, args, extra): if extra is None: return func(*args) @@ -721,7 +721,7 @@ class ndarray[dtype, ndim: Static[int]]: perm, _ = util.sort_by_stride(perm, strides) return perm - def broadcast_shapes(args, check: Static[int]): + def broadcast_shapes(args, check: Static[bool]): def largest(args): if staticlen(args) == 1: return args[0] @@ -757,7 +757,7 @@ class ndarray[dtype, ndim: Static[int]]: return ans - def broadcast_to(x, shape, check: Static[int]): + def broadcast_to(x, shape, check: Static[bool]): N: Static[int] = x.ndim substrides = (0,) * N p = Ptr[int](__ptr__(substrides).as_byte()) @@ -778,7 +778,7 @@ class ndarray[dtype, ndim: Static[int]]: new_strides = (*z, *substrides) return ndarray(shape, new_strides, x.data) - def broadcast_arrays(arrays, check: Static[int]): + def broadcast_arrays(arrays, check: Static[bool]): shape = broadcast_shapes(tuple(arr.shape for arr in arrays), check=check) return tuple(broadcast_to(arr, shape, check=False) for arr in arrays) @@ -833,7 +833,7 @@ class ndarray[dtype, ndim: Static[int]]: strides = ndarray(perm_shape, p).strides return (p, strides) - def broadcast_args(arrays, broadcast: Static[str], check: Static[int]): + def broadcast_args(arrays, broadcast: Static[str], check: Static[bool]): if broadcast == 'none': shape = arrays[0].shape strides = tuple(arr.strides for arr in arrays) @@ -918,7 +918,7 @@ class ndarray[dtype, ndim: Static[int]]: else: return tuple(ndarray(shape0, tup[1], tup[0]) for tup in allocated) - def _contiguous(self, copy: Static[int] = False): + def _contiguous(self, copy: Static[bool] = False): ccontig, _ = self._contig if ccontig: if copy: @@ -938,7 +938,7 @@ class ndarray[dtype, ndim: Static[int]]: i += 1 return p - def _fcontiguous(self, copy: Static[int] = False): + def _fcontiguous(self, copy: Static[bool] = False): _, fcontig = self._contig if fcontig: if copy: @@ -1037,7 +1037,7 @@ class ndarray[dtype, ndim: Static[int]]: else: return [a.tolist() for a in self] - def _ptr_for_index(self, indexes, check: Static[int] = True, broadcast: Static[int] = False): + def _ptr_for_index(self, indexes, check: Static[bool] = True, broadcast: Static[bool] = False): s = self.shape strides = self.strides pshape = Ptr[int](__ptr__(s).as_byte()) @@ -1065,7 +1065,7 @@ class ndarray[dtype, ndim: Static[int]]: return Ptr[dtype](self._data.as_byte() + offset) - def _ptr(self, indexes, broadcast: Static[int] = False): + def _ptr(self, indexes, broadcast: Static[bool] = False): return self._ptr_for_index(indexes, check=False, broadcast=broadcast) def __len__(self): @@ -1298,7 +1298,7 @@ class ndarray[dtype, ndim: Static[int]]: return m, M - def map(self, fn, inplace: Static[int] = False): + def map(self, fn, inplace: Static[bool] = False): if inplace: return self._iop_unary(fn) else: @@ -1394,7 +1394,7 @@ class ndarray[dtype, ndim: Static[int]]: else: return self.map(bswap, inplace=False) - def _ptr_flat(self, idx: int, check: Static[int]): + def _ptr_flat(self, idx: int, check: Static[bool]): if check: n = self.size if idx < -n or idx >= n: @@ -1403,8 +1403,8 @@ class ndarray[dtype, ndim: Static[int]]: idx += n return self._ptr(util.index_to_coords(idx, self.shape)) - def _get_flat(self, idx: int, check: Static[int]): + def _get_flat(self, idx: int, check: Static[bool]): return self._ptr_flat(idx, check=check)[0] - def _set_flat(self, idx: int, val, check: Static[int]): + def _set_flat(self, idx: int, val, check: Static[bool]): self._ptr_flat(idx, check=check)[0] = util.cast(val, dtype) diff --git a/stdlib/numpy/npdatetime.codon b/stdlib/numpy/npdatetime.codon index 8d377dd5..42266e43 100644 --- a/stdlib/numpy/npdatetime.codon +++ b/stdlib/numpy/npdatetime.codon @@ -38,8 +38,9 @@ class _time_t: isdst: i8 def __new__() -> _time_t: - return (i16(0), i16(0), i8(0), i8(0), i8(0), i8(0), i8(0), i8(0), - i8(0)) + return _time_t( + i16(0), i16(0), i8(0), i8(0), i8(0), i8(0), i8(0), i8(0), i8(0) + ) @pure @llvm @@ -737,8 +738,8 @@ def _static_gcd(x: Static[int], y: Static[int]): else: return _static_gcd(y, x % y) -def _meta_gcd(meta1: _Meta, meta2: _Meta, strict1: Static[int], - strict2: Static[int]): +def _meta_gcd(meta1: _Meta, meta2: _Meta, strict1: Static[bool], + strict2: Static[bool]): def incompatible_units(base1: Static[str], base2: Static[str]): compile_error( @@ -867,8 +868,8 @@ def _parse_datetime_type(s: Static[str]): def _promote(T1: type, T2: type): meta1 = _Meta[T1.base, T1.num]() meta2 = _Meta[T2.base, T2.num]() - strict1: Static[int] = isinstance(T1, timedelta64) - strict2: Static[int] = isinstance(T2, timedelta64) + strict1: Static[bool] = isinstance(T1, timedelta64) + strict2: Static[bool] = isinstance(T2, timedelta64) meta_out = _meta_gcd(meta1, meta2, strict1, strict2) if isinstance(T1, datetime64) or isinstance(T2, datetime64): @@ -895,7 +896,7 @@ def _coerce(d1, d2): class DatetimeMetaData: def __new__(base: int, num: int = 1) -> DatetimeMetaData: - return (i32(base), i32(num)) + return DatetimeMetaData(i32(base), i32(num)) def __new__(base: Static[str], num: int = 1) -> DatetimeMetaData: return DatetimeMetaData(base=_base_code(base), num=num) @@ -1012,14 +1013,14 @@ class DatetimeMetaData: class timedelta64: def __new__() -> timedelta64[base, num]: - return (0, ) + return timedelta64[base, num](0) def __new__(value: int, base: Static[str], num: Static[int]) -> timedelta64[base, num]: _validate_base(base) _validate_num(num) - return (value, ) + return superf(value) def __new__(td: timedelta64, base: Static[str], @@ -1036,9 +1037,9 @@ class timedelta64: if (len(s) == 3 and (s.ptr[0] == byte(78) or s.ptr[0] == byte(110)) and (s.ptr[1] == byte(65) or s.ptr[1] == byte(97)) and (s.ptr[2] == byte(116) or s.ptr[2] == byte(84))): - return (_DATETIME_NAT, ) + return timedelta64[base, num](_DATETIME_NAT) else: - return (int(s), ) + return timedelta64[base, num](int(s)) def __new__(value: int, unit: Static[str]): TD = _parse_datetime_type("timedelta64[" + unit + "]") @@ -1434,14 +1435,14 @@ class timedelta64: class datetime64: def __new__() -> datetime64[base, num]: - return (0, ) + return datetime64[base, num](0) def __new__(value: int, base: Static[str], num: Static[int]) -> datetime64[base, num]: _validate_base(base) _validate_num(num) - return (value, ) + return superf(value) def __new__(value: datetime64, base: Static[str], @@ -1457,7 +1458,7 @@ class datetime64: _validate_num(num) meta = DatetimeMetaData(base=base, num=num) dts = datetimestruct(s) - return (dts.to_datetime64(meta), ) + return datetime64[base, num](dts.to_datetime64(meta)) def __new__(value: int, unit: Static[str]): DT = _parse_datetime_type("datetime64[" + unit + "]") diff --git a/stdlib/numpy/npio.codon b/stdlib/numpy/npio.codon index b773bf6c..6db05157 100644 --- a/stdlib/numpy/npio.codon +++ b/stdlib/numpy/npio.codon @@ -525,7 +525,7 @@ class Converters: def __new__(funcs, mask, usecols, dtype: type) -> Converters[dtype, F, M, U]: - return (funcs, mask, usecols) + return superf(funcs, mask, usecols) def __call__(self, field: str, idx: int): usecols = self.usecols @@ -987,7 +987,7 @@ def loadtxt(fname: str, converters=None, skiprows: int = 0, usecols=None, - unpack: Static[int] = False, + unpack: Static[bool] = False, ndmin: Static[int] = 0, max_rows: Optional[int] = None, quotechar: Optional[str] = None): @@ -1963,7 +1963,7 @@ def genfromtxt(fname, replace_space: str = '_', autostrip: bool = False, case_sensitive=True, - unpack: Static[int] = False, + unpack: Static[bool] = False, loose: bool = True, invalid_raise: bool = True, max_rows: Optional[int] = None, @@ -1978,7 +1978,7 @@ def genfromtxt(fname, num_fields: int, block_size: int, dtype: type, - unpack: Static[int] = False): + unpack: Static[bool] = False): if cols is None: if isinstance(dtype, Tuple): if staticlen(dtype) != num_fields: diff --git a/stdlib/numpy/operators.codon b/stdlib/numpy/operators.codon index 2e739e01..9524e73f 100644 --- a/stdlib/numpy/operators.codon +++ b/stdlib/numpy/operators.codon @@ -22,21 +22,21 @@ class _OpWrap: def _fix_scalar(x, A: type): X = type(x) - a_is_int: Static[int] = (A is int or A is byte or isinstance(A, Int) - or isinstance(A, UInt)) - x_is_int: Static[int] = X is int + a_is_int: Static[bool] = (A is int or A is byte or isinstance(A, Int) + or isinstance(A, UInt)) + x_is_int: Static[bool] = X is int - a_is_float: Static[int] = (A is float or A is float32 or A is float16 - or A is bfloat16 or A is float128) - x_is_float: Static[int] = X is float + a_is_float: Static[bool] = (A is float or A is float32 or A is float16 + or A is bfloat16 or A is float128) + x_is_float: Static[bool] = X is float - a_is_complex: Static[int] = (A is complex or A is complex64) - x_is_complex: Static[int] = X is complex + a_is_complex: Static[bool] = (A is complex or A is complex64) + x_is_complex: Static[bool] = X is complex - should_cast: Static[int] = ((x_is_int and - (a_is_int or a_is_float or a_is_complex)) or - (x_is_float and (a_is_float or a_is_complex)) - or (x_is_complex and a_is_complex)) + should_cast: Static[bool] = ((x_is_int and + (a_is_int or a_is_float or a_is_complex)) or + (x_is_float and (a_is_float or a_is_complex)) + or (x_is_complex and a_is_complex)) if (A is float16 or A is float32) and X is complex: return util.cast(x, complex64) diff --git a/stdlib/numpy/random/bitgen.codon b/stdlib/numpy/random/bitgen.codon index 447d79c7..07a154aa 100644 --- a/stdlib/numpy/random/bitgen.codon +++ b/stdlib/numpy/random/bitgen.codon @@ -2117,7 +2117,7 @@ class Generator[G]: if abs(p_sum - 1.) > atol: raise ValueError("probabilities do not sum to 1") - is_scalar: Static[int] = size is None + is_scalar: Static[bool] = size is None if not is_scalar: shape = size size1 = prod(shape) diff --git a/stdlib/numpy/random/mt19937.codon b/stdlib/numpy/random/mt19937.codon index 5b4e39a5..29ee1d9c 100644 --- a/stdlib/numpy/random/mt19937.codon +++ b/stdlib/numpy/random/mt19937.codon @@ -25,7 +25,7 @@ class MT19937: data: Ptr[u32] seed: SeedSequence - def __new__(seed, legacy: Static[int] = False): + def __new__(seed, legacy: Static[bool] = False): if not isinstance(seed, SeedSequence): return MT19937(SeedSequence(seed)) else: diff --git a/stdlib/numpy/reductions.codon b/stdlib/numpy/reductions.codon index ab6f7bec..0535c272 100644 --- a/stdlib/numpy/reductions.codon +++ b/stdlib/numpy/reductions.codon @@ -62,8 +62,8 @@ def _nan_to_back(v: Ptr[T], n: int, T: type): else: return n -def _make_reducer(R, ans_type: type, dtype: type, conv_to_float: Static[int], - bool_to_int: Static[int], **kwargs): +def _make_reducer(R, ans_type: type, dtype: type, conv_to_float: Static[bool], + bool_to_int: Static[bool], **kwargs): if dtype is NoneType: if conv_to_float: ftype = type(_float(ans_type())) @@ -79,7 +79,7 @@ def _make_reducer(R, ans_type: type, dtype: type, conv_to_float: Static[int], else: return R(dtype, **kwargs) -def _cast_elem(e0, dtype: type, conv_to_float: Static[int]): +def _cast_elem(e0, dtype: type, conv_to_float: Static[bool]): if dtype is not NoneType: e1 = util.cast(e0, dtype) else: @@ -212,10 +212,10 @@ def _reduce_all(arr, empty, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), - conv_to_float: Static[int] = False, - bool_to_int: Static[int] = False, + conv_to_float: Static[bool] = False, + bool_to_int: Static[bool] = False, **kwargs): if out is not None: if keepdims: @@ -267,12 +267,12 @@ def _reduce_all(arr, loop_axis = -1 min_abs_stride = 0x7FFFFFFFFFFFFFFF - for i in staticrange(staticlen(arr.ndim)): - stride = strides[i] + for j in staticrange(arr.ndim): + stride = strides[j] if stride: abs_stride = abs(stride) if abs_stride < min_abs_stride: - loop_axis = i + loop_axis = j min_abs_stride = abs_stride if loop_axis == -1: @@ -341,18 +341,18 @@ class _GradualFunctor: k: int kwargs: KW dtype: type - conv_to_float: Static[int] + conv_to_float: Static[bool] R: type KW: type def __new__(redux: R, k: int, dtype: type, - conv_to_float: Static[int], + conv_to_float: Static[bool], kwargs: KW, R: type, KW: type) -> _GradualFunctor[dtype, conv_to_float, R, KW]: - return (redux, k, kwargs) + return superf(redux, k, kwargs) def __call__(self, q, p): e = _cast_elem(p[0], self.dtype, self.conv_to_float) @@ -365,9 +365,9 @@ def _reduce_gradual(arr, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, - conv_to_float: Static[int] = False, - bool_to_int: Static[int] = False, + keepdims: Static[bool] = False, + conv_to_float: Static[bool] = False, + bool_to_int: Static[bool] = False, **kwargs): data = arr.data shape = arr.shape @@ -522,10 +522,10 @@ def _reduce(arr, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), - conv_to_float: Static[int] = False, - bool_to_int: Static[int] = False, + conv_to_float: Static[bool] = False, + bool_to_int: Static[bool] = False, **kwargs): data = arr.data shape = arr.shape @@ -759,7 +759,7 @@ def _reduce_buffered(arr, out=None, overwrite_input: bool = False, force_contig: bool = True, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), **kwargs): data = arr.data @@ -972,7 +972,7 @@ def _reduce_buffered_multi(arr, out=None, overwrite_input: bool = False, force_contig: bool = True, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), **kwargs): data = arr.data @@ -1156,7 +1156,7 @@ class SumRedux: return ans - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): self.total += SumRedux[T]._loop(a, n, stride, S) class NanSumRedux: @@ -1189,7 +1189,7 @@ class NanSumRedux: def gradual_accept(self, curr, item, index: int, **kwargs): return curr if _isnan(item) else curr + util.cast(item, T) - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): ans = T() for i in range(n): @@ -1228,7 +1228,7 @@ class ProdRedux: def gradual_accept(self, curr, item, index: int, **kwargs): return curr * util.cast(item, T) - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): ans = T(1) for i in range(n): @@ -1267,7 +1267,7 @@ class NanProdRedux: def gradual_accept(self, curr, item, index: int, **kwargs): return curr if _isnan(item) else curr * util.cast(item, T) - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): ans = T(1) for i in range(n): @@ -1308,7 +1308,7 @@ class MeanRedux: def gradual_result(self, curr, count: int): return curr / T(count) if count else _nan(T) - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): self.total += SumRedux[T]._loop(a, n, stride, S) class NanMeanRedux: @@ -1339,7 +1339,7 @@ class NanMeanRedux: def done(self): return False - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): ans = T() nan_count = 0 @@ -1407,7 +1407,7 @@ class MinRedux: else: return x if x < m else m - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): if self.m is None: m: T = util.cast(a[0], T) a = _increment_ptr(a, stride) @@ -1478,7 +1478,7 @@ class MaxRedux: else: return x if x > m else m - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): if self.m is None and n > 0: m: T = util.cast(a[0], T) a = _increment_ptr(a, stride) @@ -1529,7 +1529,7 @@ class PTPRedux: def done(self): return False - def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[int], S: type): + def loop(self, a: Ptr[S], n: int, stride: int, partial: Static[bool], S: type): # n must be >0 here or we would've thrown an exception earlier m = util.cast(a[0], T) M = m @@ -1684,7 +1684,7 @@ def sum( axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=0, where=util._NoValue(), ): @@ -1707,7 +1707,7 @@ def nansum( axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=0, where=util._NoValue(), ): @@ -1730,7 +1730,7 @@ def prod( axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=1, where=util._NoValue(), ): @@ -1753,7 +1753,7 @@ def nanprod( axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=1, where=util._NoValue(), ): @@ -1776,7 +1776,7 @@ def mean( axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): a = asarray(a) @@ -1797,7 +1797,7 @@ def nanmean( axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): a = asarray(a) @@ -1851,7 +1851,7 @@ def var( dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): a = asarray(a) @@ -1913,7 +1913,7 @@ def nanvar( dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): a = asarray(a) @@ -1938,7 +1938,7 @@ def std( dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): a = asarray(a) @@ -1963,7 +1963,7 @@ def nanstd( dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): a = asarray(a) @@ -1984,7 +1984,7 @@ def min( dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=util._NoValue(), ): @@ -2012,7 +2012,7 @@ def max( dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=util._NoValue(), ): @@ -2034,7 +2034,7 @@ def max( initial=initial, ) -def ptp(a, axis=None, out=None, keepdims: Static[int] = False): +def ptp(a, axis=None, out=None, keepdims: Static[bool] = False): a = asarray(a) return _reduce( a, @@ -2045,7 +2045,7 @@ def ptp(a, axis=None, out=None, keepdims: Static[int] = False): keepdims=keepdims, ) -def argmin(a, axis=None, out=None, keepdims: Static[int] = False): +def argmin(a, axis=None, out=None, keepdims: Static[bool] = False): a = asarray(a) return _reduce( a, @@ -2056,7 +2056,7 @@ def argmin(a, axis=None, out=None, keepdims: Static[int] = False): keepdims=keepdims, ) -def argmax(a, axis=None, out=None, keepdims: Static[int] = False): +def argmax(a, axis=None, out=None, keepdims: Static[bool] = False): a = asarray(a) return _reduce( a, @@ -2070,7 +2070,7 @@ def argmax(a, axis=None, out=None, keepdims: Static[int] = False): def any(a, axis=None, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): a = asarray(a) return _reduce( @@ -2086,7 +2086,7 @@ def any(a, def all(a, axis=None, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): a = asarray(a) return _reduce( @@ -2099,7 +2099,7 @@ def all(a, where=where, ) -def count_nonzero(a, axis=None, keepdims: Static[int] = False): +def count_nonzero(a, axis=None, keepdims: Static[bool] = False): a = asarray(a) return _reduce(a, R=NonZeroRedux.create, @@ -2144,7 +2144,7 @@ def median(a, axis=None, out=None, overwrite_input: bool = False, - keepdims: Static[int] = False): + keepdims: Static[bool] = False): a = asarray(a) return _reduce_buffered(a, _median_reducer, @@ -2165,7 +2165,7 @@ def nanmedian(a, axis=None, out=None, overwrite_input: bool = False, - keepdims: Static[int] = False): + keepdims: Static[bool] = False): a = asarray(a) return _reduce_buffered(a, _nanmedian_reducer, @@ -2538,7 +2538,7 @@ def _quantile_unchecked(a, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: Static[int] = False): + keepdims: Static[bool] = False): # Assumes that q is in [0, 1], and is an ndarray if q.ndim == 0: return _reduce_buffered(a, @@ -2576,7 +2576,7 @@ def quantile(a, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if interpolation is not None: method = _check_interpolation_as_method(method, interpolation) @@ -2614,7 +2614,7 @@ def _nanquantile_unchecked(a, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: Static[int] = False): + keepdims: Static[bool] = False): # Assumes that q is in [0, 1], and is an ndarray if q.ndim == 0: return _reduce_buffered(a, @@ -2646,7 +2646,7 @@ def nanquantile(a, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if interpolation is not None: method = _check_interpolation_as_method(method, interpolation) @@ -2676,7 +2676,7 @@ def percentile(a, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if interpolation is not None: method = _check_interpolation_as_method(method, interpolation) @@ -2696,7 +2696,7 @@ def nanpercentile(a, out=None, overwrite_input: bool = False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if interpolation is not None: method = _check_interpolation_as_method(method, interpolation) @@ -2729,7 +2729,7 @@ class ndarray: axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=0, where=util._NoValue(), ): @@ -2748,7 +2748,7 @@ class ndarray: axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=1, where=util._NoValue(), ): @@ -2767,7 +2767,7 @@ class ndarray: axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): return mean(self, @@ -2782,7 +2782,7 @@ class ndarray: axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): return nanmean(self, @@ -2798,7 +2798,7 @@ class ndarray: dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): return var( @@ -2817,7 +2817,7 @@ class ndarray: dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): return nanvar( @@ -2836,7 +2836,7 @@ class ndarray: dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): return std( @@ -2855,7 +2855,7 @@ class ndarray: dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue(), ): return nanstd( @@ -2874,7 +2874,7 @@ class ndarray: dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=util._NoValue(), ): @@ -2889,7 +2889,7 @@ class ndarray: where=where, ) - def ptp(self, axis=None, out=None, keepdims: Static[int] = False): + def ptp(self, axis=None, out=None, keepdims: Static[bool] = False): return ptp(self, axis=axis, out=out, keepdims=keepdims) def max( @@ -2898,7 +2898,7 @@ class ndarray: dtype: type = NoneType, out=None, ddof: int = 0, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=util._NoValue(), ): @@ -2913,22 +2913,22 @@ class ndarray: where=where, ) - def argmin(self, axis=None, out=None, keepdims: Static[int] = False): + def argmin(self, axis=None, out=None, keepdims: Static[bool] = False): return argmin(self, axis=axis, out=out, keepdims=keepdims) - def argmax(self, axis=None, out=None, keepdims: Static[int] = False): + def argmax(self, axis=None, out=None, keepdims: Static[bool] = False): return argmax(self, axis=axis, out=out, keepdims=keepdims) def any(self, axis=None, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): return any(self, axis=axis, out=out, keepdims=keepdims, where=where) def all(self, axis=None, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): return all(self, axis=axis, out=out, keepdims=keepdims, where=where) diff --git a/stdlib/numpy/routines.codon b/stdlib/numpy/routines.codon index 1c845ac7..2399827b 100644 --- a/stdlib/numpy/routines.codon +++ b/stdlib/numpy/routines.codon @@ -549,7 +549,7 @@ def arange(start: str, stop, step = None, dtype: type = datetime64['D', 1]): return _datetime_arange(start, stop, step, dtype) def linspace(start: float, stop: float, num: int = 50, - endpoint: bool = True, retstep: Static[int] = False, + endpoint: bool = True, retstep: Static[bool] = False, dtype: type = float): if num < 0: raise ValueError(f'Number of samples, {num}, must be non-negative.') @@ -583,8 +583,8 @@ def linspace(start: float, stop: float, num: int = 50, return result def _linlogspace(start: float, stop: float, num: int = 50, base: float = 10.0, - out_sign: int = 1, endpoint: bool = True, retstep: Static[int] = False, - dtype: type = float, log: Static[int] = False): + out_sign: int = 1, endpoint: bool = True, retstep: Static[bool] = False, + dtype: type = float, log: Static[bool] = False): if num < 0: raise ValueError(f'Number of samples, {num}, must be non-negative.') @@ -633,7 +633,7 @@ def _linlogspace(start: float, stop: float, num: int = 50, base: float = 10.0, return result def linspace(start: float, stop: float, num: int = 50, - endpoint: bool = True, retstep: Static[int] = False, + endpoint: bool = True, retstep: Static[bool] = False, dtype: type = float): return _linlogspace(start=start, stop=stop, num=num, endpoint=endpoint, retstep=retstep, @@ -641,7 +641,7 @@ def linspace(start: float, stop: float, num: int = 50, def logspace(start: float, stop: float, num: int = 50, endpoint: bool = True, base: float = 10.0, - retstep: Static[int] = False, + retstep: Static[bool] = False, dtype: type = float): return _linlogspace(start=start, stop=stop, num=num, endpoint=endpoint, retstep=retstep, @@ -799,7 +799,7 @@ def broadcast_arrays(*args): bshape = broadcast_shapes(*shapes) return [broadcast_to(a, bshape) for a in args] -def meshgrid(*xi, copy: bool = True, sparse: Static[int] = False, indexing: Static[str] = 'xy'): +def meshgrid(*xi, copy: bool = True, sparse: Static[bool] = False, indexing: Static[str] = 'xy'): def make_shape(i, ndim: Static[int]): t = (1,) * ndim p = Ptr[int](__ptr__(t).as_byte()) @@ -1965,7 +1965,7 @@ def array_split(ary, indices_or_sections, axis: int = 0): return idx def slice_axis(arr, axis: int, start: int, stop: int): - ndim: Static[int] = staticlen(ary.shape) + ndim: Static[int] = staticlen(arr.shape) dtype = arr.dtype shape = arr.shape @@ -3474,7 +3474,7 @@ def take(a, indices, axis = None, out = None, mode: str = 'raise'): else: return res -def indices(dimensions, dtype: type = int, sparse: Static[int] = False): +def indices(dimensions, dtype: type = int, sparse: Static[bool] = False): if not isinstance(dimensions, Tuple): compile_error("dimensions must be a tuple of integers") @@ -4713,7 +4713,7 @@ def isscalar(element): T is str or T is NoneType) -def _array_get_part(arr: ndarray, imag: Static[int]): +def _array_get_part(arr: ndarray, imag: Static[bool]): if arr.dtype is complex: offset = util.sizeof(float) if imag else 0 data = Ptr[float](arr.data.as_byte() + offset) diff --git a/stdlib/numpy/sorting.codon b/stdlib/numpy/sorting.codon index 43df7d86..ffcdcccf 100644 --- a/stdlib/numpy/sorting.codon +++ b/stdlib/numpy/sorting.codon @@ -1550,7 +1550,7 @@ def _store_pivot(pivot: int, kth: int, pivots: Ptr[int], npiv: Ptr[int]): pivots[npiv[0]] = pivot npiv[0] += 1 -def _median3_swap(v: Ptr[T], tosort: Ptr[int], low: int, mid: int, high: int, arg: Static[int], T: type): +def _median3_swap(v: Ptr[T], tosort: Ptr[int], low: int, mid: int, high: int, arg: Static[bool], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) @@ -1569,7 +1569,7 @@ def _median3_swap(v: Ptr[T], tosort: Ptr[int], low: int, mid: int, high: int, ar sortee.swap(mid, low + 1) -def _median5(v: Ptr[T], tosort: Ptr[int], arg: Static[int], T: type): +def _median5(v: Ptr[T], tosort: Ptr[int], arg: Static[bool], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) @@ -1600,7 +1600,7 @@ def _median5(v: Ptr[T], tosort: Ptr[int], arg: Static[int], T: type): else: return 2 -def _unguarded_partition(v: Ptr[T], tosort: Ptr[int], pivot: T, ll: int, hh: int, arg: Static[int], T: type): +def _unguarded_partition(v: Ptr[T], tosort: Ptr[int], pivot: T, ll: int, hh: int, arg: Static[bool], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) @@ -1626,7 +1626,7 @@ def _unguarded_partition(v: Ptr[T], tosort: Ptr[int], pivot: T, ll: int, hh: int return ll, hh -def _median_of_median5(v: Ptr[T], tosort: Ptr[int], num: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], isel, T: type): +def _median_of_median5(v: Ptr[T], tosort: Ptr[int], num: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[bool], isel, T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) @@ -1648,7 +1648,7 @@ def _median_of_median5(v: Ptr[T], tosort: Ptr[int], num: int, pivots: Ptr[int], return nmed // 2 -def _dumb_select(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, arg: Static[int], T: type): +def _dumb_select(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, arg: Static[bool], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) @@ -1672,7 +1672,7 @@ def _msb(unum: u64): unum >>= u64(1) return depth_limit -def _introselect(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[int], T: type): +def _introselect(v: Ptr[T], tosort: Ptr[int], num: int, kth: int, pivots: Ptr[int], npiv: Ptr[int], arg: Static[bool], T: type): if arg: idx = ArgIdx(tosort) sortee = ArgSortee(tosort) @@ -1814,7 +1814,7 @@ def _fix_kth(kth, n: int): return pkth, nkth -def _partition_helper(a: ndarray, kth, axis, kind: str, force_compact: Static[int] = False): +def _partition_helper(a: ndarray, kth, axis, kind: str, force_compact: Static[bool] = False): if axis is None: _partition_helper(a.flatten(), kth=kth, axis=-1, kind=kind, force_compact=force_compact) return @@ -2055,7 +2055,7 @@ def _sort_buffered(a: ndarray, axis: int, sorter): util.free(buf) -def _sort_dispatch(a: ndarray, axis: int, sorter, force_compact: Static[int] = False): +def _sort_dispatch(a: ndarray, axis: int, sorter, force_compact: Static[bool] = False): axis = util.normalize_axis_index(axis, a.ndim) if force_compact: @@ -2066,7 +2066,7 @@ def _sort_dispatch(a: ndarray, axis: int, sorter, force_compact: Static[int] = F else: _sort_buffered(a, axis, sorter) -def _sort(a: ndarray, axis: int, kind: Optional[str], force_compact: Static[int] = False): +def _sort(a: ndarray, axis: int, kind: Optional[str], force_compact: Static[bool] = False): if kind is None or kind == 'quicksort' or kind == 'quick': _sort_dispatch(a, axis, quicksort, force_compact) elif kind == 'mergesort' or kind == 'merge' or kind == 'stable': diff --git a/stdlib/numpy/statistics.codon b/stdlib/numpy/statistics.codon index ce19361a..653c920c 100644 --- a/stdlib/numpy/statistics.codon +++ b/stdlib/numpy/statistics.codon @@ -11,8 +11,8 @@ from .sorting import sort, argsort def average(a, axis=None, weights=None, - returned: Static[int] = False, - keepdims: Static[int] = False): + returned: Static[bool] = False, + keepdims: Static[bool] = False): def result_type(a_dtype: type, w_dtype: type): common_dtype = type(util.coerce(a_dtype, w_dtype)) @@ -909,7 +909,7 @@ def _histogram_fast(a, bins=10, range=None): def histogram(a, bins=10, range=None, - density: Static[int] = False, + density: Static[bool] = False, weights=None): def return_zeros(size, weights): @@ -918,7 +918,7 @@ def histogram(a, else: return zeros(size, dtype=weights.dtype) - def histogram_result(n, bin_edges, density: Static[int] = False): + def histogram_result(n, bin_edges, density: Static[bool] = False): if density: db = array(_diff(bin_edges), float) return (n / db / n.sum(), bin_edges) @@ -942,11 +942,11 @@ def histogram(a, BLOCK = 65536 # The fast path uses bincount, but that only works for certain types of weight - simple_weights1: Static[int] = weights is None + simple_weights1: Static[bool] = weights is None if isinstance(weights, ndarray): - simple_weights2: Static[int] = (weights.dtype is float - or weights.dtype is complex - or weights.dtype is complex64) + simple_weights2: Static[bool] = (weights.dtype is float + or weights.dtype is complex + or weights.dtype is complex64) if uniform_bins is not None and (simple_weights1 or simple_weights2): # Fast algorithm for equal bins diff --git a/stdlib/numpy/ufunc.codon b/stdlib/numpy/ufunc.codon index f4dde392..11477f80 100644 --- a/stdlib/numpy/ufunc.codon +++ b/stdlib/numpy/ufunc.codon @@ -125,21 +125,21 @@ def _apply_vectorized_loop_binary(arr1, arr2, out, func: Static[str]): def _fix_scalar(x, A: type): X = type(x) - a_is_int: Static[int] = (A is int or A is byte or isinstance(A, Int) - or isinstance(A, UInt)) - x_is_int: Static[int] = X is bool or X is int + a_is_int: Static[bool] = (A is int or A is byte or isinstance(A, Int) + or isinstance(A, UInt)) + x_is_int: Static[bool] = X is bool or X is int - a_is_float: Static[int] = (A is float or A is float32 or A is float16 - or A is bfloat16 or A is float128) - x_is_float: Static[int] = X is float + a_is_float: Static[bool] = (A is float or A is float32 or A is float16 + or A is bfloat16 or A is float128) + x_is_float: Static[bool] = X is float - a_is_complex: Static[int] = (A is complex or A is complex64) - x_is_complex: Static[int] = X is complex + a_is_complex: Static[bool] = (A is complex or A is complex64) + x_is_complex: Static[bool] = X is complex - should_cast: Static[int] = ((x_is_int and - (a_is_int or a_is_float or a_is_complex)) or - (x_is_float and (a_is_float or a_is_complex)) - or (x_is_complex and a_is_complex)) + should_cast: Static[bool] = ((x_is_int and + (a_is_int or a_is_float or a_is_complex)) or + (x_is_float and (a_is_float or a_is_complex)) + or (x_is_complex and a_is_complex)) if (A is float16 or A is float32) and X is complex: return util.cast(x, complex64) @@ -161,10 +161,10 @@ def decide_types(x, y, dtype: type): X = type(routines.asarray(x).data[0]) Y = type(routines.asarray(y).data[0]) - x_scalar: Static[int] = (isinstance(x, bool) or isinstance(x, int) - or isinstance(x, float) or isinstance(x, complex)) - y_scalar: Static[int] = (isinstance(y, bool) or isinstance(y, int) - or isinstance(y, float) or isinstance(y, complex)) + x_scalar: Static[bool] = (isinstance(x, bool) or isinstance(x, int) + or isinstance(x, float) or isinstance(x, complex)) + y_scalar: Static[bool] = (isinstance(y, bool) or isinstance(y, int) + or isinstance(y, float) or isinstance(y, complex)) if x_scalar and y_scalar: return t1(util.coerce(X, Y)) @@ -185,10 +185,10 @@ def decide_types_copysign(x, y, dtype: type): XF = type(util.to_float(util.zero(X))) YF = type(util.to_float(util.zero(Y))) - x_scalar: Static[int] = (isinstance(x, bool) or isinstance(x, int) - or isinstance(x, float) or isinstance(x, complex)) - y_scalar: Static[int] = (isinstance(y, bool) or isinstance(y, int) - or isinstance(y, float) or isinstance(y, complex)) + x_scalar: Static[bool] = (isinstance(x, bool) or isinstance(x, int) + or isinstance(x, float) or isinstance(x, complex)) + y_scalar: Static[bool] = (isinstance(y, bool) or isinstance(y, int) + or isinstance(y, float) or isinstance(y, complex)) if dtype is float16 or dtype is float32 or dtype is float: return t2(dtype, dtype) @@ -216,9 +216,9 @@ def decide_types_ldexp(x, y, dtype: type): or isinstance(Y, UInt)): compile_error("ldexp 2nd argument must be of integral type") - x_scalar: Static[int] = (isinstance(x, bool) or isinstance(x, int) - or isinstance(x, float) or isinstance(x, complex)) - y_scalar: Static[int] = isinstance(y, int) + x_scalar: Static[bool] = (isinstance(x, bool) or isinstance(x, int) + or isinstance(x, float) or isinstance(x, complex)) + y_scalar: Static[bool] = isinstance(y, int) if dtype is float16 or dtype is float32 or dtype is float: return t2(dtype, int) @@ -237,7 +237,7 @@ class _UnaryFunctor: UF: type def __new__(ufunc: UF, dtype: type, UF: type) -> _UnaryFunctor[dtype, UF]: - return (ufunc, ) + return superf(ufunc) def __call__(self, y, x): y[0] = self.ufunc._f(x[0], dtype=self.dtype, dtype_out=type(y[0])) @@ -250,7 +250,7 @@ class _UnaryWhereFunctor: def __new__(ufunc: UF, dtype: type, UF: type) -> _UnaryWhereFunctor[dtype, UF]: - return (ufunc, ) + return superf(ufunc) def __call__(self, y, x, w): if w[0]: @@ -262,7 +262,7 @@ class _Unary2Functor: UF: type def __new__(ufunc: UF, UF: type) -> _Unary2Functor[UF]: - return (ufunc, ) + return superf(ufunc) def __call__(self, y1, y2, x): e1, e2 = self.ufunc._op(x[0]) @@ -275,7 +275,7 @@ class _Unary2WhereFunctor: UF: type def __new__(ufunc: UF, UF: type) -> _Unary2WhereFunctor[UF]: - return (ufunc, ) + return superf(ufunc) def __call__(self, y1, y2, x, w): if w[0]: @@ -293,7 +293,7 @@ class _BinaryFunctor: def __new__(ufunc: UF, CT1: type, CT2: type, dtype: type, UF: type) -> _BinaryFunctor[CT1, CT2, dtype, UF]: - return (ufunc, ) + return superf(ufunc) def __call__(self, z, x, y): z[0] = self.ufunc._f(util.cast(x[0], CT1), @@ -313,7 +313,7 @@ class _BinaryScalar1Functor: def __new__(ufunc: UF, x: X, CT1: type, CT2: type, dtype: type, UF: type, X: type) -> _BinaryScalar1Functor[CT1, CT2, dtype, UF, X]: - return (ufunc, x) + return superf(ufunc, x) def __call__(self, z, y): z[0] = self.ufunc._f(util.cast(self.x, CT1), @@ -333,7 +333,7 @@ class _BinaryScalar2Functor: def __new__(ufunc: UF, y: Y, CT1: type, CT2: type, dtype: type, UF: type, Y: type) -> _BinaryScalar2Functor[CT1, CT2, dtype, UF, Y]: - return (ufunc, y) + return superf(ufunc, y) def __call__(self, z, x): z[0] = self.ufunc._f(util.cast(x[0], CT1), @@ -351,7 +351,7 @@ class _BinaryWhereFunctor: def __new__(ufunc: UF, CT1: type, CT2: type, dtype: type, UF: type) -> _BinaryWhereFunctor[CT1, CT2, dtype, UF]: - return (ufunc, ) + return superf(ufunc) def __call__(self, z, x, y, w): if w[0]: @@ -367,7 +367,7 @@ class UnaryUFunc: F: type def __new__(op: F, name: Static[str], F: type) -> UnaryUFunc[name, F]: - return (op, ) + return superf(op) @property def nin(self): @@ -455,7 +455,7 @@ class UnaryUFunc2: F: type def __new__(op: F, name: Static[str], F: type) -> UnaryUFunc2[name, F]: - return (op, ) + return superf(op) @property def nin(self): @@ -543,7 +543,7 @@ class BinaryUFunc: identity: I = None, F: type, I: type) -> BinaryUFunc[name, F, I]: - return (op, identity) + return superf(op, identity) @property def nin(self): @@ -645,7 +645,7 @@ class BinaryUFunc: def _reduce_all(self, array, dtype: type = NoneType, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=True): if isinstance(where, bool): @@ -712,7 +712,7 @@ class BinaryUFunc: axis=0, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=True): if not isinstance(array, ndarray): diff --git a/stdlib/openmp.codon b/stdlib/openmp.codon index 10946a39..4ec4e06f 100644 --- a/stdlib/openmp.codon +++ b/stdlib/openmp.codon @@ -585,7 +585,7 @@ def _task_loop_outline_template(gtid_ptr: Ptr[i32], btid_ptr: Ptr[i32], args): for i in iterable: priv_fixed, shared_fixed = _fix_privates_and_shareds(i, priv, shared) _spawn_and_run_task( - loc_ref, gtid, _routine_stub(P=P, S=S, ...).__raw__(), priv_fixed, shared_fixed + loc_ref, gtid, _routine_stub(P=P, S=S, ...).F.T.__raw__(), priv_fixed, shared_fixed ) finally: _taskgroup_end(loc_ref, gtid) @@ -887,8 +887,8 @@ def for_par( num_threads: int = -1, chunk_size: int = -1, schedule: Static[str] = "static", - ordered: Static[int] = False, + ordered: Static[bool] = False, collapse: Static[int] = 0, - gpu: Static[int] = False, + gpu: Static[bool] = False, ): pass diff --git a/stdlib/operator.codon b/stdlib/operator.codon index d332a001..e67d720d 100644 --- a/stdlib/operator.codon +++ b/stdlib/operator.codon @@ -175,7 +175,6 @@ def attrgetter(attr: Static[str]): return getattr(obj, attr) return getter -@overload def itemgetter(*items): if staticlen(items) == 1: item = items[0] diff --git a/stdlib/os/__init__.codon b/stdlib/os/__init__.codon index ff88df5b..1d2331f8 100644 --- a/stdlib/os/__init__.codon +++ b/stdlib/os/__init__.codon @@ -12,7 +12,7 @@ class EnvMap: _map: Dict[str, str] def __new__() -> EnvMap: - return (Dict[str, str](),) + return EnvMap(Dict[str, str]()) def _init_if_needed(self): if len(self._map) == 0: diff --git a/stdlib/threading.codon b/stdlib/threading.codon index 09c7e89e..ed8b892a 100644 --- a/stdlib/threading.codon +++ b/stdlib/threading.codon @@ -5,7 +5,7 @@ class Lock: p: cobj def __new__() -> Lock: - return (_C.seq_lock_new(),) + return Lock(_C.seq_lock_new(),) def acquire(self, block: bool = True, timeout: float = -1.0) -> bool: if timeout >= 0.0 and not block: @@ -26,7 +26,7 @@ class RLock: p: cobj def __new__() -> RLock: - return (_C.seq_rlock_new(),) + return RLock(_C.seq_rlock_new(),) def acquire(self, block: bool = True, timeout: float = -1.0) -> bool: if timeout >= 0.0 and not block: diff --git a/stdlib/time.codon b/stdlib/time.codon index 40ce05d1..8623a6d0 100644 --- a/stdlib/time.codon +++ b/stdlib/time.codon @@ -94,15 +94,15 @@ class struct_time: return x def __new__( - year: int, - mon: int, - mday: int, - hour: int, - min: int, - sec: int, - wday: int, - yday: int, - isdst: int, + year: int = 0, + mon: int = 0, + mday: int = 0, + hour: int = 0, + min: int = 0, + sec: int = 0, + wday: int = 0, + yday: int = 0, + isdst: int = 0, ) -> struct_time: return struct_time( i16(year - 1900), diff --git a/test/core/arguments.codon b/test/core/arguments.codon index 7f6c63d1..00839f95 100644 --- a/test/core/arguments.codon +++ b/test/core/arguments.codon @@ -140,13 +140,13 @@ class C: b: tuple[int,int,int] def __new__(y: float) -> C: - return (y, (0, 0, 0)) + return C(y, (0, 0, 0)) def __new__(y: float, foo: int) -> C: - return (y, (foo, 1, 0)) + return C(y, (foo, 1, 0)) def __new__(x: list[float], a: int, b: int, c: int) -> C: - return (x[0], (a, b, c)) + return C(x[0], (a, b, c)) def __new__(a: int, b: int): return (a, b) diff --git a/test/core/bltin.codon b/test/core/bltin.codon index 2d652a33..9ac8d7a9 100644 --- a/test/core/bltin.codon +++ b/test/core/bltin.codon @@ -1095,13 +1095,15 @@ test_wide_int_str(Int[200]) test_wide_int_str(Int[256]) test_wide_int_str(Int[512]) test_wide_int_str(Int[1024]) -test_wide_int_str(Int[2048]) -test_wide_int_str(Int[4096]) test_wide_uint_str(UInt[128]) test_wide_uint_str(UInt[200]) test_wide_uint_str(UInt[256]) test_wide_uint_str(UInt[512]) test_wide_uint_str(UInt[1024]) -test_wide_uint_str(UInt[2048]) -test_wide_uint_str(UInt[4096]) + +# These take ages [80+ sec] on LLVM 17 to generate +# test_wide_int_str(Int[2048]) +# test_wide_int_str(Int[4096]) +# test_wide_uint_str(UInt[2048]) +# test_wide_uint_str(UInt[4096]) diff --git a/test/core/containers.codon b/test/core/containers.codon index f85d1177..93878a21 100644 --- a/test/core/containers.codon +++ b/test/core/containers.codon @@ -1062,7 +1062,7 @@ def test_counter(): assert exp == got assert repr(Counter('abcabc')) == "Counter({'a': 2, 'b': 2, 'c': 2})" -test_counter() +test_counter() # this call doubles compile time! @test def test_defaultdict(): diff --git a/test/core/llvmops.codon b/test/core/llvmops.codon deleted file mode 100644 index 85054d10..00000000 --- a/test/core/llvmops.codon +++ /dev/null @@ -1,185 +0,0 @@ -from core.llvm import * - -@test -def test_int_llvm_ops(): - assert add_int(42, 99) == 141 - assert add_int(-10, 10) == 0 - assert sub_int(12, 6) == 6 - assert sub_int(5, -5) == 10 - assert mul_int(22, 33) == 726 - assert mul_int(-3, 3) == -9 - assert div_int(10, 2) == 5 - assert div_int(10, 3) == 3 - assert div_int(10, -3) == -3 - assert mod_int(10, 2) == 0 - assert mod_int(10, 3) == 1 - - assert eq_int(42, 42) - assert not eq_int(10, -10) - assert ne_int(0, 1) - assert not ne_int(-3, -3) - - assert lt_int(2, 3) - assert not lt_int(3, 2) - assert not lt_int(3, 3) - - assert not gt_int(2, 3) - assert gt_int(3, 2) - assert not gt_int(3, 3) - - assert le_int(2, 3) - assert not le_int(3, 2) - assert le_int(3, 3) - - assert not ge_int(2, 3) - assert ge_int(3, 2) - assert ge_int(3, 3) - - assert inv_int(0b1010) == -11 - assert and_int(0b1010, 0b1101) == 0b1000 - assert or_int(0b1010, 0b1101) == 0b1111 - assert xor_int(0b1010, 0b1101) == 0b0111 - assert shr_int(0b1010, 3) == 0b1 - assert shl_int(0b1010, 3) == 0b1010000 - - assert bitreverse_int(0b0111110111010110001001001000001010001000110001010110001101001101) == 0b1011001011000110101000110001000101000001001001000110101110111110 - assert bitreverse_int(0) == 0 - assert bitreverse_int(0xffffffffffffffff) == 0xffffffffffffffff - assert ctpop_int(0x7dd6248288c5634d) == 29 - assert ctpop_int(0) == 0 - assert ctpop_int(0xffffffffffffffff) == 64 - assert bswap_int(0x7dd6248288c5634d) == 0x4d63c5888224d67d - assert bswap_int(0) == 0 - assert bswap_int(0xffffffffffffffff) == 0xffffffffffffffff - assert ctlz_int(0b0001110111010110001001001000001010001000110001010110001101001101) == 3 - assert ctlz_int(0b0011110111010110001001001000001010001000110001010110001101001101) == 2 - assert ctlz_int(0b0111110111010110001001001000001010001000110001010110001101001101) == 1 - assert ctlz_int(0b1111110111010110001001001000001010001000110001010110001101001101) == 0 - assert ctlz_int(0) == 64 - assert cttz_int(0b0001110111010110001001001000001010001000110001010110001101001000) == 3 - assert cttz_int(0b0001110111010110001001001000001010001000110001010110001101001100) == 2 - assert cttz_int(0b0001110111010110001001001000001010001000110001010110001101001110) == 1 - assert cttz_int(0b0001110111010110001001001000001010001000110001010110001101001111) == 0 - assert cttz_int(0) == 64 - -@test -def test_float_llvm_ops(): - def approx_eq(a: float, b: float, thresh: float = 1e-10): - return -thresh <= a - b <= thresh - PI = 3.1415926535897932384626433832795028841971693993751058209749445923078164062 - E = 2.718281828459045235360287471352662497757247093699959574966967627724076630353 - - assert add_float(42., 99.) == 141. - assert add_float(-10., 10.) == 0. - assert sub_float(12., 6.) == 6. - assert sub_float(5., -5.) == 10. - assert mul_float(22., 33.) == 726. - assert mul_float(-3., 3.) == -9. - assert div_float(10., 2.) == 5. - assert div_float(10., 4.) == 2.5 - assert div_float(10., -2.5) == -4. - assert mod_float(10., 2.) == 0. - assert mod_float(10., 3.) == 1. - - assert eq_float(42., 42.) - assert not eq_float(10., -10.) - assert ne_float(0., 1.) - assert not ne_float(-3., -3.) - - assert lt_float(2., 3.) - assert not lt_float(3., 2.) - assert not lt_float(3., 3.) - - assert not gt_float(2., 3.) - assert gt_float(3., 2.) - assert not gt_float(3., 3.) - - assert le_float(2., 3.) - assert not le_float(3., 2.) - assert le_float(3., 3.) - - assert not ge_float(2., 3.) - assert ge_float(3., 2.) - assert ge_float(3., 3.) - - assert pow_float(10., 2.) == 100. - assert sqrt_float(100.) == 10. - assert sin_float(0.0) == 0. - assert sin_float(PI/2) == 1. - assert approx_eq(sin_float(PI), 0.) - assert cos_float(0.0) == 1. - assert approx_eq(cos_float(PI/2), 0.) - assert cos_float(PI) == -1. - assert exp_float(0.) == 1. - assert exp_float(1.) == E - assert exp2_float(0.) == 1. - assert exp2_float(1.) == 2. - assert log_float(1.) == 0. - assert log_float(E) == 1. - assert log10_float(1.) == 0. - assert log10_float(10.) == 1. - assert log2_float(1.) == 0. - assert log2_float(2.) == 1. - assert abs_float(1.5) == 1.5 - assert abs_float(-1.5) == 1.5 - assert pow_float(1., 0.) == 1. - assert pow_float(3., 2.) == 9. - assert pow_float(2., -2.) == 0.25 - assert pow_float(-2., 2.) == 4. - assert min_float(1., 1.) == 1. - assert min_float(-1., 1.) == -1. - assert min_float(3., 2.) == 2. - assert max_float(1., 1.) == 1. - assert max_float(-1., 1.) == 1. - assert max_float(3., 2.) == 3. - assert copysign_float(100., 1.234) == 100. - assert copysign_float(100., -1.234) == -100. - assert copysign_float(-100., 1.234) == 100. - assert copysign_float(-100., -1.234) == -100. - assert fma_float(2., 3., 4.) == 10. - - assert floor_float(1.5) == 1. - assert ceil_float(1.5) == 2. - assert trunc_float(-1.5) == -1. - assert rint_float(1.8) == 2. - assert rint_float(1.3) == 1. - assert nearbyint_float(2.3) == 2. - assert nearbyint_float(-3.8) == -4. - assert round_float(2.3) == 2. - assert round_float(-2.3) == -2. - -@test -def test_conversion_llvm_ops(): - assert int_to_float(42) == 42.0 - assert int_to_float(-100) == -100.0 - assert float_to_int(3.14) == 3 - assert float_to_int(-3.14) == -3 - -@test -def test_str_llvm_ops(): - N = 10 - p = ptr[byte](N) - for i in range(N): - p[i] = byte(i + 1) - - q = ptr[byte](N) - for i in range(N): - q[i] = byte(0) - - memcpy(q, p, N) - for i in range(10): - assert q[i] == byte(i + 1) - - memmove(p + 1, p, N - 1) - assert p[1] == p[0] - for i in range(1, N): - assert p[i] == byte(i) - - memset(p, byte(42), N) - for i in range(N): - assert p[i] == byte(42) - -test_int_llvm_ops() -test_float_llvm_ops() -test_conversion_llvm_ops() -test_str_llvm_ops() diff --git a/test/core/serialization.codon b/test/core/serialization.codon index 3e6fc835..9841ecb1 100644 --- a/test/core/serialization.codon +++ b/test/core/serialization.codon @@ -40,7 +40,7 @@ def test_pickle[T](x: T): @test def test_non_atomic_list_pickle[T](x: list[list[T]]): import gzip - copy = [copy(a) for a in x] + ncopy = [copy(a) for a in x] path = 'build/testjar.bin' jar = gzip.open(path, 'wb') pickle.dump(x, jar) @@ -53,12 +53,12 @@ def test_non_atomic_list_pickle[T](x: list[list[T]]): y = pickle.load(jar, list[list[T]]) jar.close() - assert y == copy + assert y == ncopy @test def test_non_atomic_dict_pickle[T](x: dict[str, list[T]]): import gzip - copy = {k: copy(v) for k,v in x.items()} + ncopy = {k: copy(v) for k,v in x.items()} path = 'build/testjar.bin' jar = gzip.open(path, 'wb') pickle.dump(x, jar) @@ -71,12 +71,12 @@ def test_non_atomic_dict_pickle[T](x: dict[str, list[T]]): y = pickle.load(jar, dict[str, list[T]]) jar.close() - assert y == copy + assert y == ncopy @test def test_non_atomic_set_pickle(x: set[A]): import gzip - copy = {copy(a) for a in x} + ncopy = {copy(a) for a in x} path = 'build/testjar.bin' jar = gzip.open(path, 'wb') pickle.dump(x, jar) @@ -89,7 +89,7 @@ def test_non_atomic_set_pickle(x: set[A]): y = pickle.load(jar, set[A]) jar.close() - assert y == copy + assert y == ncopy test_pickle(42) test_pickle(3.14) diff --git a/test/core/sort.codon b/test/core/sort.codon index 795369d1..8a745519 100644 --- a/test/core/sort.codon +++ b/test/core/sort.codon @@ -1,5 +1,12 @@ +from algorithms.pdqsort import pdq_sort +from algorithms.insertionsort import insertion_sort +from algorithms.heapsort import heap_sort +from algorithms.timsort import tim_sort + + MANUAL_TEST = False + def print_test[T](a: list[T], b: list[T]): if not MANUAL_TEST: print(a) @@ -22,17 +29,17 @@ def print_test[T](a: list[T], b: list[T]): ### Comparison Functions ### -def compare_less(x: int, y: int) -> bool: - return x < y +def compare_less(x: int): + return x -def compare_greater(x: int, y: int) -> bool: - return x > y +def compare_greater(x: int): + return -x -def compare_string(x: str, y: str) -> bool: - return x < y +def compare_string(x): + return x -def compare_dict(x: dict[str,int], y: dict[str,int]) -> bool: - return x["key"] < y["key"] +def compare_dict(x: dict[str,int]): + return x["key"] ### Basic Sort Tests ### print_test(insertion_sort(list[int](), compare_less), list[int]()) # EXPECT: [] diff --git a/test/main.cpp b/test/main.cpp index 26ffe455..0d70451b 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -251,18 +251,8 @@ class SeqTest : public testing::TestWithParam< string getFilename(const string &basename) { return string(TEST_DIR) + "/" + basename; } - int runInChildProcess() { - assert(pipe(out_pipe) != -1); - pid = fork(); - GC_atfork_prepare(); - assert(pid != -1); - - if (pid == 0) { - GC_atfork_child(); - dup2(out_pipe[1], STDOUT_FILENO); - close(out_pipe[0]); - close(out_pipe[1]); - + int runInChildProcess(bool avoidFork = false) { + auto fn = [this]() { auto file = getFilename(get<0>(GetParam())); bool debug = get<1>(GetParam()); auto code = get<3>(GetParam()); @@ -279,7 +269,7 @@ class SeqTest : public testing::TestWithParam< ? compiler->parseFile(file, testFlags) : compiler->parseCode(file, code, startLine, testFlags), [](const error::ParserErrorInfo &e) { - for (auto &group : e) { + for (auto &group : e.getErrors()) { for (auto &msg : group) { getLogger().level = 0; printf("%s\n", msg.getMessage().c_str()); @@ -288,7 +278,6 @@ class SeqTest : public testing::TestWithParam< fflush(stdout); exit(EXIT_FAILURE); }); - auto *pm = compiler->getPassManager(); pm->registerPass(std::make_unique()); pm->registerPass(std::make_unique()); @@ -300,11 +289,27 @@ class SeqTest : public testing::TestWithParam< ir::analyze::dataflow::DominatorAnalysis::KEY}); pm->registerPass(std::make_unique(capKey), /*insertBefore=*/"", {capKey}); - llvm::cantFail(compiler->compile()); seq_exc_init(0); compiler->getLLVMVisitor()->run({file}); fflush(stdout); + }; + // if (true) { + // fn(); + // return 0; + // } + + assert(pipe(out_pipe) != -1); + pid = fork(); + GC_atfork_prepare(); + assert(pid != -1); + + if (pid == 0) { + GC_atfork_child(); + dup2(out_pipe[1], STDOUT_FILENO); + close(out_pipe[0]); + close(out_pipe[1]); + fn(); exit(EXIT_SUCCESS); } else { GC_atfork_parent(); @@ -348,6 +353,8 @@ TEST_P(SeqTest, Run) { status = runInChildProcess(); else status = runInChildProcess(); + if (!WIFEXITED(status)) + std::cerr << result() << std::endl; ASSERT_TRUE(WIFEXITED(status)); string output = result(); @@ -365,7 +372,7 @@ TEST_P(SeqTest, Run) { vector results = splitLines(output); for (unsigned i = 0; i < min(results.size(), expects.first.size()); i++) if (expects.second) - EXPECT_EQ(results[i], expects.first[i]); + EXPECT_EQ(results[i].substr(0, expects.first[i].size()), expects.first[i]); else EXPECT_EQ(results[i], expects.first[i]); EXPECT_EQ(results.size(), expects.first.size()); @@ -383,9 +390,10 @@ auto getTypeTests(const vector &files) { int line = 0; while (getline(fin, l)) { if (l.substr(0, 3) == "#%%") { - if (line) + if (line && testName != "__ignore__") { cases.emplace_back(make_tuple(f, true, to_string(line) + "_" + testName, code, codeLine, barebones, false)); + } auto t = ast::split(l.substr(4), ','); barebones = (t.size() > 1 && t[1] == "barebones"); testName = t[0]; @@ -397,9 +405,10 @@ auto getTypeTests(const vector &files) { } line++; } - if (line) + if (line && testName != "__ignore__") { cases.emplace_back(make_tuple(f, true, to_string(line) + "_" + testName, code, codeLine, barebones, false)); + } } return cases; } @@ -408,12 +417,23 @@ auto getTypeTests(const vector &files) { INSTANTIATE_TEST_SUITE_P( TypeTests, SeqTest, testing::ValuesIn(getTypeTests({ - "parser/simplify_expr.codon", - "parser/simplify_stmt.codon", - "parser/typecheck_expr.codon", - "parser/typecheck_stmt.codon", - "parser/types.codon", - "parser/llvm.codon" + "parser/typecheck/test_access.codon", + "parser/typecheck/test_assign.codon", + "parser/typecheck/test_basic.codon", + "parser/typecheck/test_call.codon", + "parser/typecheck/test_class.codon", + "parser/typecheck/test_collections.codon", + "parser/typecheck/test_cond.codon", + "parser/typecheck/test_ctx.codon", + "parser/typecheck/test_error.codon", + "parser/typecheck/test_function.codon", + "parser/typecheck/test_import.codon", + "parser/typecheck/test_infer.codon", + "parser/typecheck/test_loops.codon", + "parser/typecheck/test_op.codon", + "parser/typecheck/test_parser.codon", + "parser/typecheck/test_python.codon", + "parser/typecheck/test_typecheck.codon" })), getTypeTestNameFromParam); @@ -465,6 +485,7 @@ INSTANTIATE_TEST_SUITE_P( StdlibTests, SeqTest, testing::Combine( testing::Values( + "stdlib/llvm_test.codon", "stdlib/str_test.codon", "stdlib/re_test.codon", "stdlib/math_test.codon", @@ -513,41 +534,41 @@ INSTANTIATE_TEST_SUITE_P( ), getTestNameFromParam); -INSTANTIATE_TEST_SUITE_P( - NumPyTests, SeqTest, - testing::Combine( - testing::Values( - "numpy/random_tests/test_mt19937.codon", - "numpy/random_tests/test_pcg64.codon", - "numpy/random_tests/test_philox.codon", - "numpy/random_tests/test_sfc64.codon", - "numpy/test_dtype.codon", - "numpy/test_fft.codon", - "numpy/test_functional.codon", - // "numpy/test_fusion.codon", // TODO: uses a lot of RAM - "numpy/test_indexing.codon", - "numpy/test_io.codon", - "numpy/test_lib.codon", - "numpy/test_linalg.codon", - "numpy/test_loops.codon", - // "numpy/test_misc.codon", // TODO: takes forever in debug mode - "numpy/test_ndmath.codon", - "numpy/test_npdatetime.codon", - "numpy/test_pybridge.codon", - "numpy/test_reductions.codon", - "numpy/test_routines.codon", - "numpy/test_sorting.codon", - "numpy/test_statistics.codon", - "numpy/test_window.codon" - ), - testing::Values(true, false), - testing::Values(""), - testing::Values(""), - testing::Values(0), - testing::Values(false), - testing::Values(false) - ), - getTestNameFromParam); +// INSTANTIATE_TEST_SUITE_P( +// NumPyTests, SeqTest, +// testing::Combine( +// testing::Values( +// "numpy/random_tests/test_mt19937.codon", +// "numpy/random_tests/test_pcg64.codon", +// "numpy/random_tests/test_philox.codon", +// "numpy/random_tests/test_sfc64.codon", +// "numpy/test_dtype.codon", +// "numpy/test_fft.codon", +// "numpy/test_functional.codon", +// // "numpy/test_fusion.codon", // TODO: uses a lot of RAM +// "numpy/test_indexing.codon", +// "numpy/test_io.codon", +// "numpy/test_lib.codon", +// "numpy/test_linalg.codon", +// "numpy/test_loops.codon", +// // "numpy/test_misc.codon", // TODO: takes forever in debug mode +// "numpy/test_ndmath.codon", +// "numpy/test_npdatetime.codon", +// "numpy/test_pybridge.codon", +// "numpy/test_reductions.codon", +// "numpy/test_routines.codon", +// "numpy/test_sorting.codon", +// "numpy/test_statistics.codon", +// "numpy/test_window.codon" +// ), +// testing::Values(true, false), +// testing::Values(""), +// testing::Values(""), +// testing::Values(0), +// testing::Values(false), +// testing::Values(false) +// ), +// getTestNameFromParam); // clang-format on diff --git a/test/numpy/test_fft.codon b/test/numpy/test_fft.codon index 1da5fc10..2d37e87e 100644 --- a/test/numpy/test_fft.codon +++ b/test/numpy/test_fft.codon @@ -449,7 +449,7 @@ def test_fftn_out_argument(dtype: type, transpose: bool, axes): assert np.array_equal(result2, expected2) @test -def test_fftn_out_and_s_interaction(fft, rfftn: Static[int]): +def test_fftn_out_and_s_interaction(fft, rfftn: Static[bool]): # With s, shape varies, so generally one cannot pass in out. gen = rnd.default_rng(seed=20) if rfftn: diff --git a/test/numpy/test_lib.codon b/test/numpy/test_lib.codon index eaafe9d5..c6647787 100644 --- a/test/numpy/test_lib.codon +++ b/test/numpy/test_lib.codon @@ -29,9 +29,9 @@ test_sliding_window_view(np.array([[0, 1, 2, 3], [10, 11, 12, 13], @test def test_unique(ar, expected, - return_index: Static[int] = False, - return_inverse: Static[int] = False, - return_counts: Static[int] = False, + return_index: Static[bool] = False, + return_inverse: Static[bool] = False, + return_counts: Static[bool] = False, axis=None, equal_nan: bool = True): if return_index or return_counts or return_inverse: @@ -97,7 +97,7 @@ def test_intersect1d(ar1, ar2, expected, assume_unique: bool = False, - return_indices: Static[int] = False): + return_indices: Static[bool] = False): if return_indices: u = np.intersect1d(ar1, ar2, assume_unique, return_indices) for i in range(len(expected)): diff --git a/test/numpy/test_linalg.codon b/test/numpy/test_linalg.codon index 7e896515..2d5a5cb0 100644 --- a/test/numpy/test_linalg.codon +++ b/test/numpy/test_linalg.codon @@ -2029,10 +2029,10 @@ def TestDet_test_zero(): assert (isinstance(alg.det([[0.0]]), double)) assert (alg.det([[0.0j]]) == 0.0) assert (isinstance(alg.det([[0.0j]]), cdouble)) - assert (alg.slogdet([[0.0]]) == (0.0, -inf)) + assert (tuple(alg.slogdet([[0.0]])) == (0.0, -inf)) assert (isinstance(alg.slogdet([[0.0]])[0], double)) assert (isinstance(alg.slogdet([[0.0]])[1], double)) - assert (alg.slogdet([[0.0j]]) == (0.0j, -inf)) + assert (tuple(alg.slogdet([[0.0j]])) == (0.0j, -inf)) assert (isinstance(alg.slogdet([[0.0j]])[0], cdouble)) assert (isinstance(alg.slogdet([[0.0j]])[1], double)) @@ -2288,7 +2288,7 @@ test_qr( @test def test_svd(a, full_matrices: bool = True, - compute_uv: Static[int] = True, + compute_uv: Static[bool] = True, hermitian: bool = False, expected=0): if hermitian: diff --git a/test/numpy/test_npdatetime.codon b/test/numpy/test_npdatetime.codon index 794a6dd0..2a46aec5 100644 --- a/test/numpy/test_npdatetime.codon +++ b/test/numpy/test_npdatetime.codon @@ -2004,7 +2004,7 @@ def test_datetime_arange(): 22, np.timedelta64(2, 'D'), dtype=np.datetime64['D', 1]) - assert array_equal(a.dtype, np.datetime64['D', 1]) + assert isinstance(a.dtype, np.datetime64['D', 1]) assert array_equal( a, np.array('1969-12-19', dtype=np.datetime64['D', 1]) + @@ -2021,7 +2021,7 @@ test_datetime_arange() @test def test_timedelta_arange(): a = np.arange(3, 10, dtype=np.timedelta64['generic', 1]) - assert array_equal(a.dtype, np.timedelta64['generic', 1]) + assert isinstance(a.dtype, np.timedelta64['generic', 1]) assert array_equal( a, np.array(0, dtype=np.timedelta64['generic', 1]) + np.arange(3, 10)) @@ -2029,7 +2029,7 @@ def test_timedelta_arange(): 10, 2, dtype=np.timedelta64['generic', 1]) - assert array_equal(a.dtype, np.timedelta64['s', 1]) + assert isinstance(a.dtype, np.timedelta64['s', 1]) assert array_equal( a, np.array(0, dtype=np.timedelta64['s', 1]) + np.arange(3, 10, 2)) diff --git a/test/numpy/test_reductions.codon b/test/numpy/test_reductions.codon index 4265f791..3e35f653 100644 --- a/test/numpy/test_reductions.codon +++ b/test/numpy/test_reductions.codon @@ -513,7 +513,7 @@ def test_sum(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=0, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): @@ -562,7 +562,7 @@ def test_prod(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=1, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): @@ -609,7 +609,7 @@ def test_mean(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float) or isinstance( expected, complex): @@ -658,7 +658,7 @@ def test_nanmean(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): assert np.nanmean(a, @@ -707,7 +707,7 @@ def test_var(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float) or isinstance( expected, complex): @@ -753,7 +753,7 @@ def test_nanvar(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): assert np.nanvar(a, @@ -801,7 +801,7 @@ def test_std(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float) or isinstance( expected, complex): @@ -852,7 +852,7 @@ def test_nanstd(a, axis=None, dtype: type = NoneType, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): assert np.nanstd(a, @@ -901,7 +901,7 @@ def test_min(a, expected, axis=None, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): @@ -947,7 +947,7 @@ def test_max(a, expected, axis=None, out=None, - keepdims: Static[int] = False, + keepdims: Static[bool] = False, initial=util._NoValue(), where=util._NoValue()): if isinstance(expected, int) or isinstance(expected, float): @@ -987,7 +987,7 @@ test_max(np.array([[1 + 2j, 3 + 4j, 5 + 6j], [7 + 8j, 9 + 10j, 11 + 12j]]), axis=1) @test -def test_ptp(a, expected, axis=None, out=None, keepdims: Static[int] = False): +def test_ptp(a, expected, axis=None, out=None, keepdims: Static[bool] = False): if isinstance(expected, int) or isinstance(expected, float) or isinstance( expected, complex): assert np.ptp(a, axis=axis, out=out, keepdims=keepdims) == expected @@ -1021,7 +1021,7 @@ def test_argmin(a, expected, axis=None, out=None, - keepdims: Static[int] = False): + keepdims: Static[bool] = False): if isinstance(expected, int): assert np.argmin(a, axis=axis, out=out, keepdims=keepdims) == expected else: @@ -1048,7 +1048,7 @@ def test_argmax(a, expected, axis=None, out=None, - keepdims: Static[int] = False): + keepdims: Static[bool] = False): if isinstance(expected, int): assert np.argmax(a, axis=axis, out=out, keepdims=keepdims) == expected else: @@ -1096,7 +1096,7 @@ def test_all(): test_all() @test -def test_count_nonzero(a, expected, axis=None, keepdims: Static[int] = False): +def test_count_nonzero(a, expected, axis=None, keepdims: Static[bool] = False): if isinstance(expected, int) or isinstance(expected, float): assert np.count_nonzero(a, axis=axis, keepdims=keepdims) == expected else: @@ -1122,7 +1122,7 @@ def test_median(a, axis=None, out=None, overwrite_input: bool = False, - keepdims: Static[int] = False): + keepdims: Static[bool] = False): if isinstance(expected, int) or isinstance(expected, float) or isinstance( expected, complex): if np.isnan(expected): @@ -1285,7 +1285,7 @@ def test_quantile(a, out=None, overwrite_input=False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if isinstance(expected, float): assert np.quantile(a, @@ -1344,7 +1344,7 @@ def test_percentile(a, out=None, overwrite_input=False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if isinstance(expected, float): assert np.percentile(a, @@ -1521,7 +1521,7 @@ def test_nanpercentile(a, out=None, overwrite_input=False, method: str = "linear", - keepdims: Static[int] = False, + keepdims: Static[bool] = False, interpolation=None): if isinstance(expected, float): assert np.nanpercentile(a, diff --git a/test/numpy/test_routines.codon b/test/numpy/test_routines.codon index 8411ef83..046ccbc0 100644 --- a/test/numpy/test_routines.codon +++ b/test/numpy/test_routines.codon @@ -140,7 +140,7 @@ def test_linspace(start, expected, num: int = 50, endpoint: bool = True, - retstep: Static[int] = False, + retstep: Static[bool] = False, dtype: type = float): assert (np.linspace(start, stop, num, endpoint, retstep, dtype) == expected).all() @@ -159,7 +159,7 @@ def test_logspace(start, num: int = 50, endpoint: bool = True, base: float = 10.0, - retstep: Static[int] = False, + retstep: Static[bool] = False, dtype: type = float): assert (round( np.logspace(start, stop, num, endpoint, base, retstep, dtype), @@ -959,7 +959,7 @@ test_rot90( def test_indices(dimensions, expected, dtype: type = int, - sparse: Static[int] = False): + sparse: Static[bool] = False): assert (np.indices(dimensions, dtype, sparse) == expected).all() test_indices((2, 3), np.array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, @@ -1359,7 +1359,7 @@ test_extract( np.mod(np.arange(12).reshape((3, 4)), 3) == 0, np.array([0, 3, 6, 9])) @test -def test_count_nonzero(a, expected, axis=None, keepdims: Static[int] = False): +def test_count_nonzero(a, expected, axis=None, keepdims: Static[bool] = False): x = np.count_nonzero(a, axis=axis, keepdims=keepdims) if isinstance(x, int): assert x == expected diff --git a/test/numpy/test_statistics.codon b/test/numpy/test_statistics.codon index 69a5270f..ce6ada2d 100644 --- a/test/numpy/test_statistics.codon +++ b/test/numpy/test_statistics.codon @@ -5,8 +5,8 @@ def test_average(a, expected, axis=None, weights=None, - returned: Static[int] = False, - keepdims: Static[int] = False): + returned: Static[bool] = False, + keepdims: Static[bool] = False): if isinstance(expected, int) or isinstance(expected, float): assert np.average(a, axis=axis, @@ -542,7 +542,7 @@ def test_histogram(a, expected_edges, bins=10, range=None, - density: Static[int] = False, + density: Static[bool] = False, weights=None): hist, bin_edges = np.histogram(a, bins=bins, diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index a79523ac..e69de29b 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -1,552 +0,0 @@ -#%% none,barebones -@extend -class Optional: - def __repr__(self): - return 'OPTIONAL: ' + ('-' if self is None else self.__val__().__repr__()) - def __str__(self): - return 'OPTIONAL: ' + ('-' if self is None else self.__val__().__repr__()) - -a = None -print a #: OPTIONAL: - -if True: - a = 5 -print a #: OPTIONAL: 5 - -#%% bool,barebones -print True, False #: True False - -#%% int,barebones -print 0b0000_1111 #: 15 -print 0B101 #: 5 -print 3 #: 3 -print 18_446_744_073_709_551_000 #: -616 -print 0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111 #: -1 -print 0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111u #: 18446744073709551615 -print 18_446_744_073_709_551_000u #: 18446744073709551000 -print 65i7 #: -63 -print -1u7 #: 127 - -@extend -class int: - def __suffix_test__(s): - return 'TEST: ' + str(s) -print 123_456test #: TEST: 123456 - -#%% int_error,barebones -print 1844674407_3709551999 #! integer '18446744073709551999' cannot fit into 64-bit integer - -#%% float,barebones -print 5.15 #: 5.15 -print 2e2 #: 200 -print 2.e-2 #: 0.02 -print 1_000.0 #: 1000 -print 1_000e9 #: 1e+12 - -#%% float_suffix,barebones -@extend -class float: - def __suffix_zoo__(x): - return str(x) + '_zoo' - -print 1.2e-1zoo #: 0.12_zoo - -#%% string,barebones -print 'kthxbai', "kthxbai" #: kthxbai kthxbai -print """hi -hello""", '''hai -hallo''' -#: hi -#: hello hai -#: hallo - -#%% fstring,barebones -a, b = 1, 2 -print f"string {a}" #: string 1 -print F"{b} string" #: 2 string -print f"str {a+b} end" #: str 3 end -print f"str {a+b=}" #: str a+b=3 -c = f'and this is {a} followed by {b}' -print c, f'{b}{a}', f'. {1+a=} .. {b} ...' #: and this is 1 followed by 2 21 . 1+a=2 .. 2 ... - -#%% fstring_error,barebones -f"a{b + 3}}" #! single '}' is not allowed in f-string - -#%% fstring_error_2,barebones -f"a{{b + 3}" #! expecting '}' in f-string - -#%% prefix_str,barebones -@extend -class str: - def __prefix_pfx__[N: Static[int]](s: str): - return 'PFX ' + s -print pfx'HELLO' #: PFX HELLO - -@extend -class str: - def __prefix_pxf__(s: str, N: Static[int]): - return 'PXF ' + s + " " + str(N) -print pxf'HELLO' #: PXF HELLO 5 - -#%% raw_str,barebones -print 'a\\b' #: a\b -print r'a\tb' #: a\tb -print R'\n\r\t\\' #: \n\r\t\\ - -#%% id_fstring_error,barebones -f"a{b + 3}" #! name 'b' is not defined - -#%% id_access,barebones -def foo(): - a = 5 - def bar(): - print a - bar() #: 5 - a = 4 - bar() #: 5 -foo() - -z = {} -def fox(): - a = 5 - def goo(): - z['x'] = 'y' - print a - return goo -fox()() -print z -#: 5 -#: {'x': 'y'} - - -#%% star_err,barebones -a = (1, 2, 3) -z = *a #! unexpected star expression - -#%% list,barebones -a = [4, 5, 6] -print a #: [4, 5, 6] -b = [1, 2, 3, *a] -print b #: [1, 2, 3, 4, 5, 6] - -#%% set,barebones -gs = {1.12} -print gs #: {1.12} -fs = {1, 2, 3, 1, 2, 3} -gs.add(1.12) -gs.add(1.13) -print fs, gs #: {1, 2, 3} {1.12, 1.13} -print {*fs, 5, *fs} #: {1, 2, 3, 5} - -#%% dict,barebones -gd = {1: 'jedan', 2: 'dva', 2: 'two', 3: 'tri'} -fd = {} -fd['jedan'] = 1 -fd['dva'] = 2 -print gd, fd #: {1: 'jedan', 2: 'two', 3: 'tri'} {'jedan': 1, 'dva': 2} - -#%% comprehension,barebones -l = [(i, j, f'i{i}/{j}') - for i in range(50) if i % 2 == 0 if i % 3 == 0 - for j in range(2) if j == 1] -print l #: [(0, 1, 'i0/1'), (6, 1, 'i6/1'), (12, 1, 'i12/1'), (18, 1, 'i18/1'), (24, 1, 'i24/1'), (30, 1, 'i30/1'), (36, 1, 'i36/1'), (42, 1, 'i42/1'), (48, 1, 'i48/1')] - -s = {i%3 for i in range(20)} -print s #: {0, 1, 2} - -d = {i: j for i in range(10) if i < 1 for j in range(10)} -print d #: {0: 9} - -x = {t: lambda x: x * t for t in range(5)} -print(x[3](10)) #: 30 - -#%% comprehension_opt,barebones -@extend -class List: - def __init__(self, cap: int): - print 'optimize', cap - self.arr = Array[T](cap) - self.len = 0 -def foo(): - yield 0 - yield 1 - yield 2 -print [i for i in range(3)] #: optimize 3 -#: [0, 1, 2] -print [i for i in foo()] #: [0, 1, 2] -print [i for i in range(3) if i%2 == 0] #: [0, 2] -print [i + j for i in range(1) for j in range(1)] #: [0] -print {i for i in range(3)} #: {0, 1, 2} - -#%% comprehension_opt_clone -import sys -z = [i for i in sys.argv] - -#%% generator,barebones -z = 3 -g = (e for e in range(20) if e % z == 1) -print str(g)[:13] #: = a >= -5) #: True False - -#%% if,barebones -c = 5 -a = 1 if c < 5 else 2 -b = -(1 if c else 2) -print a, b #: 2 -1 - -#%% unary,barebones -a, b = False, 1 -print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1 - -#%% binary,barebones -x, y = 1, 0 -c = [1, 2, 3] - -print x and y, x or y #: False True -print x in c, x not in c #: True False -print c is c, c is not c #: True False - -z: Optional[int] = None -print z is None, None is z, None is not z, None is None #: True True False True - -#%% chain_binary,barebones -def foo(): - print 'foo' - return 15 -a = b = c = foo() #: foo -print a, b, c #: 15 15 15 - -x = y = [] -x.append(1) -print x, y #: [1] [1] - -print 1 <= foo() <= 10 #: foo -#: False -print 15 >= foo()+1 < 30 > 20 > foo() -#: foo -#: False -print 15 >= foo()-1 < 30 > 20 > foo() -#: foo -#: foo -#: True - -print True == (b == 15) #: True - -#%% pipe_error,barebones -def b(a, b, c, d): - pass -1 |> b(1, ..., 2, ...) #! multiple ellipsis expressions - -#%% index_normal,barebones -t: tuple[int, int] = (1, 2) -print t #: (1, 2) - -tt: Tuple[int] = (1, ) -print tt #: (1,) - -def foo(i: int) -> int: - return i + 1 -f: Callable[[int], int] = foo -print f(1) #: 2 -fx: function[[int], int] = foo -print fx(2) #: 3 -fxx: Function[[int], int] = foo -print fxx(3) #: 4 - -#%% index_special,barebones -class Foo: - def __getitem__(self, foo): - print foo -f = Foo() -f[0,0] #: (0, 0) -f[0,:] #: (0, slice(None, None, None)) -f[:,:] #: (slice(None, None, None), slice(None, None, None)) -f[:,0] #: (slice(None, None, None), 0) - -#%% index_error,barebones -Ptr[9.99] #! expected type expression - -#%% index_error_b,barebones -Ptr['s'] #! ''s'' does not match expected type 'T' - -#%% index_error_static,barebones -Ptr[1] #! '1' does not match expected type 'T' - -#%% index_error_2,barebones -Ptr[int, 's'] #! Ptr takes 1 generics (2 given) - -#%% index_error_3,barebones -Ptr[1, 's'] #! Ptr takes 1 generics (2 given) - -#%% call_ptr,barebones -v = 5 -p = __ptr__(v) -print p[0] #: 5 - -#%% call_ptr_error,barebones -__ptr__(1) #! __ptr__() only takes identifiers as arguments - -#%% call_ptr_error_3,barebones -v = 1 -__ptr__(v, 1) #! __ptr__() takes 1 arguments (2 given) - -#%% call_array,barebones -a = __array__[int](2) -a[0] = a[1] = 5 -print a[0], a[1] #: 5 5 - -#%% call_array_error,barebones -a = __array__[int](2, 3) #! '__array__[int]' object has no method '__new__' with arguments (int, int) - -#%% call_err_1,barebones -seq_print(1, name="56", 2) #! positional argument follows keyword argument - -#%% call_err_2,barebones -x = (1, 2) -seq_print(1, name=*x) #! syntax error, unexpected '*' - -#%% call_err_3,barebones -x = (1, 2) -seq_print(1, name=**x) #! syntax error, unexpected '*' - -#%% call_collections -from collections import namedtuple as nt - -ee = nt('Foo', ['x', 'y']) -f = ee(1, 2) -print f #: (x: 1, y: 2) - -ee = nt('FooX', [('x', str), 'y']) -fd = ee('s', 2) -print fd #: (x: 's', y: 2) - -#%% call_partial_functools -from functools import partial -def foo(x, y, z): - print x,y,z -f1 = partial(foo, 1, z=3) -f1(2) #: 1 2 3 -f2 = partial(foo, y=2) -f2(1, 2) #: 1 2 2 - -#%% lambda,barebones -l = lambda a, b: a + b -print l(1, 2) #: 3 - -e = 5 -lp = lambda x: x + e -print lp(1) #: 6 - -e = 7 -print lp(2) #: 9 - -def foo[T](a: T, l: Callable[[T], T]): - return l(a) -print foo(4, lp) #: 11 - -def foox(a, l): - return l(a) -print foox(4, lp) #: 11 - -#%% nested_lambda,barebones -def foo(): - print list(a*a for a in range(3)) -foo() #: [0, 1, 4] - -#%% walrus,barebones -def foo(x): - return x * x -if x := foo(4): - pass -if (x := foo(4)) and False: - print 'Nope' -print x #: 16 - -a = [y := foo(1), y+1, y+2] -print a #: [1, 2, 3] - -print {y: b for y in [1,2,3] if (b := (y - 1))} #: {2: 1, 3: 2} -print list(b for y in [1,2,3] if (b := (y // 3))) #: [1] - -#%% walrus_update,barebones -def foo(x): - return x * x -x = 5 -if x := foo(4): - pass -print x #: 16 - -#%% walrus_cond_1,barebones -def foo(x): - return x * x -if False or (x := foo(4)): - pass -print(x) #: 16 - -y = (z := foo(5)) if True else 0 -print(z) #: 25 - -#%% walrus_err,barebones -def foo(x): - return x * x -if False and (x := foo(4)): - pass -try: - print(x) -except NameError: - print("Error") #: Error - -t = True -y = 0 if t else (z := foo(4)) -try: - print(z) -except NameError: - print("Error") #: Error - -#%% range_err,barebones -1 ... 3 #! unexpected range expression - -#%% callable_error,barebones -def foo(x: Callable[[]]): pass #! Callable takes 2 generics (1 given) - -#%% unpack_specials,barebones -x, = 1, -print x #: 1 - -a = (2, 3) -b = (1, *a[1:]) -print a, b #: (2, 3) (1, 3) - -#%% nonlocal,barebones -def goo(ww): - z = 0 - def foo(x): - f = 10 - def bar(y): - nonlocal z - f = x + y - z += y - print('goo.foo.bar', f, z) - bar(5) - print('goo.foo', f) - return bar - b = foo(10) - print('goo', z) - return b -b = goo('s') -# goo.foo.bar 15 5 -# goo.foo 10 -# goo 5 -b(11) -# goo.foo.bar 21 16 -b(12) -# goo.foo.bar 22 28 -b = goo(1) # test another instantiation -# goo.foo.bar 15 5 -# goo.foo 10 -# goo 5 -b(11) -# goo.foo.bar 21 16 -b(13) -# goo.foo.bar 23 29 - -#%% nonlocal_error,barebones -def goo(): - z = 0 - def foo(): - z += 1 -goo() #! local variable 'z' referenced before assignment - -#%% new_scoping,barebones -try: - if True and (x := (True or (y := 1 + 2))): - pass - try: - print(x) #: True - print(y) - except NameError: - print("Error") #: Error - print(x) #: True - if len("s") > 0: - print(x) #: True - print(y) - print(y) # TODO: test for __used__ usage - print(y) # (right now manual inspection is needed) -except NameError as e: - print(e.message) #: variable 'y' not yet defined - -t = True -y = 0 if t else (xx := 1) -try: - print(xx) -except NameError: - print("Error") #: Error - -#%% new_scoping_weird,barebones -def foo(): - if len("s") == 3: - x = 3 - def bar(y): - print(x+y) - x=5 - return bar -try: - f = foo() - f(5) -except NameError: - print('error') #: error - # TODO: Python works here. - # Need to capture these vars conditionally? - -#%% new_scoping_loops_try,barebones -for i in range(10): - pass -print(i) #: 9 - -j = 6 -for j in range(0): - pass -print(j) #: 6 - -for j in range(1): - pass -print(j) #: 0 - -z = 6 -for z in []: - pass -print(z) #: 6 - -for z in [1, 2]: - pass -print(z) #: 2 - -try: - raise ValueError("hi") -except ValueError as e: - pass -print(e.message) #: hi - -try: - pass -except ValueError as f: - pass -try: - print(f.message) -except NameError: - print('error') #: error diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index 2b60d2fb..e69de29b 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -1,1328 +0,0 @@ -#%% pass,barebones -pass - -#%% continue_error,barebones -continue #! 'continue' outside loop - -#%% break_error,barebones -break #! 'break' outside loop - -#%% assign,barebones -a = 1 -print a #: 1 -a = 2 -print a #: 2 - -x, y = 1, 2 -print x, y #: 1 2 -(x, y) = (3, 4) -print x, y #: 3 4 -x, y = (1, 2) -print x, y #: 1 2 -(x, y) = 3, 4 -print x, y #: 3 4 -(x, y) = [3, 4] -print x, y #: 3 4 -[x, y] = [1, 2] -print x, y #: 1 2 -[x, y] = (4, 3) -print x, y #: 4 3 - -l = list(iter(range(10))) -[a, b, *lx, c, d] = l -print a, b, lx, c, d #: 0 1 [2, 3, 4, 5, 6, 7] 8 9 -a, b, *lx = l -print a, b, lx #: 0 1 [2, 3, 4, 5, 6, 7, 8, 9] -*lx, a, b = l -print lx, a, b #: [0, 1, 2, 3, 4, 5, 6, 7] 8 9 -*xz, a, b = (1, 2, 3, 4, 5) -print xz, a, b #: (1, 2, 3) 4 5 -(*ex,) = [1, 2, 3] -print ex #: [1, 2, 3] - -#%% assign_str,barebones -sa, sb = 'XY' -print sa, sb #: X Y -(sa, sb), sc = 'XY', 'Z' -print sa, sb, sc #: X Y Z -sa, *la = 'X' -print sa, la, 1 #: X 1 -sa, *la = 'XYZ' -print sa, la #: X YZ -(xa,xb), *xc, xd = [1,2],'this' -print xa, xb, xc, xd #: 1 2 () this -(a, b), (sc, *sl) = [1,2], 'this' -print a, b, sc, sl #: 1 2 t his - -#%% assign_index_dot,barebones -class Foo: - a: int - def __setitem__(self, i: int, t: int): - self.a += i * t -f = Foo() -f.a = 5 -print f.a #: 5 -f[3] = 5 -print f.a #: 20 -f[1] = -8 -print f.a #: 12 - - -def foo(): - print('foo') - return 0 -v = [0] -v[foo()] += 1 -#: foo -print(v) -#: [1] - -#%% assign_err_1,barebones -a, *b, c, *d = 1,2,3,4,5 #! multiple starred expressions in assignment - -#%% assign_err_2,barebones -a = [1, 2, 3] -a[1]: int = 3 #! syntax error, unexpected ':' - -#%% assign_err_3,barebones -a = 5 -a.x: int = 3 #! syntax error, unexpected ':' - -#%% assign_err_4,barebones -*x = range(5) #! cannot assign to given expression - -#%% assign_err_5,barebones -try: - (sa, sb), sc = 'XYZ' -except IndexError: - print "assign failed" #: assign failed - -#%% assign_comprehension,barebones -g = ((b, a, c) for a, *b, c in ['ABC','DEEEEF','FHGIJ']) -x, *q, y = list(g) # TODO: auto-unroll as in Python -print x, y, q #: ('B', 'A', 'C') ('HGI', 'F', 'J') [('EEEE', 'D', 'F')] - -#%% assign_shadow,barebones -a = 5 -print a #: 5 -a : str = 's' -print a #: s - -#%% assign_err_must_exist,barebones -a = 1 -def foo(): - a += 2 #! local variable 'a' referenced before assignment - -#%% assign_rename,barebones -y = int -z = y(5) -print z #: 5 - -def foo(x): return x + 1 -x = foo -print x(1) #: 2 - -#%% assign_err_6,barebones -x = bar #! name 'bar' is not defined - -#%% assign_err_7,barebones -foo() += bar #! cannot assign to given expression - -#%% assign_update_eq,barebones -a = 5 -a += 3 -print a #: 8 -a -= 1 -print a #: 7 - -class Foo: - a: int - def __add__(self, i: int): - print 'add!' - return Foo(self.a + i) - def __iadd__(self, i: int): - print 'iadd!' - self.a += i - return self - def __str__(self): - return str(self.a) -f = Foo(3) -print f + 2 #: add! -#: 5 -f += 6 #: iadd! -print f #: 9 - -#%% del,barebones -a = 5 -del a -print a #! name 'a' is not defined - -#%% del_index,barebones -y = [1, 2] -del y[0] -print y #: [2] - -#%% del_error,barebones -a = [1] -del a.ptr #! cannot delete given expression - -#%% assert,barebones -assert True -assert True, "blah" - -try: - assert False -except AssertionError as e: - print e.message[:15], e.message[-24:] #: Assert failed ( simplify_stmt.codon:174) - -try: - assert False, f"hehe {1}" -except AssertionError as e: - print e.message[:23], e.message[-24:] #: Assert failed: hehe 1 ( simplify_stmt.codon:179) - -#%% print,barebones -print 1, -print 1, 2 #: 1 1 2 - -print 1, 2 #: 1 2 -print(3, "4", sep="-", end=" !\n") #: 3-4 ! - -print(1, 2) #: 1 2 -print (1, 2) #: (1, 2) - -def foo(i, j): - return i + j -print 3 |> foo(1) #: 4 - -#%% return_fail,barebones -return #! 'return' outside function - -#%% yield_fail,barebones -yield 5 #! 'yield' outside function - -#%% yield_fail_2,barebones -(yield) #! 'yield' outside function - -#%% while_else,barebones -a = 1 -while a: - print a #: 1 - a -= 1 -else: - print 'else' #: else -a = 1 -while a: - print a #: 1 - a -= 1 -else not break: - print 'else' #: else -while True: - print 'infinite' #: infinite - break -else: - print 'nope' - -#%% for_assignment,barebones -l = [[1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11]] -for a, *m, b in l: - print a + b, len(m) -#: 5 2 -#: 14 3 -#: 21 0 - -#%% for_else,barebones -for i in [1]: - print i #: 1 -else: - print 'else' #: else -for i in [1]: - print i #: 1 -else not break: - print 'else' #: else -for i in [1]: - print i #: 1 - break -else: - print 'nope' - -best = 4 -for s in [3, 4, 5]: - for i in [s]: - if s >= best: - print('b:', best) - break - else: - print('s:', s) - best = s -#: s: 3 -#: b: 3 -#: b: 3 - -#%% match -def foo(x): - match x: - case 1: - print 'int' - case 2 ... 10: - print 'range' - case 'ACGT': - print 'string' - case (a, 1): - print 'tuple_wild', a - case []: - print 'list' - case [[]]: - print 'list list' - case [1, 2]: - print 'list 2' - case [1, z, ...] if z < 5: - print 'list 3', z - case [1, _, ..., zz] | (1, zz): - print 'list 4', zz - case (1 ... 10, s := ('ACGT', 1 ... 4)): - print 'complex', s - case _: - print 'else' -foo(1) #: int -foo(5) #: range -foo('ACGT') #: string -foo((9, 1)) #: tuple_wild 9 -foo(List[int]()) #: list -foo([List[int]()]) #: list list -foo([1, 2]) #: list 2 -foo([1, 3]) #: list 3 3 -foo([1, 5]) #: else -foo([1, 5, 10]) #: list 4 10 -foo((1, 33)) #: list 4 33 -foo((9, ('ACGT', 3))) #: complex ('ACGT', 3) -foo(range(10)) #: else - -for op in 'MI=DXSN': - match op: - case 'M' | '=' | 'X': - print('case 1') - case 'I' or 'S': - print('case 2') - case _: - print('case 3') -#: case 1 -#: case 2 -#: case 1 -#: case 3 -#: case 1 -#: case 2 -#: case 3 - -#%% match_err_1,barebones -match [1, 2]: - case [1, ..., 2, ..., 3]: pass -#! multiple ellipses in a pattern - -#%% global,barebones -a = 1 -def foo(): - global a - a += 1 -print a, -foo() -print a #: 1 2 - -#%% global_err,barebones -a = 1 -global a #! 'global' outside function - -#%% global_err_2,barebones -def foo(): - global b #! name 'b' is not defined - -#%% global_err_3,barebones -def foo(): - b = 1 - def bar(): - global b #! no binding for global 'b' found - -#%% global_err_4,barebones -a = 1 -def foo(): - a += 1 -foo() #! local variable 'a' referenced before assignment - -#%% global_ref,barebones -a = [1] -def foo(): - a.append(2) -foo() -print a #: [1, 2] - -#%% yield_from,barebones -def foo(): - yield from range(3) - yield from range(10, 13) - yield -1 -print list(foo()) #: [0, 1, 2, 10, 11, 12, -1] - -#%% with,barebones -class Foo: - i: int - def __enter__(self: Foo): - print '> foo! ' + str(self.i) - def __exit__(self: Foo): - print '< foo! ' + str(self.i) - def foo(self: Foo): - print 'woof' -class Bar: - s: str - def __enter__(self: Bar): - print '> bar! ' + self.s - def __exit__(self: Bar): - print '< bar! ' + self.s - def bar(self: Bar): - print 'meow' -with Foo(0) as f: -#: > foo! 0 - f.foo() #: woof -#: < foo! 0 -with Foo(1) as f, Bar('s') as b: -#: > foo! 1 -#: > bar! s - f.foo() #: woof - b.bar() #: meow -#: < bar! s -#: < foo! 1 -with Foo(2), Bar('t') as q: -#: > foo! 2 -#: > bar! t - print 'eeh' #: eeh - q.bar() #: meow -#: < bar! t -#: < foo! 2 - -#%% import_c,barebones -from C import sqrt(float) -> float -print sqrt(4.0) #: 2 - -from C import puts(cobj) -puts("hello".ptr) #: hello - -from C import atoi(cobj) -> int as s2i -print s2i("11".ptr) #: 11 - -@C -def log(x: float) -> float: - pass -print log(5.5) #: 1.70475 - -from C import seq_flags: Int[32] as e -# debug | standalone == 5 -print e #: 5 - -#%% import_c_shadow_error,barebones -# Issue #45 -from C import sqrt(float) -> float as foo -sqrt(100.0) #! name 'sqrt' is not defined - - -#%% import_c_dylib,barebones -from internal.dlopen import dlext -RT = "./libcodonrt." + dlext() -if RT[-3:] == ".so": - RT = "build/" + RT[2:] -from C import RT.seq_str_int(int, str, Ptr[bool]) -> str as sp -p = False -print sp(65, "", __ptr__(p)) #: 65 - -#%% import_c_dylib_error,barebones -from C import "".seq_print(str) as sp -sp("hi!") #! syntax error, unexpected '"' - -#%% import,barebones -zoo, _zoo = 1, 1 -print zoo, _zoo, __name__ #: 1 1 __main__ - -import a #: a -a.foo() #: a.foo - -from a import foo, bar as b -foo() #: a.foo -b() #: a.bar - -print str(a)[:9], str(a)[-18:] #: - -import a.b -print a.b.c #: a.b.c -a.b.har() #: a.b.har a.b.__init__ a.b.c - -print a.b.A.B.b_foo().__add__(1) #: a.b.A.B.b_foo() -#: 2 - -print str(a.b)[:9], str(a.b)[-20:] #: -print Int[a.b.stt].__class__.__name__ #: Int[5] - -from a.b import * -har() #: a.b.har a.b.__init__ a.b.c -a.b.har() #: a.b.har a.b.__init__ a.b.c -fx() #: a.foo -print(stt, Int[stt].__class__.__name__) #: 5 Int[5] - -from a import * -print zoo, _zoo, __name__ #: 5 1 __main__ - -f = Foo(Ptr[B]()) -print f.__class__.__name__, f.t.__class__.__name__ #: Foo Ptr[B] - -a.ha() #: B - -print par #: x - -#%% import_subimport,barebones -import a as xa #: a - -xa.foo() #: a.foo -#: a.sub -xa.sub.foo() #: a.sub.foo - -#%% import_order,barebones -def foo(): - import a - a.foo() -def bar(): - import a - a.bar() - -bar() #: a -#: a.bar -foo() #: a.foo - -#%% import_class -import sys -print str(sys)[:20] #: int: - a{={=}} -#! invalid LLVM code - -#%% function_llvm_err_4,barebones -a = 5 -@llvm -def foo() -> int: - a{=a -#! invalid LLVM code - -#%% function_self,barebones -class Foo: - def foo(self): - return 'F' -f = Foo() -print f.foo() #: F - -#%% function_self_err,barebones -class Foo: - def foo(self): - return 'F' -Foo.foo(1) #! 'Foo' object has no method 'foo' with arguments (int) - -#%% function_nested,barebones -def foo(v): - value = v - def bar(): - return value - return bar -baz = foo(2) -print baz() #: 2 - -def f(x): - a=1 - def g(y): - return a+y - return g(x) -print f(5) #: 6 - -#%% nested_generic_static,barebones -def foo(): - N: Static[int] = 5 - Z: Static[int] = 15 - T = Int[Z] - def bar(): - x = __array__[T](N) - print(x.__class__.__name__) - return bar -foo()() #: Array[Int[15]] - -def f[T](): - def g(): - return T() - return g() -print f(int) #: 0 - -#%% class_err_1,barebones -@extend -@foo -class Foo: - pass -#! cannot combine '@extend' with other attributes or decorators - -#%% class_err_1b,barebones -size_t = i32 -@extend -class size_t: - pass -#! class name 'size_t' is not defined - -#%% class_err_2,barebones -def foo(): - @extend - class Foo: - pass -#! class extension must be a top-level statement - -#%% class_nested,barebones -class Foo: - foo: int - class Bar: - bar: int - b: Optional[Foo.Bar] - c: Optional[int] - class Moo: - # TODO: allow nested class reference to the upclass - # x: Foo.Bar - x: int -y = Foo(1) -z = Foo.Bar(2, None, 4) -m = Foo.Bar.Moo(5) -print y.foo #: 1 -print z.bar, z.b.__bool__(), z.c, m.x #: 2 False 4 5 - -#%% class_nested_2,barebones -@tuple -class Foo: - @tuple - class Bar: - x: int - x: int - b: Bar - c: Foo.Bar -f = Foo(5, Foo.Bar(6), Foo.Bar(7)) -print(f) #: (x: 5, b: (x: 6), c: (x: 7)) - -#%% class_nested_err,barebones -class Foo: - class Bar: - b: Ptr[Bar] -#! name 'Bar' is not defined - -#%% class_err_4,barebones -@extend -class Foo: - pass -#! class name 'Foo' is not defined - -#%% class_err_5,barebones -class Foo[T, U]: - pass -@extend -class Foo[T]: - pass -#! class extensions cannot define data attributes and generics or inherit other classes - -#%% class_err_7,barebones -class Foo: - a: int - a: int -#! duplicate data attribute 'a' in class definition - -#%% class_err_tuple_no_recursive,barebones -@tuple -class Foo: - a: Foo -#! name 'Foo' is not defined - -#%% class_err_8,barebones -class Foo: - while True: pass -#! unexpected expression in class definition - -#%% class_err_9,barebones -class F[T: Static[float]]: - pass -#! expected 'int' or 'str' (only integers and strings can be static) - -#%% class_err_10,barebones -def foo[T](): - class A: - x: T -#! name 'T' cannot be captured - -#%% class_err_11,barebones -def foo(x): - class A: - def bar(): - print x -#! name 'x' cannot be captured - -#%% class_err_12,barebones -def foo(x): - T = type(x) - class A: - def bar(): - print T() -#! name 'T' cannot be captured - -#%% recursive_class,barebones -class Node[T]: - data: T - children: List[Node[T]] - def __init__(self, data: T): - self.data = data - self.children = [] -print Node(2).data #: 2 - -class Node2: - data: int - children: List[Node2] - def __init__(self, data: int): - self.data = data - self.children = [] -print Node2(3).data #: 3 - -#%% class_auto_init,barebones -class X[T]: - a: int = 4 - b: int - c: T - d: str = 'oops' - def __str__(self): - return f'X({self.a},{self.b},{self.c},{self.d})' -x = X[float]() -print x #: X(4,0,0,oops) -y = X(c='darius',a=5) -print y #: X(5,0,darius,oops) - -#%% magic,barebones -@tuple -class Foo: - x: int - y: int -a, b = Foo(1, 2), Foo(1, 3) -print a, b #: (x: 1, y: 2) (x: 1, y: 3) -print a.__len__() #: 2 -print a.__hash__(), b.__hash__() #: 175247769363 175247769360 -print a == a, a == b #: True False -print a != a, a != b #: False True -print a < a, a < b, b < a #: False True False -print a <= a, a <= b, b <= a #: True True False -print a > a, a > b, b > a #: False False True -print a >= a, a >= b, b >= a #: True False True -print a.__getitem__(1) #: 2 -print list(a.__iter__()) #: [1, 2] - -#%% magic_class,barebones -@dataclass(eq=True, order=True) -class Foo: - x: int - y: int - def __str__(self): return f'{self.x}_{self.y}' -a, b = Foo(1, 2), Foo(1, 3) -print a, b #: 1_2 1_3 -print a == a, a == b #: True False -print a != a, a != b #: False True -print a < a, a < b, b < a #: False True False -print a <= a, a <= b, b <= a #: True True False -print a > a, a > b, b > a #: False False True -print a >= a, a >= b, b >= a #: True False True - -# Right magic test -class X: - x: int -class Y: - y: int - def __eq__(self, o: X): return self.y == o.x - def __ne__(self, o: X): return self.y != o.x - def __le__(self, o: X): return self.y <= o.x - def __lt__(self, o: X): return self.y < o.x - def __ge__(self, o: X): return self.y >= o.x - def __gt__(self, o: X): return self.y > o.x - def __add__(self, o: X): return self.y + o.x + 1 - def __radd__(self, o: X): return self.y + o.x + 2 -print Y(1) == X(1), Y(1) != X(1) #: True False -print X(1) == Y(1), X(1) != Y(1) #: True False -print Y(1) <= X(2), Y(1) < X(2) #: True True -print X(1) <= Y(2), X(1) < Y(2) #: True True -print Y(1) >= X(2), Y(1) > X(2) #: False False -print X(1) >= Y(2), X(1) > Y(2) #: False False -print X(1) + Y(2) #: 5 -print Y(1) + X(2) #: 4 - - -class A: - def __radd__(self, n: int): - return 0 -def f(): - print('f') - return 1 -def g(): - print('g') - return A() -f() + g() -#: f -#: g - -#%% magic_2,barebones -@tuple -class Foo: - pass -a, b = Foo(), Foo() -print a, b #: () () -print a.__len__() #: 0 -print a.__hash__(), b.__hash__() #: 0 0 -print a == a, a == b #: True True -print a != a, a != b #: False False -print a < a, a < b, b < a #: False False False -print a <= a, a <= b, b <= a #: True True True -print a > a, a > b, b > a #: False False False -print a >= a, a >= b, b >= a #: True True True - -# TODO: pickle / to_py / from_py - -#%% magic_contains,barebones -sponge = (1, 'z', 1.55, 'q', 48556) -print 1.1 in sponge #: False -print 'q' in sponge #: True -print True in sponge #: False - -bob = (1, 2, 3) -print 1.1 in sponge #: False -print 1 in sponge #: True -print 0 in sponge #: False - -#%% magic_err_2,barebones -@tuple -class Foo: - pass -try: - print Foo().__getitem__(1) -except IndexError: - print 'error' #: error - -#%% magic_empty_tuple,barebones -@tuple -class Foo: - pass -print list(Foo().__iter__()) #: [] - -#%% magic_err_4,barebones -@tuple(eq=False) -class Foo: - x: int -Foo(1).__eq__(Foo(1)) #! 'Foo' object has no attribute '__eq__' - -#%% magic_err_5,barebones -@tuple(pickle=False) -class Foo: - x: int -p = Ptr[byte]() -Foo(1).__pickle__(p) #! 'Foo' object has no attribute '__pickle__' - -#%% magic_err_6,barebones -@tuple(container=False) -class Foo: - x: int -Foo(1).__getitem__(0) #! 'Foo' object has no attribute '__getitem__' - -#%% magic_err_7,barebones -@tuple(python=False) -class Foo: - x: int -p = Ptr[byte]() -Foo(1).__to_py__(p) #! 'Foo' object has no attribute '__to_py__' - -#%% python -from python import os -print os.name #: posix - -from python import datetime -z = datetime.datetime.utcfromtimestamp(0) -print z #: 1970-01-01 00:00:00 - -#%% python_numpy -from python import numpy as np -a = np.arange(9).reshape(3, 3) -print a -#: [[0 1 2] -#: [3 4 5] -#: [6 7 8]] -print a.dtype.name #: int64 -print np.transpose(a) -#: [[0 3 6] -#: [1 4 7] -#: [2 5 8]] -n = np.array([[1, 2], [3, 4]]) -print n[0], n[0][0] + 1 #: [1 2] 2 - -a = np.array([1,2,3]) -print(a + 1) #: [2 3 4] -print(a - 1) #: [0 1 2] -print(1 - a) #: [ 0 -1 -2] - -#%% python_import_fn -from python import re.split(str, str) -> List[str] as rs -print rs(r'\W+', 'Words, words, words.') #: ['Words', 'words', 'words', ''] - -#%% python_import_fn_2 -from python import os.system(str) -> int -system("echo 'hello!'") #: hello! - -#%% python_pydef -@python -def test_pydef(n) -> str: - return ''.join(map(str,range(n))) -print test_pydef(5) #: 01234 - -#%% python_pydef_nested -def foo(): - @python - def pyfoo(): - return 1 - print pyfoo() #: 1 - if True: - @python - def pyfoo2(): - return 2 - print pyfoo2() #: 2 - pass - @python - def pyfoo3(): - if 1: - return 3 - return str(pyfoo3()) -print foo() #: 3 - -#%% python_pyobj -@python -def foofn() -> Dict[pyobj, pyobj]: - return {"str": "hai", "int": 1} - -foo = foofn() -print(sorted(foo.items(), key=lambda x: str(x)), foo.__class__.__name__) -#: [('int', 1), ('str', 'hai')] Dict[pyobj,pyobj] -foo["codon"] = 5.15 -print(sorted(foo.items(), key=lambda x: str(x)), foo["codon"].__class__.__name__, foo.__class__.__name__) -#: [('codon', 5.15), ('int', 1), ('str', 'hai')] pyobj Dict[pyobj,pyobj] - -a = {1: "s", 2: "t"} -a[3] = foo["str"] -print(sorted(a.items())) #: [(1, 's'), (2, 't'), (3, 'hai')] - - -#%% python_isinstance -import python - -@python -def foo(): - return 1 - -z = foo() -print(z.__class__.__name__) #: pyobj - -print isinstance(z, pyobj) #: True -print isinstance(z, int) #: False -print isinstance(z, python.int) #: True -print isinstance(z, python.ValueError) #: False - -print isinstance(z, (int, str, python.int)) #: True -print isinstance(z, (int, str, python.AttributeError)) #: False - -try: - foo().x -except python.ValueError: - pass -except python.AttributeError as e: - print('caught', e, e.__class__.__name__) #: caught 'int' object has no attribute 'x' pyobj - - -#%% python_exceptions -import python - -@python -def foo(): - return 1 - -try: - foo().x -except python.AttributeError as f: - print 'py.Att', f #: py.Att 'int' object has no attribute 'x' -except ValueError: - print 'Val' -except PyError as e: - print 'PyError', e -try: - foo().x -except python.ValueError as f: - print 'py.Att', f -except ValueError: - print 'Val' -except PyError as e: - print 'PyError', e #: PyError 'int' object has no attribute 'x' -try: - raise ValueError("ho") -except python.ValueError as f: - print 'py.Att', f -except ValueError: - print 'Val' #: Val -except PyError as e: - print 'PyError', e - - -#%% typeof_definition_error,barebones -a = 1 -class X: - b: type(a) #! cannot use type() in type signatures - -#%% typeof_definition_error_2,barebones -def foo(a)->type(a): pass #! cannot use type() in type signatures - -#%% typeof_definition_error_3,barebones -a=1 -b: type(a) = 1 #! cannot use type() in type signatures - -#%% assign_underscore,barebones -_ = 5 -_ = 's' - -#%% inherit_class_4,barebones -class defdict[K,V](Static[Dict[K,V]]): - fx: Function[[],V] - def __init__(self, d: Dict[K,V], fx: Function[[], V]): - self.__init__() - for k,v in d.items(): self[k] = v - self.fx = fx - def __getitem__(self, key: K) -> V: - if key in self: - return self.values[self.keys.index(key)] - else: - self[key] = self.fx() - return self[key] -z = defdict({'ha':1}, lambda: -1) -print z -print z['he'] -print z -#: {'ha': 1} -#: -1 -#: {'ha': 1, 'he': -1} - -class Foo: - x: int - def foo(self): - return f'foo {self.x}' -class Bar[T]: - y: T - def bar(self): - return f'bar {self.y}/{self.y.__class__.__name__}' -class FooBarBaz[T](Static[Foo], Static[Bar[T]]): - def baz(self): - return f'baz! {self.foo()} {self.bar()}' -print FooBarBaz[str]().foo() #: foo 0 -print FooBarBaz[float]().bar() #: bar 0/float -print FooBarBaz[str]().baz() #: baz! foo 0 bar /str - -#%% inherit_class_err_5,barebones -class defdict(Static[Dict[str,float]]): - def __init__(self, d: Dict[str, float]): - self.__init__(d.items()) -z = defdict() -z[1.1] #! 'defdict' object has no method '__getitem__' with arguments (defdict, float) - -#%% inherit_tuple,barebones -class Foo: - a: int - b: str - def __init__(self, a: int): - self.a, self.b = a, 'yoo' -@tuple -class FooTup(Static[Foo]): pass - -f = Foo(5) -print f.a, f.b #: 5 yoo -fp = FooTup(6, 's') -print fp #: (a: 6, b: 's') - -#%% inherit_class_err_1,barebones -class defdict(Static[Array[int]]): - pass #! reference classes cannot inherit tuple classes - -#%% inherit_class_err_2,barebones -@tuple -class defdict(Static[int]): - pass #! internal classes cannot inherit other classes - -#%% inherit_class_err_3,barebones -class defdict(Static[Dict[int, float, float]]): - pass #! Dict takes 2 generics (3 given) - -#%% inherit_class_err_4,barebones -class Foo: - x: int -class Bar: - x: float -class FooBar(Static[Foo], Static[Bar]): - pass -# right now works as we rename other fields - -#%% keyword_prefix,barebones -def foo(return_, pass_, yield_, break_, continue_, print_, assert_): - return_.append(1) - pass_.append(2) - yield_.append(3) - break_.append(4) - continue_.append(5) - print_.append(6) - assert_.append(7) - return return_, pass_, yield_, break_, continue_, print_, assert_ -print foo([1], [1], [1], [1], [1], [1], [1]) -#: ([1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7]) - - -#%% class_deduce,barebones -@deduce -class Foo: - def __init__(self, x): - self.x = [x] - self.y = 1, x - -f = Foo(1) -print(f.x, f.y, f.__class__.__name__) #: [1] (1, 1) Foo[List[int],Tuple[int,int]] - -f: Foo = Foo('s') -print(f.x, f.y, f.__class__.__name__) #: ['s'] (1, 's') Foo[List[str],Tuple[int,str]] - -@deduce -class Bar: - def __init__(self, y): - self.y = Foo(y) - -b = Bar(3.1) -print(b.y.x, b.__class__.__name__) #: [3.1] Bar[Foo[List[float],Tuple[int,float]]] - -#%% multi_error,barebones -a = 55 -print z #! name 'z' is not defined -print(a, q, w) #! name 'q' is not defined -print quit #! name 'quit' is not defined - -#%% class_var,barebones -class Foo: - cx = 15 - x: int = 10 - cy: ClassVar[str] = "ho" - class Bar: - bx = 1.1 -print(Foo.cx) #: 15 -f = Foo() -print(Foo.cy, f.cy) #: ho ho -print(Foo.Bar.bx) #: 1.1 - -Foo.cx = 10 -print(Foo.cx) #: 10 - -def x(): - class Foo: - i = 0 - f = Foo() - def __init__(self): - Foo.i += 1 - def __repr__(self): - return 'heh-cls' - Foo(), Foo(), Foo() - print Foo.f, Foo.i #: heh-cls 4 - return Foo() -f = x() -print f.f, f.i #: heh-cls 5 - -@tuple -class Fot: - f = Fot() - def __repr__(self): - return 'heh-tup' -print Fot.f #: heh-tup - - -#%% dot_access_error,barebones -class Foo: - x: int = 1 -Foo.x #! 'Foo' object has no attribute 'x' - -#%% scoping_same_name,barebones -def match(pattern: str, string: str, flags: int = 0): - pass - -def match(match): - if True: - match = 0 - match - -match(1) - -#%% loop_domination,barebones -for i in range(2): - try: dat = 1 - except: pass - print(dat) -#: 1 -#: 1 - -def comprehension_test(x): - for n in range(3): - print('>', n) - l = ['1', '2', str(x)] - x = [n for n in l] - print(x, n) -comprehension_test(5) -#: > 0 -#: > 1 -#: > 2 -#: ['1', '2', '5'] 2 - - -#%% block_unroll,barebones -# Ensure that block unrolling is done in RAII manner on error -def foo(): - while True: - def magic(a: x): - return - print b -foo() -#! name 'x' is not defined -#! name 'b' is not defined - -#%% capture_recursive,barebones -def f(x: int) -> int: - z = 2 * x - def g(y: int) -> int: - if y == 0: - return 1 - else: - return g(y - 1) * z - return g(4) -print(f(3)) #: 1296 - -#%% class_setter,barebones -class Foo: - _x: int - - @property - def x(self): - print('getter') - return self._x - - @x.setter - def x(self, v): - print('setter') - self._x = v - -f = Foo(1) -print(f.x) -#: getter -#: 1 - -f.x = 99 -print(f.x) -print(f._x) -#: setter -#: getter -#: 99 -#: 99 diff --git a/test/parser/a/__init__.codon b/test/parser/typecheck/a/__init__.codon similarity index 100% rename from test/parser/a/__init__.codon rename to test/parser/typecheck/a/__init__.codon diff --git a/test/parser/a/b/__init__.codon b/test/parser/typecheck/a/b/__init__.codon similarity index 100% rename from test/parser/a/b/__init__.codon rename to test/parser/typecheck/a/b/__init__.codon diff --git a/test/parser/a/b/rec1.codon b/test/parser/typecheck/a/b/rec1.codon similarity index 100% rename from test/parser/a/b/rec1.codon rename to test/parser/typecheck/a/b/rec1.codon diff --git a/test/parser/a/b/rec1_err.codon b/test/parser/typecheck/a/b/rec1_err.codon similarity index 100% rename from test/parser/a/b/rec1_err.codon rename to test/parser/typecheck/a/b/rec1_err.codon diff --git a/test/parser/a/b/rec2.codon b/test/parser/typecheck/a/b/rec2.codon similarity index 100% rename from test/parser/a/b/rec2.codon rename to test/parser/typecheck/a/b/rec2.codon diff --git a/test/parser/a/b/rec2_err.codon b/test/parser/typecheck/a/b/rec2_err.codon similarity index 100% rename from test/parser/a/b/rec2_err.codon rename to test/parser/typecheck/a/b/rec2_err.codon diff --git a/test/parser/a/sub/__init__.codon b/test/parser/typecheck/a/sub/__init__.codon similarity index 100% rename from test/parser/a/sub/__init__.codon rename to test/parser/typecheck/a/sub/__init__.codon diff --git a/test/parser/typecheck/test_access.codon b/test/parser/typecheck/test_access.codon new file mode 100644 index 00000000..b37fa61e --- /dev/null +++ b/test/parser/typecheck/test_access.codon @@ -0,0 +1,580 @@ +#%% __ignore__ +from typing import Optional + +#%% id_fstring_error,barebones +f"a{b + 3}" #! name 'b' is not defined + +#%% id_access,barebones +def foo(): + a = 5 + def bar(): + print(a) + bar() #: 5 + a = 4 + bar() #: 5 + ## TODO: should be 4, needs Cell pointer +foo() + +d = {} +def foo(): + a = 5 + def goo(): + d['x'] = 'y' + print(a) + return goo +foo()() +print(d) +#: 5 +#: {'x': 'y'} + +#%% nonlocal,barebones +def goo(ww): + z = 0 + def foo(x): + f = 10 + def bar(y): + nonlocal z + f = x + y + z += y + print('goo.foo.bar', f, z) + bar(5) + print('goo.foo', f) + return bar + b = foo(10) + print('goo', z) + return b +b = goo('s') +# goo.foo.bar 15 5 +# goo.foo 10 +# goo 5 +b(11) +# goo.foo.bar 21 16 +b(12) +# goo.foo.bar 22 28 +b = goo(1) # test another instantiation +# goo.foo.bar 15 5 +# goo.foo 10 +# goo 5 +b(11) +# goo.foo.bar 21 16 +b(13) +# goo.foo.bar 23 29 + +#%% nonlocal_error,barebones +def goo(): + z = 0 + def foo(): + z += 1 + foo() +goo() #! local variable 'z' referenced before assignment + +#%% new_scoping,barebones +try: + if True and (x := (True or (y := 1 + 2))): + pass + try: + print(x) #: True + print(y) + except NameError: + print("Error") #: Error + print(x) #: True + if len("s") > 0: + print(x) #: True + print(y) + print(y) # TODO: test for __used__ usage + print(y) # (right now manual inspection is needed) +except NameError as e: + print(e) #: name 'y' is not defined + +t = True +y = 0 if t else (xx := 1) +try: + print(xx) +except NameError: + print("Error") #: Error + +def foo(): + if len("s") == 3: + x = 3 + x = 5 ## TODO: MOVE AFTER THE FN ONCE REF-CAPTURES ARE IMPLEMENTED + def bar(y): + print(x + y) + return bar +f = foo() +f(5) #: 10 + +# This should compile. +def rad4f(ido: int, l1: int, cxai): + def CC(a: int, b: int, c: int): + return cxai[a+ido*(b+l1*c)] + for k in range(l1): + # Make sure that cxai[0] assignment does not mark + # cxai as "adding" variable + # See scoping.cpp:visit(DotExpr*) (or IndexExpr*) + tr1, cxai[0] = 1, 1 +rad4f(1, 2, [1, 2]) + +#%% new_scoping_loops_try,barebones +for i in range(10): + pass +print(i) #: 9 + +j = 6 +for j in range(0): + pass +print(j) #: 6 + +for j in range(1): + pass +print(j) #: 0 + +z = 6 +for z in []: + pass +print(z) #: 6 + +for z in [1, 2]: + pass +print(z) #: 2 + +try: + raise ValueError("hi") +except ValueError as e: + ee = e +print(ee) #: hi + +#%% new_scoping_loops_try_error,barebones +try: + pass +except ValueError as f: + pass +try: + print(f.message) #! no module named 'f' +except NameError: + print('error') + +#%% dot_access_error_NOPY,barebones +class Foo: + x: int = 1 +Foo.x #! 'Foo' object has no attribute 'x' + +#%% scoping_same_name,barebones +def match(pattern: str, string: str, flags: int = 0): + pass + +def match(match): + if True: + match = 0 + match + +match(1) + +#%% dot_case_1,barebones +a = [] +print(a[0].loop()) #! 'int' object has no attribute 'loop' +a.append(5) + +#%% dot_case_2_NOPY,barebones +a = Optional(0) +print(a.__bool__()) #: False +print(a.__add__(1)) #: 1 + +#%% dot_case_4_NOPY,barebones +a = [5] +print(a.len) #: 1 + +#%% dot_case_4_err,barebones +a = [5] +a.foo #! 'List[int]' object has no attribute 'foo' + +#%% dot_case_6_NOPY,barebones +# Did heavy changes to this testcase because +# of the automatic optional wraps/unwraps and promotions +class Foo: + def bar(self, a): + print('generic', a, a.__class__.__name__) + def bar(self, a: Optional[float]): + print('optional', a) + def bar(self, a: int): + print('normal', a) +f = Foo() +f.bar(1) #: normal 1 +f.bar(1.1) #: optional 1.1 +f.bar(Optional('s')) #: generic s Optional[str] +# Check static caching +f.bar(Optional('t')) #: generic t Optional[str] +f.bar('hehe') #: generic hehe str + + +#%% dot_case_6b_NOPY,barebones +class Foo: + def bar(self, a, b): + print('1', a, b) + def bar(self, a, b: str): + print('2', a, b) + def bar(self, a: str, b): + print('3', a, b) +f = Foo() +# Take the newest highest scoring method +f.bar('s', 't') #: 3 s t +f.bar(1, 't') #: 2 1 t +f.bar('s', 1) #: 3 s 1 +f.bar(1, 2) #: 1 1 2 + +#%% dot,barebones +class Foo: + def clsmethod(): + print('foo') + def method(self, a): + print(a) +Foo.clsmethod() #: foo +Foo.method(Foo(), 1) #: 1 +m1 = Foo.method +m1(Foo(), 's') #: s +m2 = Foo().method +m2(1.1) #: 1.1 + +#%% dot_error_static,barebones +class Foo: + def clsmethod(): + print('foo') + def method(self, a): + print(a) +Foo().clsmethod() #! clsmethod() takes 0 arguments (1 given) + +#%% nested_class_error,barebones +class X: + def foo(self, x): + return x + class Y: + def bar(self, x): + return x +y = X.Y() +y.foo(1) #! 'X.Y' object has no attribute 'foo' + +#%% nested_deep_class_NOPY,barebones +class A[T]: + a: T + class B[U]: + b: U + class C[V]: + c: V + def foo[W](t: V, u: V, v: V, w: W): + return (t, u, v, w) + +print(A.B.C[bool].foo(W=str, ...).__fn_name__) #: foo[str;bool,bool,bool,str] +print(A.B.C.foo(1,1,1,True)) #: (1, 1, 1, True) +print(A.B.C.foo('x', 'x', 'x', 'x')) #: ('x', 'x', 'x', 'x') +print(A.B.C.foo('x', 'x', 'x', 'x')) #: ('x', 'x', 'x', 'x') +print(A.B.C.foo('x', 'x', 'x', 'x')) #: ('x', 'x', 'x', 'x') + +x = A.B.C[bool](False) +print(x.__class__.__name__) #: A.B.C[bool] + +#%% nested_deep_class_error_NOPY,barebones +class A[T]: + a: T + class B[U]: + b: U + class C[V]: + c: V + def foo[W](t: V, u: V, v: V, w: W): + return (t, u, v, w) + +print A.B.C[str].foo(1,1,1,True) #! 'int' does not match expected type 'str' + +#%% nested_deep_class_error_2_NOPY,barebones +class A[T]: + a: T + class B[U]: + b: U + class C[V]: + c: V + def foo[W](t: V, u: V, v: V, w: W): + return (t, u, v, w) +print A.B[int].C[float].foo(1,1,1,True) #! 'A.B[int]' object has no attribute 'C' + +#%% nested_class_function,barebones +def f(x): + def g(y): + return y + a = g(1) + b = g('s') + c = g(x) + return a, b, c +print f(1.1).__class__.__name__ #: Tuple[int,str,float] +print f(False).__class__.__name__ #: Tuple[int,str,bool] + +class A[T]: + a: T + class B[U]: + b: U + class C[V]: + c: V + def f(x): + def g(y): + return y + a = g(1) + b = g('s') + c = g(x) + return a, b, c +print A.B.C.f(1.1).__class__.__name__ #: Tuple[int,str,float] +print A.B.C[Optional[int]].f(False).__class__.__name__ #: Tuple[int,str,bool] + +#%% rec_class_1_NOPY,barebones +class A: + y: A + def __init__(self): pass # necessary to prevent recursive instantiation! +x = A() +print(x.__class__.__name__, x.y.__class__.__name__) #: A A + +#%% rec_class_2_NOPY,barebones +class A[T]: + a: T + b: A[T] + c: A[str] + def __init__(self): pass +a = A[int]() +print a.__class__.__name__, a.b.__class__.__name__, a.c.__class__.__name__, a.b.b.__class__.__name__, a.b.c.__class__.__name__ +#: A[int] A[int] A[str] A[int] A[str] +print a.c.b.__class__.__name__, a.c.c.__class__.__name__, a.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.__class__.__name__ +#: A[str] A[str] A[int] + +#%% rec_class_3_NOPY,barebones +class X: + x: int + rec: X + def __init__(self): pass + def foo(x: X, y: int): + return y + class Y: + y: int = 0 + def bar(self, y): + print y + return self.y +x, y = X(), X.Y() +print x.__class__.__name__, y.__class__.__name__ +#: X X.Y +print X.foo(x, 4), x.foo(5) +#: 4 5 +print y.bar(1), y.bar('s'), X.Y.bar(y, True) +#: 1 +#: s +#: True +#: 0 0 0 + +#%% rec_class_4_NOPY,barebones +class A[T]: + a: T + b: A[T] + c: A[str] + def __init__(self): pass +class B[T]: + a: T + b: A[T] + c: B[T] + def __init__(self): pass + class Nest1[U]: + n: U + class Nest2[T, U]: + m: T + n: U +b = B[float]() +print b.__class__.__name__, b.a.__class__.__name__, b.b.__class__.__name__, b.c.__class__.__name__, b.c.b.c.a.__class__.__name__ +#: B[float] float A[float] B[float] str + +n1 = B.Nest1[int](0) +print n1.n, n1.__class__.__name__, n1.n.__class__.__name__ #: 0 B.Nest1[int] int + +n1: B.Nest2 = B.Nest2[float, int](0, 0) +print (n1.m, n1.n), n1.__class__.__name__, n1.m.__class__.__name__, n1.n.__class__.__name__ #: (0, 0) B.Nest2[float,int] float int + +#%% class_fn_access_NOPY,barebones +class X[T]: + def foo[U](self, x: T, y: U): + return (x+x, y+y) +y = X[X[int]]() +print y.__class__.__name__ #: X[X[int]] +print X[float].foo(U=int, ...).__fn_name__ #: foo[int;X[float],float,int] +print X[int]().foo(1, 's') #: (2, 'ss') + +#%% class_partial_access_NOPY,barebones +class X[T]: + def foo[U](self, x, y: U): + return (x+x, y+y) +y = X[X[int]]() +# TODO: should this even be the case? +# print y.foo(U=float,...).__class__.__name__ -> X.foo[X[X[int]],...,...] +print y.foo(1, 2.2, float) #: (2, 4.4) + +#%% fn_overloads_NOPY,barebones +def foo(x): + return 1, x + +print(foo('')) #: (1, '') + +@overload +def foo(x, y): + def foo(x, y): + return f'{x}_{y}' + return 2, foo(x, y) + +@overload +def foo(x): + if x == '': + return 3, 0 + return 3, 1 + foo(x[1:])[1] + +print foo('hi') #: (3, 2) +print foo('hi', 1) #: (2, 'hi_1') + +def fox(a: int, b: int, c: int, dtype: type = int): + print('fox 1:', a, b, c) + +@overload +def fox(a: int, b: int, dtype: type = int): + print('fox 2:', a, b, dtype.__class__.__name__) + +fox(1, 2, float) +#: fox 2: 1 2 float +fox(1, 2) +#: fox 2: 1 2 int +fox(1, 2, 3) +#: fox 1: 1 2 3 + +# Test whether recursive self references override overloads (they shouldn't) + +def arange(start: int, stop: int, step: int): + return (start, stop, step) + +@overload +def arange(stop: int): + return arange(0, stop, 1) + +print(arange(0, 1, 2)) +#: (0, 1, 2) +print(arange(12)) +#: (0, 12, 1) + + +#%% fn_shadow,barebones +def foo(x): + return 1, x +print(foo('hi')) #: (1, 'hi') + +def foo(x): + return 2, x +print(foo('hi')) #: (2, 'hi') + +#%% fn_overloads_error_NOPY,barebones +def foo(x): + return 1, x +@overload +def foo(x, y): + return 2, x, y +foo('hooooooooy!', 1, 2) +#! no function 'foo' with arguments (str, int, int) + +#%% fn_overloads_dispatch +import math +print(math.sqrt(4.0)) #: 2 + +#%% generator_capture_nonglobal,barebones +# Issue #49 +def foo(iter): + print(iter.__class__.__name__, list(iter)) + +for x in range(2): + foo(1 for _ in range(x)) +#: Generator[int] [] +#: Generator[int] [1] +for x in range(2): + for y in range(x): + foo('z' for _ in range(y)) +#: Generator[str] [] + +#%% nonlocal_capture_loop,barebones +# Issue #51 +def kernel(fn): + def wrapper(*args, grid, block): + print(grid, block, fn(*args)) + return wrapper +def test_mandelbrot(): + MAX = 10 # maximum Mandelbrot iterations + N = 2 # width and height of image + pixels = [0 for _ in range(N)] + def scale(x, a, b): + return a + (x/N)*(b - a) + @kernel + def k(pixels): + i = 0 + while i < MAX: i += 1 # this is needed for test to make sense + return (MAX, N, pixels, scale(N, -2, 0.4)) + k(pixels, grid=(N*N)//1024, block=1024) +test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4) + +#%% id_shadow_overload_call,barebones +def foo(): + def bar(): + return -1 + def xo(): + return bar() + @overload # w/o this this fails because xo cannot capture bar + def bar(a): + return a + bar(1) +foo() + +#%% domination_nested,barebones +def correlate(a, b, mode = 'valid'): + if mode == 'valid': + if isinstance(a, List): + xret = '1' + else: + xret = '2' + for i in a: + for j in b: + xret += 'z' + elif mode == 'same': + if isinstance(a, List): + xret = '3' + else: + xret = '4' + for i in a: + for j in b: + xret += 'z' + elif mode == 'full': + if isinstance(a, List): + xret = '5' + else: + xret = '6' + for i in a: + for j in b: + xret += 'z' + else: + raise ValueError(f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})") + return xret +print(correlate([1], [2], 'full')) #: 5z + +def foo(x, y): + a = 5 + if isinstance(a, int): + if staticlen(y) == 0: + a = 0 + elif staticlen(y) == 1: + a = 1 + else: + for i in range(10): + a = 40 + return a + return a +print foo(5, (1, 2, 3)) #: 40 + +#%% nontype_name,barebones +# Fix #357 +class Foo: + def goo(self): + print(self.__name__) +Foo().goo() +#! 'Foo' object has no attribute '__name__' +#! during the realization of goo diff --git a/test/parser/typecheck/test_assign.codon b/test/parser/typecheck/test_assign.codon new file mode 100644 index 00000000..c710e67a --- /dev/null +++ b/test/parser/typecheck/test_assign.codon @@ -0,0 +1,404 @@ +#%% __ignore__ +from typing import Optional, List, Dict, Generator +from dataclasses import dataclass + +#%% basic,barebones +a = 5 +b: float = 6.16 +c: Optional[str] = None +print(a, b, c) #: 5 6.16 None + +#%% walrus,barebones +def foo(x): + return x * x +if x := foo(3): + pass +if (x := foo(4)) and False: + print('Nope') +if False and (x := foo(5)): + print('Nope') +print(x) #: 16 + +a = [y := foo(1), y+1, y+2] +print(a) #: [1, 2, 3] + +print({y: b for y in [1,2,3] if (b := (y - 1))}) #: {2: 1, 3: 2} +print(list(b for y in [1,2,3] if (b := (y // 3)))) #: [1] + +#%% walrus_update,barebones +def foo(x): + return x * x +x = 5 +if x := foo(4): + pass +print(x) #: 16 + +#%% walrus_cond_1,barebones +def foo(x): + return x * x +if False or (x := foo(4)): + pass +print(x) #: 16 + +y = (z := foo(5)) if True else 0 +print(z) #: 25 + +#%% walrus_err,barebones +def foo(x): + return x * x +if False and (x := foo(4)): + pass +try: + print(x) +except NameError: + print("Error") #: Error + +t = True +y = 0 if t else (z := foo(4)) +try: + print(z) +except NameError: + print("Error") #: Error + +#%% unpack_specials,barebones +x, = 1, +print(x) #: 1 + +a = (2, 3) +b = (1, *a[1:]) +print(a, b) #: (2, 3) (1, 3) + +#%% assign,barebones +a = 1 +print(a) #: 1 +a = 2 +print(a) #: 2 + +x, y = 1, 2 +print(x, y) #: 1 2 +(x, y) = (3, 4) +print(x, y) #: 3 4 +x, y = (1, 2) +print(x, y) #: 1 2 +(x, y) = 3, 4 +print(x, y) #: 3 4 +(x, y) = [3, 4] +print(x, y) #: 3 4 +[x, y] = [1, 2] +print(x, y) #: 1 2 +[x, y] = (4, 3) +print(x, y) #: 4 3 + +l = list(iter(range(10))) +[a, b, *lx, c, d] = l +print(a, b, lx, c, d) #: 0 1 [2, 3, 4, 5, 6, 7] 8 9 +a, b, *lx = l +print(a, b, lx) #: 0 1 [2, 3, 4, 5, 6, 7, 8, 9] +*lx, a, b = l +print(lx, a, b) #: [0, 1, 2, 3, 4, 5, 6, 7] 8 9 +*xz, a, b = (1, 2, 3, 4, 5) +print(xz, a, b) #: (1, 2, 3) 4 5 +(*ex,) = [1, 2, 3] +print(ex) #: [1, 2, 3] + +#%% assign_str,barebones +sa, sb = 'XY' +print(sa, sb) #: X Y +(sa, sb), sc = 'XY', 'Z' +print(sa, sb, sc) #: X Y Z +sa, *la = 'X' +print(sa, la, 1) #: X 1 +sa, *la = 'XYZ' +print(sa, la) #: X YZ +(xa,xb), *xc, xd = [1,2],'this' +print(xa, xb, xc, xd) #: 1 2 () this +(a, b), (sc, *sl) = [1,2], 'this' +print(a, b, sc, sl) #: 1 2 t his + +#%% assign_index_dot,barebones +class Foo: + a: int = 0 + def __setitem__(self, i: int, t: int): + self.a += i * t +f = Foo() +f.a = 5 +print(f.a) #: 5 +f[3] = 5 +print(f.a) #: 20 +f[1] = -8 +print(f.a) #: 12 + +def foo(): + print('foo') + return 0 +v = [0] +v[foo()] += 1 +#: foo +print(v) +#: [1] + +#%% assign_err_1,barebones +a, *b, c, *d = 1,2,3,4,5 #! multiple starred expressions in assignment + +#%% assign_err_2_NOPY,barebones +a = [1, 2, 3] +a[1]: int = 3 #! syntax error, unexpected ':' + +#%% assign_err_3,barebones +a = 5 +a.x: int = 3 #! syntax error, unexpected ':' + +#%% assign_err_4,barebones +*x = range(5) #! cannot assign to given expression + +#%% assign_err_5_NOPY,barebones +# TODO in Python, this is a ValueError +try: + (sa, sb), sc = 'XYZ' +except IndexError: + print("assign failed") #: assign failed + +#%% assign_comprehension,barebones +g = ((b, a, c) for a, *b, c in ['ABC','DEEEEF','FHGIJ']) +x, *q, y = list(g) # TODO: auto-unroll as in Python +print(x, y, q) #: ('B', 'A', 'C') ('HGI', 'F', 'J') [('EEEE', 'D', 'F')] + +#%% assign_shadow,barebones +a = 5 +print(a) #: 5 +a : str = 's' +print(a) #: s + +#%% assign_err_must_exist,barebones +a = 1 +def foo(): + a += 2 #! local variable 'a' referenced before assignment +foo() + +#%% assign_rename,barebones +y = int +z = y(5) +print(z) #: 5 + +def foo(x): return x + 1 +x = foo +print(x(1)) #: 2 + +#%% assign_err_6,barebones +x = bar #! name 'bar' is not defined + +#%% assign_err_7,barebones +foo() += bar #! cannot assign to given expression + +#%% assign_update_eq,barebones +a = 5 +a += 3 +print(a) #: 8 +a -= 1 +print(a) #: 7 + +@dataclass +class Foo: + a: int + def __add__(self, i: int): + print('add!') + return Foo(self.a + i) + def __iadd__(self, i: int): + print('iadd!') + self.a += i + return self + def __str__(self): + return str(self.a) +f = Foo(3) +print(f + 2) #: add! +#: 5 +f += 6 #: iadd! +print(f) #: 9 + +#%% del,barebones +a = 5 +del a +print(a) #! name 'a' is not defined + +#%% del_index,barebones +y = [1, 2] +del y[0] +print(y) #: [2] + +#%% del_error,barebones +a = [1] +del a.ptr #! cannot delete given expression + +#%% assign_underscore,barebones +_ = 5 +_ = 's' + +#%% assign_optional_NOPY,barebones +a = None +print(a) #: None +a = 5 +print(a) #: 5 + +b: Optional[float] = Optional[float](6.5) +c: Optional[float] = 5.5 +print(b, c) #: 6.5 5.5 + +#%% assign_type_alias,barebones +I = int +print(I(5)) #: 5 + +L = dict[int, str] +l = L() +print(l) #: {} +l[5] = 'haha' +print(l) #: {5: 'haha'} + +#%% assign_type_annotation,barebones +a: List[int] = [] +print(a) #: [] + +#%% assign_type_err,barebones +a = 5 +if 1: + a = 3.3 #! 'float' does not match expected type 'int' +a + +#%% assign_atomic_NOPY,barebones +i = 1 +f = 1.1 + +@llvm +def xchg(d: Ptr[int], b: int) -> None: + %tmp = atomicrmw xchg i64* %d, i64 %b seq_cst + ret {} {} +@llvm +def aadd(d: Ptr[int], b: int) -> int: + %tmp = atomicrmw add i64* %d, i64 %b seq_cst + ret i64 %tmp +@llvm +def amin(d: Ptr[int], b: int) -> int: + %tmp = atomicrmw min i64* %d, i64 %b seq_cst + ret i64 %tmp +@llvm +def amax(d: Ptr[int], b: int) -> int: + %tmp = atomicrmw max i64* %d, i64 %b seq_cst + ret i64 %tmp +def min(a, b): return a if a < b else b +def max(a, b): return a if a > b else b + +@extend +class int: + def __atomic_xchg__(self: Ptr[int], i: int): + print('atomic:', self[0], '<-', i) + xchg(self, i) + def __atomic_add__(self: Ptr[int], i: int): + print('atomic:', self[0], '+=', i) + return aadd(self, i) + def __atomic_min__(self: Ptr[int], b: int): + print('atomic:', self[0], '?=', b) + return amax(self, b) + +@atomic +def foo(x): + global i, f + + i += 1 #: atomic: 1 += 1 + print(i) #: 2 + i //= 2 #: atomic: 2 <- 1 + print(i) #: 1 + i = 3 #: atomic: 1 <- 3 + print(i) #: 3 + i = min(i, 10) #: atomic: 3 ?= 10 + print(i) #: 10 + i = max(20, i) #: atomic: 10 <- 20 + print(i) #: 20 + + f += 1.1 + f = 3.3 + f = max(f, 5.5) +foo(1) +print(i, f) #: 20 5.5 + +#%% assign_atomic_real_NOPY +i = 1 +f = 1.1 +@atomic +def foo(x): + global i, f + + i += 1 + print(i) #: 2 + i //= 2 + print(i) #: 1 + i = 3 + print(i) #: 3 + i = min(i, 10) + print(i) #: 3 + i = max(i, 10) + print(i) #: 10 + + f += 1.1 + f = 3.3 + f = max(f, 5.5) +foo(1) +print(i, f) #: 10 5.5 + +#%% assign_member_NOPY,barebones +class Foo: + x: Optional[int] = None +f = Foo() +print(f.x) #: None +f.x = 5 +print(f.x) #: 5 + +fo = Optional(Foo()) +fo.x = 6 +print(fo.x) #: 6 + +#%% assign_member_err_1_NOPY,barebones +class Foo: + x: Optional[int] = None +Foo().y = 5 #! 'Foo' object has no attribute 'y' + +#%% assign_member_err_2_NOPY,barebones +@tuple +class Foo: + x: Optional[int] = None +Foo().x = 5 #! cannot modify tuple attributes + +#%% assign_wrappers_NOPY,barebones +a = 1.5 +print(a) #: 1.5 +if 1: + a = 1 +print(a, a.__class__.__name__) #: 1 float + +a: Optional[int] = None +if 1: + a = 5 +print(a.__class__.__name__, a) #: Optional[int] 5 + +b = 5 +c: Optional[int] = 6 +if 1: + b = c +print(b.__class__.__name__, c.__class__.__name__, b, c) #: int Optional[int] 6 6 + +z: Generator[int] = [1, 2] +print(z.__class__.__name__) #: Generator[int] + +zx: float = 1 +print(zx.__class__.__name__, zx) #: float 1 + +def test(v: Optional[int]): + v: int = v if v is not None else 3 + print(v.__class__.__name__) +test(5) #: int +test(None) #: int + +# %% diff --git a/test/parser/typecheck/test_basic.codon b/test/parser/typecheck/test_basic.codon new file mode 100644 index 00000000..c7212772 --- /dev/null +++ b/test/parser/typecheck/test_basic.codon @@ -0,0 +1,120 @@ +#%% none,barebones +a = None +print(a.__class__.__name__, a) #: Optional[int] None +if True: a = 5 # wrap with `if`` to avoid shadowing +print(a.__class__.__name__, a) #: Optional[int] 5 + +#%% none_unbound,barebones +a = None +print(a.__class__.__name__, a) #: Optional[NoneType] None + +#%% bool,barebones +print(True, False) #: True False +a = True +print(a.__class__.__name__, a) #: bool True + +#%% int,barebones +i = 15 +print(i.__class__.__name__, i) #: int 15 +print(0b0000_1111) #: 15 +print(0B101) #: 5 +print(3) #: 3 +print(18_446_744_073_709_551_000) #: -616 +print(0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111) #: -1 +print(0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111u) #: 18446744073709551615 +print(18_446_744_073_709_551_000u) #: 18446744073709551000 +print(65i7) #: -63 +print(-1u7) #: 127 + +#%% int_suffix,barebones +@extend +class int: + def __suffix_test__(s): + return 'TEST: ' + str(s) +print(123_456test) #: TEST: 123456 + +#%% int_large,barebones +print(1844674407_3709551999) #: 383 +print(1844674407_3709551999i256) #: 18446744073709551999 + +#%% float,barebones +f = 1.11 +print(f.__class__.__name__, f) #: float 1.11 +print(5.15) #: 5.15 +print(2e2) #: 200 +print(2.e-2) #: 0.02 +print 1_000.0 #: 1000 +print 1_000e9 #: 1e+12 + +#%% float_suffix,barebones +@extend +class float: + def __suffix_zoo__(x): + return str(x) + '_zoo' +print(1.2e-1zoo) #: 0.12_zoo + +#%% string,barebones +a = 'hi' +print(a.__class__.__name__, a) #: str hi +print('kthxbai', "kthxbai") #: kthxbai kthxbai +print("""hi +hello""", '''hai +hallo''') +#: hi +#: hello hai +#: hallo + +#%% fstring,barebones +a, b = 1, 2 +print(f"string {a}") #: string 1 +print(F"{b} string") #: 2 string +print(f"str {a+b} end") #: str 3 end +print(f"str {a+b=}") #: str a+b=3 +c = f'and this is {a} followed by {b}' +print(c, f'{b}{a}', f'. {1+a=} .. {b} ...') +#: and this is 1 followed by 2 21 . 1+a=2 .. 2 ... + +#%% fstring_error_1,barebones +f"a{1 + 3}}" #! single '}' is not allowed in f-string + +#%% fstring_error_2,barebones +f"a{{1 + 3}" #! expecting '}' in f-string + +#%% string_prefix,barebones +@extend +class str: + def __prefix_pfx__(s: str, N: Static[int]): + return 'PFX ' + s +print(pfx'HELLO') #: PFX HELLO + +@extend +class str: + def __prefix_pxf__(s: str, N: Static[int]): + return 'PXF ' + s + " " + str(N) +print(pxf'HELLO') #: PXF HELLO 5 + +#%% string_raw,barebones +print('a\\b') #: a\b +print(r'a\tb') #: a\tb +print(R'\n\r\t\\') #: \n\r\t\\ + +#%% string_format +a = 'xyz' +print(f"{a:>10}") +#: xyz +print(f"{a!r:>10}") +#: 'xyz' +print(f"{a=!r:>10}") +#: a= 'xyz' +print(f"{a=}") +#: a=xyz +print(f"{a=:>10}") +#: a= xyz +print(f"{a!r}") +#: 'xyz' +print(f'{1000000=:,}') +#: 1000000=1,000,000 +print(f"{'':=<30}") +#: ============================== +print(f'{1000000:,}') +#: 1,000,000 diff --git a/test/parser/typecheck/test_call.codon b/test/parser/typecheck/test_call.codon new file mode 100644 index 00000000..377c35b0 --- /dev/null +++ b/test/parser/typecheck/test_call.codon @@ -0,0 +1,1010 @@ + +#%% call_ptr,barebones +v = 5 +p = __ptr__(v) +print p[0] #: 5 + +#%% call_ptr_error,barebones +__ptr__(1) #! __ptr__() only takes identifiers as arguments + +#%% call_ptr_error_3,barebones +v = 1 +__ptr__(v, 1) #! __ptr__() takes 1 arguments (2 given) + +#%% call_array,barebones +a = __array__[int](2) +a[0] = a[1] = 5 +print a[0], a[1] #: 5 5 + +#%% call_array_error,barebones +a = __array__[int](2, 3) #! __new__() takes 1 arguments (2 given) + +#%% call_err_1,barebones +seq_print(1, name="56", 2) #! positional argument follows keyword argument + +#%% call_err_2,barebones +x = (1, 2) +seq_print(1, name=*x) #! syntax error, unexpected '*' + +#%% call_err_3,barebones +x = (1, 2) +seq_print(1, name=**x) #! syntax error, unexpected '*' + +#%% call_collections +from collections import namedtuple as nt + +ee = nt('Foo', ('x', 'y')) +f = ee(1, 2) +print f #: (x: 1, y: 2) + +ee = nt('FooX', (('x', str), 'y')) +fd = ee('s', 2) +print fd #: (x: 's', y: 2) + +#%% call_partial_functools +from functools import partial +def foo(x, y, z): + print x,y,z +f1 = partial(foo, 1, z=3) +f1(2) #: 1 2 3 +f2 = partial(foo, y=2) +f2(1, 2) #: 1 2 2 + +#%% call,barebones +def foo(a, b, c='hi'): + print 'foo', a, b, c + return 1 +class Foo: + def __init__(self): + print 'Foo.__init__' + def foo(self, a): + print 'Foo.foo', a + return 's' + def bar[T](self, a: T): + print 'Foo.bar', a + return a.__class__.__name__ + def __call__(self, y): + print 'Foo.__call__' + return foo(2, y) + +foo(1, 2.2, True) #: foo 1 2.2 True +foo(1, 2.2) #: foo 1 2.2 hi +foo(b=2.2, a=1) #: foo 1 2.2 hi +foo(b=2.2, c=12u, a=1) #: foo 1 2.2 12 + +f = Foo() #: Foo.__init__ +print f.foo(a=5) #: Foo.foo 5 +#: s +print f.bar(a=1, T=int) #: Foo.bar 1 +#: int +print Foo.bar(Foo(), 1.1, T=float) #: Foo.__init__ +#: Foo.bar 1.1 +#: float +print Foo.bar(Foo(), 's') #: Foo.__init__ +#: Foo.bar s +#: str +print f('hahaha') #: Foo.__call__ +#: foo 2 hahaha hi +#: 1 + +@tuple +class Moo: + moo: int + def __new__(i: int) -> Moo: + print 'Moo.__new__' + return superf(i) +print Moo(1) #: Moo.__new__ +#: (moo: 1) + +#%% call_err_6,barebones +seq_print_full(1, name="56", name=2) #! keyword argument repeated: name + +#%% call_partial,barebones +def foo(i, j, k): + return i + j + k +print foo(1.1, 2.2, 3.3) #: 6.6 +p = foo(6, ...) +print p.__class__.__name__ #: foo[int,...,...] +print p(2, 1) #: 9 +print p(k=3, j=6) #: 15 +q = p(k=1, ...) +print q(3) #: 10 +qq = q(2, ...) +print qq() #: 9 +# +add_two = foo(3, k=-1, ...) +print add_two(42) #: 44 +print 3 |> foo(1, 2) #: 6 +print 42 |> add_two #: 44 +# +def moo(a, b, c=3): + print a, b, c +m = moo(b=2, ...) +print m.__class__.__name__ #: moo[...,int,...] +m('s', 1.1) #: s 2 1.1 +# # +n = m(c=2.2, ...) +print n.__class__.__name__ #: moo[...,int,float] +n('x') #: x 2 2.2 +print n('y').__class__.__name__ #: NoneType + +def ff(a, b, c): + return a, b, c +print ff(1.1, 2, True).__class__.__name__ #: Tuple[float,int,bool] +print ff(1.1, ...)(2, True).__class__.__name__ #: Tuple[float,int,bool] +y = ff(1.1, ...)(c=True, ...) +print y.__class__.__name__ #: ff[float,...,bool] +print ff(1.1, ...)(2, ...)(True).__class__.__name__ #: Tuple[float,int,bool] +print y('hei').__class__.__name__ #: Tuple[float,str,bool] +z = ff(1.1, ...)(c='s', ...) +print z.__class__.__name__ #: ff[float,...,str] + +def fx(*args, **kw): + print(args, kw) +f1 = fx(1, x=1, ...) +f2 = f1(2, y=2, ...) +f3 = f2(3, z=3, ...) +f3() +#: (1, 2, 3) (x: 1, y: 2, z: 3) + +#%% call_arguments_partial,barebones +def doo[R, T](a: Callable[[T], R], b: Generator[T], c: Optional[T], d: T): + print R.__class__.__name__, T.__class__.__name__ + print a.__class__.__name__[:8], b.__class__.__name__ + for i in b: + print a(i) + print c, c.__class__.__name__ + print d, d.__class__.__name__ + +l = [1, 2, 3] +doo(b=l, d=Optional(5), c=l[0], a=lambda x: x+1) +#: int int +#: %_lambda Generator[int] +#: 2 +#: 3 +#: 4 +#: 1 Optional[int] +#: 5 int + +l = [1] +def adder(a, b): return a+b +doo(b=l, d=Optional(5), c=l[0], a=adder(b=4, ...)) +#: int int +#: adder[.. Generator[int] +#: 5 +#: 1 Optional[int] +#: 5 int + +#%% call_partial_star,barebones +def foo(x, *args, **kwargs): + print x, args, kwargs +p = foo(...) +p(1, z=5) #: 1 () (z: 5) +p('s', zh=65) #: s () (zh: 65) +q = p(zh=43, ...) +q(1) #: 1 () (zh: 43) +r = q(5, 38, ...) +r() #: 5 (38,) (zh: 43) +r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1) + +#%% call_args_kwargs_type,barebones +def foo(*args: float, **kwargs: int): + print(args, kwargs, args.__class__.__name__) + +foo(1, f=1) #: (1,) (f: 1) Tuple[float] +foo(1, 2.1, 3, z=2) #: (1, 2.1, 3) (z: 2) Tuple[float,float,float] + +def sum(x: Generator[int]): + a = 0 + for i in x: + a += i + return a + +def sum_gens(*x: Generator[int]) -> int: + a = 0 + for i in x: + a += sum(i) + return a +print sum_gens([1, 2, 3]) #: 6 +print sum_gens({1, 2, 3}) #: 6 +print sum_gens(iter([1, 2, 3])) #: 6 + +#%% call_kwargs,barebones +def kwhatever(**kwargs): + print 'k', kwargs +def whatever(*args): + print 'a', args +def foo(a, b, c=1, *args, **kwargs): + print a, b, c, args, kwargs + whatever(a, b, *args, c) + kwhatever(x=1, **kwargs) +foo(1, 2, 3, 4, 5, arg1='s', kwa=2) +#: 1 2 3 (4, 5) (arg1: 's', kwa: 2) +#: a (1, 2, 4, 5, 3) +#: k (arg1: 's', kwa: 2, x: 1) +foo(1, 2) +#: 1 2 1 () () +#: a (1, 2, 1) +#: k (x: 1) +foo(1, 2, 3) +#: 1 2 3 () () +#: a (1, 2, 3) +#: k (x: 1) +foo(1, 2, 3, 4) +#: 1 2 3 (4,) () +#: a (1, 2, 4, 3) +#: k (x: 1) +foo(1, 2, zamboni=3) +#: 1 2 1 () (zamboni: 3) +#: a (1, 2, 1) +#: k (x: 1, zamboni: 3) + +#%% call_unpack,barebones +def foo(*args, **kwargs): + print args, kwargs + +@tuple +class Foo: + x: int = 5 + y: bool = True + +t = (1, 's') +f = Foo(6) +foo(*t, **f) #: (1, 's') (x: 6, y: True) +foo(*(1,2)) #: (1, 2) () +foo(3, f) #: (3, (x: 6, y: True)) () +foo(k = 3, **f) #: () (k: 3, x: 6, y: True) + +#%% call_partial_args_kwargs,barebones +def foo(*args): + print(args) +a = foo(1, 2, ...) +b = a(3, 4, ...) +c = b(5, ...) +c('zooooo') +#: (1, 2, 3, 4, 5, 'zooooo') + +def fox(*args, **kwargs): + print(args, kwargs) +xa = fox(1, 2, x=5, ...) +xb = xa(3, 4, q=6, ...) +xc = xb(5, ...) +xd = xc(z=5.1, ...) +xd('zooooo', w='lele') +#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele') + +class Foo: + i: int + def __str__(self): + return f'#{self.i}' + def foo(self, a): + return f'{self}:generic' + def foo(self, a: float): + return f'{self}:float' + def foo(self, a: int): + return f'{self}:int' +f = Foo(4) + +def pacman(x, f): + print f(x, '5') + print f(x, 2.1) + print f(x, 4) +pacman(f, Foo.foo) +#: #4:generic +#: #4:float +#: #4:int + +def macman(f): + print f('5') + print f(2.1) + print f(4) +macman(f.foo) +#: #4:generic +#: #4:float +#: #4:int + +class Fox: + i: int + def __str__(self): + return f'#{self.i}' + def foo(self, a, b): + return f'{self}:generic b={b}' + def foo(self, a: float, c): + return f'{self}:float, c={c}' + def foo(self, a: int): + return f'{self}:int' + def foo(self, a: int, z, q): + return f'{self}:int z={z} q={q}' +ff = Fox(5) +def maxman(f): + print f('5', b=1) + print f(2.1, 3) + print f(4) + print f(5, 1, q=3) +maxman(ff.foo) +#: #5:generic b=1 +#: #5:float, c=3 +#: #5:int +#: #5:int z=1 q=3 + + +#%% call_static,barebones +print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool) +#: True True False +print isinstance((1, 2), Tuple), isinstance((1, 2), Tuple[int, int]), isinstance((1, 2), Tuple[float, int]) +#: True True False +print isinstance([1, 2], List), isinstance([1, 2], List[int]), isinstance([1, 2], List[float]) +#: True True False +print isinstance({1, 2}, List), isinstance({1, 2}, Set[float]) +#: False False +print isinstance(Optional(5), Optional[int]), isinstance(Optional(), Optional) +#: True True +print isinstance(Optional(), Optional[int]), isinstance(Optional('s'), Optional[int]) +#: False False +print isinstance(None, Optional), isinstance(None, Optional[int]) +#: True False +print isinstance(None, Optional[NoneType]) +#: True +print isinstance({1, 2}, List) +#: False + +print staticlen((1, 2, 3)), staticlen((1, )), staticlen('hehe') +#: 3 1 4 + +print hasattr([1, 2], "__getitem__") +#: True +print hasattr(type([1, 2]), "__getitem__") +#: True +print hasattr(int, "__getitem__") +#: False +print hasattr([1, 2], "__getitem__") #: True +print hasattr([1, 2], "__getitem__", int) #: True +print hasattr([1, 2], "__getitem__", str) #: False +print hasattr([1, 2], "__getitem__", idx=int) #: True +print hasattr([1, 2], "__getitem__", idx=str) #: False + + +#%% isinstance_inheritance,barebones +class AX[T]: + a: T + def __init__(self, a: T): + self.a = a +class Side: + def __init__(self): + pass +class BX[T,U](Static[AX[T]], Static[Side]): + b: U + def __init__(self, a: T, b: U): + super().__init__(a) + self.b = b +class CX[T,U](Static[BX[T,U]]): + c: int + def __init__(self, a: T, b: U): + super().__init__(a, b) + self.c = 1 +c = CX('a', False) +print isinstance(c, CX), isinstance(c, BX), isinstance(c, AX), isinstance(c, Side) +#: True True True True +print isinstance(c, BX[str, bool]), isinstance(c, BX[str, str]), isinstance(c, AX[int]) +#: True False False + +#%% staticlen_err,barebones +print staticlen([1, 2]) #! expected tuple type + +#%% compile_error,barebones +compile_error("woo-hoo") #! woo-hoo + +#%% stack_alloc,barebones +a = __array__[int](2) +print a.__class__.__name__ #: Array[int] + +#%% typeof,barebones +a = 5 +z = [] +z.append(6) +print z.__class__.__name__, z, type(1.1).__class__.__name__ #: List[int] [6] float + +#%% ptr,barebones +v = 5 +c = __ptr__(v) +print c.__class__.__name__ #: Ptr[int] + +#%% tuple_fn,barebones +@tuple +class unpackable_plain: + a: int + b: str + +u = unpackable_plain(1, 'str') +a, b = tuple(u) +print a, b #: 1 str + +@tuple +class unpackable_gen: + a: int + b: T + T: type + +u2 = unpackable_gen(1, 'str') +a2, b2 = tuple(u2) +print a2,b2 #: 1 str + +class plain: + a: int + b: str + +c = plain(3, 'heh') +z = tuple(c) +print z, z.__class__.__name__ #: (3, 'heh') Tuple[int,str] + +#%% super,barebones +class A[T]: + a: T + def __init__(self, t: T): + self.a = t + def foo(self): + return f'A:{self.a}' +class B(Static[A[str]]): + b: int + def __init__(self): + super().__init__('s') + self.b = 6 + def baz(self): + return f'{super().foo()}::{self.b}' +b = B() +print b.foo() #: A:s +print b.baz() #: A:s::6 + +class AX[T]: + a: T + def __init__(self, a: T): + self.a = a + def foo(self): + return f'[AX:{self.a}]' +class BX[T,U](Static[AX[T]]): + b: U + def __init__(self, a: T, b: U): + print super().__class__.__name__ + super().__init__(a) + self.b = b + def foo(self): + return f'[BX:{super().foo()}:{self.b}]' +class CX[T,U](Static[BX[T,U]]): + c: int + def __init__(self, a: T, b: U): + print super().__class__.__name__ + super().__init__(a, b) + self.c = 1 + def foo(self): + return f'CX:{super().foo()}:{self.c}' +c = CX('a', False) +print c.__class__.__name__, c.foo() +#: BX[str,bool] +#: AX[str] +#: CX[str,bool] CX:[BX:[AX:a]:False]:1 + +#%% super_vtable_2 +class Base: + def test(self): + print('base.test') +class A(Base): + def test(self): + super().test() + Base.test(self) + print('a.test') +a = A() +a.test() +def moo(x: Base): + x.test() +moo(a) +Base.test(a) +#: base.test +#: base.test +#: a.test +#: base.test +#: base.test +#: a.test +#: base.test + +#%% super_tuple,barebones +@tuple +class A[T]: + a: T + x: int + def __new__(a: T) -> A[T]: + return A[T](a, 1) + def foo(self): + return f'A:{self.a}' +@tuple +class B(Static[A[str]]): + b: int + def __new__() -> B: + return B(*(A('s')), 6) + def baz(self): + return f'{super().foo()}::{self.b}' + +b = B() +print b.foo() #: A:s +print b.baz() #: A:s::6 + + +#%% super_error,barebones +class A: + def __init__(self): + super().__init__() +a = A() +#! no super methods found +#! during the realization of __init__(self: A) + +#%% super_error_2,barebones +super().foo(1) #! no super methods found + +#%% superf,barebones +class Foo: + def foo(a): + # superf(a) + print 'foo-1', a + def foo(a: int): + superf(a) + print 'foo-2', a + def foo(a: str): + superf(a) + print 'foo-3', a + def foo(a): + superf(a) + print 'foo-4', a +Foo.foo(1) +#: foo-1 1 +#: foo-2 1 +#: foo-4 1 + +class Bear: + def woof(x): + return f'bear woof {x}' +@extend +class Bear: + def woof(x): + return superf(x) + f' bear w--f {x}' +print Bear.woof('!') +#: bear woof ! bear w--f ! + +class PolarBear(Static[Bear]): + def woof(): + return 'polar ' + superf('@') +print PolarBear.woof() +#: polar bear woof @ bear w--f @ + +#%% superf_error,barebones +class Foo: + def foo(a): + superf(a) + print 'foo-1', a +Foo.foo(1) +#! no superf methods found +#! during the realization of foo(a: int) + +#%% static_getitem +print Int[staticlen("ee")].__class__.__name__ #: Int[2] + +y = [1, 2] +print getattr(y, "len") #: 2 +print y.len #: 2 +getattr(y, 'append')(1) +print y #: [1, 2, 1] + +@extend +class Dict: + def __getitem2__(self, attr: Static[str]): + if hasattr(self, attr): + return getattr(self, attr) + else: + return self[attr] + def __getitem1__(self, attr: Static[int]): + return self[attr] + +d = {'s': 3.19} +print d.__getitem2__('_upper_bound') #: 3 +print d.__getitem2__('s') #: 3.19 +e = {1: 3.33} +print e.__getitem1__(1) #: 3.33 + +#%% forward,barebones +def foo(f, x): + f(x, type(x)) + print f.__class__.__name__ +def bar[T](x): + print x, T.__class__.__name__ +foo(bar, 1) +#: 1 int +#: bar[...] +foo(bar(...), 's') +#: s str +#: bar[...] +z = bar +z('s', int) +#: s int +z(1, T=str) +#: 1 str + +zz = bar(T=int,...) +zz(1) +#: 1 int + +#%% forward_error,barebones +def foo(f, x): + f(x, type(x)) + print f.__class__.__name__ +def bar[T](x): + print x, T.__class__.__name__ +foo(bar(T=int,...), 1) +#! bar() takes 2 arguments (2 given) +#! during the realization of foo(f: bar[...], x: int) +# TODO fix this error message + +#%% sort_partial +def foo(x, y): + return y**x +print sorted([1,2,3,4,5], key=foo(y=2, ...)) +print sorted([1,2,3,4,5], key=foo(y=-2, ...)) +#: [1, 2, 3, 4, 5] +#: [5, 3, 1, 2, 4] + +#%% type_loc,barebones +a = 1 +T = type(a) +print T.__class__.__name__ #: int + +#%% methodcaller,barebones +def foo(): + def bar(a, b): + print 'bar', a, b + return bar +foo()(1, 2) #: bar 1 2 + +def methodcaller(foo: Static[str]): + def caller(foo: Static[str], obj, *args, **kwargs): + if isinstance(getattr(obj, foo)(*args, **kwargs), None): + getattr(obj, foo)(*args, **kwargs) + else: + return getattr(obj, foo)(*args, **kwargs) + return caller(foo=foo, ...) +v = [1] +methodcaller('append')(v, 42) +print v #: [1, 42] +print methodcaller('index')(v, 42) #: 1 + +#%% constructor_passing +class A: + s: str + def __init__(self, x): + self.s = str(x)[::-1] + def __lt__(self, o): return self.s < o.s + def __eq__(self, o): return self.s == o.s + def __ge__(self, o): return self.s >= o.s +foo = [1,2,11,30] +print(sorted(foo, key=str)) +#: [1, 11, 2, 30] +print(sorted(foo, key=A)) +#: [30, 1, 11, 2] + +@tuple +class AT: + s: str + def __new__(i: int) -> AT: return AT(str(i)) +print(sorted(foo, key=AT)) +#: [1, 11, 2, 30] + +#%% polymorphism,barebones +class A: + a: int + def foo(self, a: int): return (f'A({self.a})', a) + def bar(self): return 'A.bar' + def aaz(self): return 'A.aaz' +class B(A): + b: int + def foo(self, a): return (f'B({self.a},{self.b})', a + self.b) + def bar(self): return 'B.bar' + def baz(self): return 'B.baz' +class M[T]: + m: T + def moo(self): return (f'M_{T.__class__.__name__}', self.m) +class X(B,M[int]): + def foo(self, a): return (f'X({self.a},{self.b},{self.m})', a + self.b + self.m) + def bar(self): return 'X.bar' + +def foo(i): + x = i.foo(1) + y = i.bar() + z = i.aaz() + print(*x, y, z) +a = A(1) +l = [a, B(2,3), X(2,3,-1)] +for i in l: foo(i) +#: A(1) 1 A.bar A.aaz +#: B(2,3) 4 B.bar A.aaz +#: X(2,3,-1) 3 X.bar A.aaz + +def moo(m: M): + print(m.moo()) +moo(M[float](5.5)) +moo(X(1,2,3)) +#: ('M_float', 5.5) +#: ('M_int', 3) + + +class A[T]: + def __init__(self): + print("init A", T.__class__.__name__) +class Ho: + def __init__(self): + print("init Ho") +# TODO: this throws and error: B[U](U) +class B[U](A[U], Ho): + def __init__(self): + super().__init__() + print("init B", U.__class__.__name__) +B[Ho]() +#: init A Ho +#: init B Ho + + +class Vehicle: + def drive(self): + return "I'm driving a vehicle" + +class Car(Vehicle): + def drive(self): + return "I'm driving a car" + +class Truck(Vehicle): + def drive(self): + return "I'm driving a truck" + +class SUV(Car, Truck): + def drive(self): + return "I'm driving an SUV" + +suv = SUV() +def moo(s): + print(s.drive()) +moo(suv) +moo(Truck()) +moo(Car()) +moo(Vehicle()) +#: I'm driving an SUV +#: I'm driving a truck +#: I'm driving a car +#: I'm driving a vehicle + + +#%% polymorphism_error_1,barebones +class M[T]: + m: T +class X(M[int]): + pass +l = [M[float](1.1), X(2)] +#! 'X' does not match expected type 'M[float]' + +#%% polymorphism_2 +class Expr: + def __init__(self): + pass + def eval(self): + raise ValueError('invalid expr') + return 0.0 + def __str__(self): + return "Expr" +class Const(Expr): + x: float + def __init__(self, x): + self.x=x + def __str__(self): + return f"{self.x}" + def eval(self): + return self.x +class Add(Expr): + lhs: Expr + rhs: Expr + def __init__(self, lhs, rhs): + self.lhs=lhs + self.rhs=rhs + # print(f'ctr: {self}') + def eval(self): + return self.lhs.eval()+self.rhs.eval() + def __str__(self): + return f"({self.lhs}) + ({self.rhs})" +class Mul(Expr): + lhs: Expr + rhs: Expr + def __init__(self, lhs, rhs): + self.lhs=lhs + self.rhs=rhs + def eval(self): + return self.lhs.eval()*self.rhs.eval() + def __str__(self): + return f"({self.lhs}) * ({self.rhs})" + +c1 = Const(5) +c2 = Const(4) +m = Add(c1, c2) +c3 = Const(2) +a : Expr = Mul(m, c3) +print(f'{a} = {a.eval()}') +#: ((5) + (4)) * (2) = 18 + +from random import random, seed +seed(137) +def random_expr(depth) -> Expr: + if depth<=0: + return Const(int(random()*42.0)) + else: + lhs=random_expr(depth-1) + rhs=random_expr(depth-1) + ctorid = int(random()*3) + if ctorid==0: + return Mul(lhs,rhs) + else: + return Add(lhs,rhs) +for i in range(11): + print(random_expr(i).eval()) +#: 17 +#: 71 +#: 1760 +#: 118440 +#: 94442 +#: 8.02435e+15 +#: 1.07463e+13 +#: 1.43017e+19 +#: 2.40292e+34 +#: 6.1307e+28 +#: 5.16611e+49 + +#%% polymorphism_3 +import operator + +class Expr: + def eval(self): + return 0 + +class Const(Expr): + value: int + + def __init__(self, value): + self.value = value + + def eval(self): + return self.value + +class BinOp(Expr): + lhs: Expr + rhs: Expr + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def eval_from_fn(self, fn): + return fn(self.lhs.eval(), self.rhs.eval()) + +class Add(BinOp): + def eval(self): + return self.eval_from_fn(operator.add) + +class Sub(BinOp): + def eval(self): + return self.eval_from_fn(operator.sub) + +class Mul(BinOp): + def eval(self): + return self.eval_from_fn(operator.mul) + +class Div(BinOp): + def eval(self): + return self.eval_from_fn(operator.floordiv) + +# TODO: remove Expr requirement +expr : Expr = Mul(Const(3), Add(Const(10), Const(5))) +print(expr.eval()) #: 45 + +#%% polymorphism_4 +class A(object): + a: int + def __init__(self, a: int): + self.a = a + + def test_a(self, n: int): + print("test_a:A", n) + + def test(self, n: int): + print("test:A", n) + + def test2(self, n: int): + print("test2:A", n) + +class B(A): + b: int + def __init__(self, a: int, b: int): + super().__init__(a) + self.b = b + + def test(self, n: int): + print("test:B", n) + + def test2(self, n: int): + print("test2:B", n) + +class C(B): + pass + +b = B(1, 2) +b.test_a(1) +b.test(1) +#: test_a:A 1 +#: test:B 1 + +a: A = b +a.test(1) +a.test2(2) +#: test:B 1 +#: test2:B 2 + + + +class AX(object): + value: u64 + + def __init__(self): + print('init/AX') + self.value = 15u64 + + def get_value(self) -> u64: + return self.value + +class BX(object): + a: AX + def __init__(self): + print('init/BX') + self.a = AX() + def hai(self): + return f"hai/BX: {self.a.value}" + +class CX(BX): + def __init__(self): + print('init/CX') + super().__init__() + + def getsuper(self): + return super() + + def test(self): + print('test/CX:', self.a.value) + return self.a.get_value() + + def hai(self): + return f"hai/CX: {self.a.value}" + +table = CX() +#: init/CX +#: init/BX +#: init/AX +print table.test() +#: test/CX: 15 +#: 15 + +s = table.getsuper() +print(s.hai()) +#: hai/BX: 15 +s.a.value += 1u64 +print(s.hai()) +#: hai/BX: 16 +table.a.value += 1u64 +print(s.hai()) +#: hai/BX: 17 +table.test() +#: test/CX: 17 + +c: List[BX] = [s, table] +print(c[0].hai()) #: hai/BX: 17 +print(c[1].hai()) #: hai/CX: 17 + + diff --git a/test/parser/typecheck/test_class.codon b/test/parser/typecheck/test_class.codon new file mode 100644 index 00000000..b2a80e18 --- /dev/null +++ b/test/parser/typecheck/test_class.codon @@ -0,0 +1,513 @@ +#%% class_err_1,barebones +@extend +@foo +class Foo: + pass +#! cannot combine '@extend' with other attributes or decorators + +#%% class_extend_typedef,barebones +size_t = i32 +@extend +class size_t: + def foo(self): + return f'Int{N}.foo.{self}' + +print size_t(1).foo() #: Int32.foo.1 +print Int[64](2).foo() #: Int64.foo.2 + +#%% class_err_2,barebones +def foo(): + @extend + class Foo: + pass +foo() +#! class extension must be a top-level statement +#! during the realization of foo() + +#%% class_nested,barebones +class Foo: + foo: int + class Bar: + bar: int + b: Optional[Foo.Bar] # TODO: allow this ONLY in type annotations + c: Optional[int] + class Moo: + # TODO: allow nested class reference to the upclass + # x: Foo.Bar + x: int +y = Foo(1) +z = Foo.Bar(2, None, 4) +m = Foo.Bar.Moo(5) +print y.foo #: 1 +print z.bar, z.b.__bool__(), z.c, m.x #: 2 False 4 5 + +#%% class_nested_2,barebones +@tuple +class Foo: + @tuple + class Bar: + x: int + x: int + b: Bar + c: Foo.Bar +f = Foo(5, Foo.Bar(6), Foo.Bar(7)) +print(f) #: (x: 5, b: (x: 6), c: (x: 7)) + +#%% class_nested_err,barebones +class Foo: + class Bar: + b: Ptr[Bar] +#! name 'Bar' is not defined + +#%% class_err_4,barebones +@extend +class Foo: + pass +#! class name 'Foo' is not defined + +#%% class_err_5,barebones +class Foo[T, U]: + pass +@extend +class Foo[T]: + pass +#! class extensions cannot define data attributes and generics or inherit other classes + +#%% class_err_7,barebones +class Foo: + a: int + a: int +#! duplicate data attribute 'a' in class definition + +#%% class_err_tuple_no_recursive,barebones +@tuple +class Foo: + a: Foo +#! name 'Foo' is not defined + +#%% class_err_8,barebones +class Foo: + while 0: pass +#! unexpected expression in class definition + +#%% class_err_9,barebones +class F[T: Static[float]]: + pass +#! expected 'int', 'bool' or 'str' + +#%% class_err_11,barebones +def foo(x): + class A: + def bar(): + print x +foo(1) +#! name 'x' cannot be captured +#! during the realization + +#%% class_err_12,barebones +def foo(x): + T = type(x) + class A: + def bar(): + print T() + A.bar() +foo(1) +#! name 'T' cannot be captured +#! during the realization of foo + +#%% recursive_class,barebones +class Node[T]: + data: T + children: List[Node[T]] + def __init__(self, data: T): + self.data = data + self.children = [] +print Node(2).data #: 2 + +class Node2: + data: int + children: List[Node2] + def __init__(self, data: int): + self.data = data + self.children = [] +print Node2(3).data #: 3 + +#%% class_auto_init,barebones +class X[T]: + a: int = 4 + b: int = 0 + c: T = T() + d: str = 'oops' + def __str__(self): + return f'X({self.a},{self.b},{self.c},{self.d})' +x = X[float]() +print x #: X(4,0,0,oops) +y = X(c='darius',a=5) +print y #: X(5,0,darius,oops) + +#%% magic,barebones +@tuple +class Foo: + x: int + y: int +a, b = Foo(1, 2), Foo(1, 3) +print a, b #: (x: 1, y: 2) (x: 1, y: 3) +print a.__len__() #: 2 +print a.__hash__(), b.__hash__() #: 175247769363 175247769360 +print a == a, a == b #: True False +print a != a, a != b #: False True +print a < a, a < b, b < a #: False True False +print a <= a, a <= b, b <= a #: True True False +print a > a, a > b, b > a #: False False True +print a >= a, a >= b, b >= a #: True False True +print a.__getitem__(1) #: 2 +print list(a.__iter__()) #: [1, 2] + +#%% magic_class,barebones +@dataclass(eq=True, order=True) +class Foo: + x: int + y: int + def __str__(self): return f'{self.x}_{self.y}' +a, b = Foo(1, 2), Foo(1, 3) +print a, b #: 1_2 1_3 +print a == a, a == b #: True False +print a != a, a != b #: False True +print a < a, a < b, b < a #: False True False +print a <= a, a <= b, b <= a #: True True False +print a > a, a > b, b > a #: False False True +print a >= a, a >= b, b >= a #: True False True + +# Right magic test +class X: + x: int +class Y: + y: int + def __eq__(self, o: X): return self.y == o.x + def __ne__(self, o: X): return self.y != o.x + def __le__(self, o: X): return self.y <= o.x + def __lt__(self, o: X): return self.y < o.x + def __ge__(self, o: X): return self.y >= o.x + def __gt__(self, o: X): return self.y > o.x + def __add__(self, o: X): return self.y + o.x + 1 + def __radd__(self, o: X): return self.y + o.x + 2 +print Y(1) == X(1), Y(1) != X(1) #: True False +print X(1) == Y(1), X(1) != Y(1) #: True False +print Y(1) <= X(2), Y(1) < X(2) #: True True +print X(1) <= Y(2), X(1) < Y(2) #: True True +print Y(1) >= X(2), Y(1) > X(2) #: False False +print X(1) >= Y(2), X(1) > Y(2) #: False False +print X(1) + Y(2) #: 5 +print Y(1) + X(2) #: 4 + + +class A: + def __radd__(self, n: int): + return 0 +def f(): + print('f') + return 1 +def g(): + print('g') + return A() +f() + g() +#: f +#: g + +#%% magic_2,barebones +@tuple +class Foo: + pass +a, b = Foo(), Foo() +print a, b #: () () +print a.__len__() #: 0 +print a.__hash__(), b.__hash__() #: 0 0 +print a == a, a == b #: True True +print a != a, a != b #: False False +print a < a, a < b, b < a #: False False False +print a <= a, a <= b, b <= a #: True True True +print a > a, a > b, b > a #: False False False +print a >= a, a >= b, b >= a #: True True True + +# TODO: pickle / to_py / from_py + +#%% magic_contains,barebones +sponge = (1, 'z', 1.55, 'q', 48556) +print 1.1 in sponge #: False +print 'q' in sponge #: True +print True in sponge #: False + +bob = (1, 2, 3) +print 1.1 in sponge #: False +print 1 in sponge #: True +print 0 in sponge #: False + +#%% magic_err_2,barebones +@tuple +class Foo: + pass +try: + print Foo().__getitem__(1) +except IndexError: + print 'error' #: error + +#%% magic_empty_tuple,barebones +@tuple +class Foo: + pass +print list(Foo().__iter__()) #: [] + +#%% magic_err_4,barebones +@tuple(eq=False) +class Foo: + x: int +Foo(1).__eq__(Foo(1)) #! 'Foo' object has no attribute '__eq__' + +#%% magic_err_5,barebones +@tuple(pickle=False) +class Foo: + x: int +p = Ptr[byte]() +Foo(1).__pickle__(p) #! 'Foo' object has no attribute '__pickle__' + +#%% magic_err_6,barebones +@tuple(container=False) +class Foo: + x: int +Foo(1).__getitem__(0) #! 'Foo' object has no attribute '__getitem__' + +#%% magic_err_7,barebones +@tuple(python=False) +class Foo: + x: int +p = Ptr[byte]() +Foo(1).__to_py__(p) #! 'Foo' object has no attribute '__to_py__' + +#%% inherit_class_4,barebones +class defdict[K,V](Static[Dict[K,V]]): + fx: Function[[],V] + def __init__(self, d: Dict[K,V], fx: Function[[], V]): + self.__init__() + for k,v in d.items(): self[k] = v + self.fx = fx + def __getitem__(self, key: K) -> V: + if key in self: + return self.values[self.keys.index(key)] + else: + self[key] = self.fx() + return self[key] +z = defdict({'ha':1}, lambda: -1) +print z +print z['he'] +print z +#: {'ha': 1} +#: -1 +#: {'ha': 1, 'he': -1} + +class Foo: + x: int = 0 + def foo(self): + return f'foo {self.x}' +class Bar[T]: + y: T = T() + def bar(self): + return f'bar {self.y}/{self.y.__class__.__name__}' +class FooBarBaz[T](Static[Foo], Static[Bar[T]]): + def baz(self): + return f'baz! {self.foo()} {self.bar()}' +print FooBarBaz[str]().foo() #: foo 0 +print FooBarBaz[float]().bar() #: bar 0/float +print FooBarBaz[str]().baz() #: baz! foo 0 bar /str + +#%% inherit_class_err_5,barebones +class defdict(Static[Dict[str,float]]): + def __init__(self, d: Dict[str, float]): + self.__init__(d.items()) +z = defdict() +z[1.1] #! 'float' does not match expected type 'str' + +#%% inherit_tuple,barebones +class Foo: + a: int + b: str + def __init__(self, a: int): + self.a, self.b = a, 'yoo' +@tuple +class FooTup(Static[Foo]): pass + +f = Foo(5) +print f.a, f.b #: 5 yoo +fp = FooTup(6, 's') +print fp #: (a: 6, b: 's') + +#%% inherit_class_err_1,barebones +class defdict(Static[Array[int]]): + pass #! reference classes cannot inherit tuple classes + +#%% inherit_class_err_2,barebones +@tuple +class defdict(Static[int]): + pass #! internal classes cannot inherit other classes + +#%% inherit_class_err_3,barebones +class defdict(Static[Dict[int, float, float]]): + pass #! Dict takes 2 generics (3 given) + +#%% inherit_class_err_4,barebones +class Foo: + x: int +class Bar: + x: float +class FooBar(Static[Foo], Static[Bar]): + pass +# right now works as we rename other fields + + +#%% class_deduce,barebones +@deduce +class Foo: + def __init__(self, x): + self.x = [x] + self.y = 1, x + +f = Foo(1) +print(f.x, f.y, f.__class__.__name__) #: [1] (1, 1) Foo[List[int],Tuple[int,int]] + +f: Foo = Foo('s') +print(f.x, f.y, f.__class__.__name__) #: ['s'] (1, 's') Foo[List[str],Tuple[int,str]] + +@deduce +class Bar: + def __init__(self, y: float): + self.y = Foo(y) + def __init__(self, y: str): + self.x = Foo(y) + +b = Bar(3.1) +print(b.x.__class__.__name__, b.y.__class__.__name__, b.y.x, b.__class__.__name__) +#: NoneType Foo[List[float],Tuple[int,float]] [3.1] Bar[NoneType,Foo[List[float],Tuple[int,float]]] +b = Bar('3.1') +print(b.x.__class__.__name__, b.y.__class__.__name__, b.x.x, b.__class__.__name__) +#: Foo[List[str],Tuple[int,str]] NoneType ['3.1'] Bar[Foo[List[str],Tuple[int,str]],NoneType] + +#%% class_var,barebones +class Foo: + cx = 15 + x: int = 10 + cy: ClassVar[str] = "ho" + class Bar: + bx = 1.1 +print(Foo.cx) #: 15 +f = Foo() +print(Foo.cy, f.cy) #: ho ho +print(Foo.Bar.bx) #: 1.1 + +Foo.cx = 10 +print(Foo.cx) #: 10 + +def x(): + class Foo: + i = 0 + f = Foo() + def __init__(self): + Foo.i += 1 + def __repr__(self): + return 'heh-cls' + Foo(), Foo(), Foo() + print Foo.f, Foo.i #: heh-cls 4 + return Foo() +f = x() +print f.f, f.i #: heh-cls 5 + +@tuple +class Fot: + f = Fot() + def __repr__(self): + return 'heh-tup' +print Fot.f #: heh-tup + +#%% extend,barebones +@extend +class int: + def run_lola_run(self): + while self > 0: + yield self + self -= 1 +print list((5).run_lola_run()) #: [5, 4, 3, 2, 1] + +#%% staticmethod,barebones +class Foo: + def __repr__(self): + return 'Foo' + def m(self): + print 'm', self + @staticmethod + def sm(i): + print 'sm', i +Foo.sm(1) #: sm 1 +Foo().sm(2) #: sm 2 +Foo().m() #: m Foo + +#%% class_setter,barebones +class Foo: + _x: int + + @property + def x(self): + print('getter') + return self._x + + @x.setter + def x(self, v): + print('setter') + self._x = v + +f = Foo(1) +print(f.x) +#: getter +#: 1 + +f.x = 99 +print(f.x) +print(f._x) +#: setter +#: getter +#: 99 +#: 99 + +#%% inherit_surrounding,barebones +# Fix 354 +class A: + pass +class B: + class C(B): pass +#! nested classes cannot inherit surrounding classes + +#%% inherit_no_member_middle,barebones +# Fix #532 +class A: + _map: Dict[str, str] + def __init__(self): + self._map = Dict[str, str]() +class B(A): + def __init__(self): + super().__init__() +class C(B): + placeholder: str + def __init__(self): + super().__init__() +test = C() + +#%% inherit_optional,barebones +# Fix 554 +class A: + pass +class B(A): + pass +def foo(val: Optional[A]): + if val: + print("A") + else: + print("None[A]") +foo(A()) #: A +foo(None) #: None[A] \ No newline at end of file diff --git a/test/parser/typecheck/test_collections.codon b/test/parser/typecheck/test_collections.codon new file mode 100644 index 00000000..5327f836 --- /dev/null +++ b/test/parser/typecheck/test_collections.codon @@ -0,0 +1,174 @@ +#%% list_unbound,barebones +a = [] +#! cannot typecheck +#! cannot typecheck + +#%% star_err,barebones +a = (1, 2, 3) +z = *a #! unexpected star expression + +#%% list,barebones +a = [4, 5, 6] +print a #: [4, 5, 6] +b = [1, 2, 3, *a] +print b #: [1, 2, 3, 4, 5, 6] + +#%% set,barebones +gs = {1.12} +print gs #: {1.12} +fs = {1, 2, 3, 1, 2, 3} +gs.add(1.12) +gs.add(1.13) +print fs, gs #: {1, 2, 3} {1.12, 1.13} +print {*fs, 5, *fs} #: {1, 2, 3, 5} + +#%% dict,barebones +gd = {1: 'jedan', 2: 'dva', 2: 'two', 3: 'tri'} +fd = {} +fd['jedan'] = 1 +fd['dva'] = 2 +print gd, fd #: {1: 'jedan', 2: 'two', 3: 'tri'} {'jedan': 1, 'dva': 2} + + + +#%% comprehension,barebones +l = [(i, j, f'i{i}/{j}') + for i in range(50) if i % 2 == 0 if i % 3 == 0 + for j in range(2) if j == 1] +print l #: [(0, 1, 'i0/1'), (6, 1, 'i6/1'), (12, 1, 'i12/1'), (18, 1, 'i18/1'), (24, 1, 'i24/1'), (30, 1, 'i30/1'), (36, 1, 'i36/1'), (42, 1, 'i42/1'), (48, 1, 'i48/1')] + +s = {i%3 for i in range(20)} +print s #: {0, 1, 2} + +d = {i: j for i in range(10) if i < 1 for j in range(10)} +print d #: {0: 9} + +t = 's' +x = {t: lambda x: x * t for t in range(5)} +print(x[3](10)) #: 40 +print(t) #: s + +#%% comprehension_opt,barebones +@extend +class List: + def __init__(self, cap: int): + print 'optimize', cap + self.arr = Array[T](cap) + self.len = 0 +def foo(): + yield 0 + yield 1 + yield 2 +print [i for i in range(3)] #: optimize 3 +#: [0, 1, 2] +print [i for i in foo()] #: [0, 1, 2] +print [i for i in range(3) if i%2 == 0] #: [0, 2] +print [i + j for i in range(1) for j in range(1)] #: [0] +print {i for i in range(3)} #: {0, 1, 2} + +#%% generator,barebones +z = 3 +g = (e for e in range(20) if e % z == 1) +print str(g)[:13] #: = a >= -5) #: True False + +#%% if_expr,barebones +c = 5 +a = 1 if c < 5 else 2 +b = -(1 if c else 2) +print a, b #: 2 -1 + + + +#%% range_err,barebones +1 ... 3 #! unexpected range expression + +#%% match +def foo(x): + match x: + case 1: + print 'int' + case 2 ... 10: + print 'range' + case 'ACGT': + print 'string' + case (a, 1): + print 'tuple_wild', a + case []: + print 'list' + case [[]]: + print 'list list' + case [1, 2]: + print 'list 2' + case [1, z, ...] if z < 5: + print 'list 3', z + case [1, _, ..., zz] | (1, zz): + print 'list 4', zz + case (1 ... 10, s := ('ACGT', 1 ... 4)): + print 'complex', s + case _: + print 'else' +foo(1) #: int +foo(5) #: range +foo('ACGT') #: string +foo((9, 1)) #: tuple_wild 9 +foo(List[int]()) #: list +foo([List[int]()]) #: list list +foo([1, 2]) #: list 2 +foo([1, 3]) #: list 3 3 +foo([1, 5]) #: else +foo([1, 5, 10]) #: list 4 10 +foo((1, 33)) #: list 4 33 +foo((9, ('ACGT', 3))) #: complex ('ACGT', 3) +foo(range(10)) #: else + +for op in 'MI=DXSN': + match op: + case 'M' | '=' | 'X': + print('case 1') + case 'I' or 'S': + print('case 2') + case _: + print('case 3') +#: case 1 +#: case 2 +#: case 1 +#: case 3 +#: case 1 +#: case 2 +#: case 3 + +#%% match_err_1,barebones +match [1, 2]: + case [1, ..., 2, ..., 3]: pass +#! multiple ellipses in a pattern + +#%% if_expr_2,barebones +y = 1 if True else 2 +print y.__class__.__name__ #: int + +a = None +b = 5 +z = a if bool(True) else b # needs bool to prevent static evaluation +print z, z.__class__.__name__ #: None Optional[int] + +zz = 1.11 if True else None +print zz, zz.__class__.__name__ #: 1.11 float + +#%% if,barebones +for a, b in [(1, 2), (3, 3), (5, 4)]: + if a > b: + print '1', + elif a == b: + print '=', + else: + print '2', +print '_' #: 2 = 1 _ + +if 1: + print '1' #: 1 + +#%% static_if,barebones +def foo(x, N: Static[int]): + if isinstance(x, int): + return x + 1 + elif isinstance(x, float): + return x.__pow__(.5) + elif isinstance(x, Tuple[int, str]): + return f'foo: {x[1]}' + elif isinstance(x, Tuple) and (N >= 3 or staticlen(x) > 2): + return x[2:] + elif hasattr(x, '__len__'): + return 'len ' + str(x.__len__()) + else: + compile_error('invalid type') +print foo(N=1, x=1) #: 2 +print foo(N=1, x=2.0) #: 1.41421 +print foo(N=1, x=(1, 'bar')) #: foo: bar +print foo(N=1, x=(1, 2)) #: len 2 +print foo(N=3, x=(1, 2)) #: () +print foo(N=1, x=(1, 2, 3)) #: (3,) + diff --git a/test/parser/typecheck/test_ctx.codon b/test/parser/typecheck/test_ctx.codon new file mode 100644 index 00000000..e69de29b diff --git a/test/parser/typecheck/test_error.codon b/test/parser/typecheck/test_error.codon new file mode 100644 index 00000000..d6d59381 --- /dev/null +++ b/test/parser/typecheck/test_error.codon @@ -0,0 +1,104 @@ +#%% assert,barebones +assert True +assert True, "blah" + +try: + assert False +except AssertionError as e: + print e.message[:15], e.message[-19:] #: Assert failed ( test_error.codon:6) + +try: + assert False, f"hehe {1}" +except AssertionError as e: + print e.message[:23], e.message[-20:] #: Assert failed: hehe 1 ( test_error.codon:11) + +#%% try_throw,barebones +class MyError(Static[Exception]): + def __init__(self, message: str): + super().__init__('MyError', message) +try: + raise MyError("hello!") +except MyError as e: + print str(e) #: hello! +try: + raise OSError("hello os!") +# TODO: except (MyError, OSError) as e: +# print str(e) +except MyError: + print "my" +except OSError as o: + print "os", o.typename, len(o.message), o.file[-16:], o.line + #: os OSError 9 test_error.codon 24 +finally: + print "whoa" #: whoa + +# Test function name +def foo(): + raise MyError("foo!") +try: + foo() +except MyError as e: + print e.typename, e.message #: MyError foo! + +#%% throw_error,barebones +raise 'hello' +#! exceptions must derive from BaseException + +#%% raise_from,barebones +def foo(bar): + try: + bar() + except ValueError as e: + raise RuntimeError("oops") from e + raise RuntimeError("oops") + +def bar1(): + raise ValueError("bar1") +try: + foo(bar1) +except RuntimeError as e: + print(e.message, e.__cause__) #: oops bar1 + +def bar2(): + raise ValueError("bar2") +try: + foo(bar2) +except RuntimeError as e: + print(e.message, e.__cause__) #: oops bar2 + +def bar3(): + pass +try: + foo(bar3) +except RuntimeError as e: + print(e.message, e.__cause__) #: oops None + +#%% try_else,barebones +def div(x, y): + if y == 0: raise ZeroDivisionError("oops!") + return x // y +def divide(x: int, y: int): + try: + result = div(x, y) + except ZeroDivisionError: + print("ZeroDivisionError") + else: + print(result) + finally: + print('Done!') + + try: + result = div(x, y) + except ZeroDivisionError: + print("ZeroDivisionError") + else: + print(result) + +divide(3, 2) +#: 1 +#: Done! +#: 1 +divide(3, 0) +#: ZeroDivisionError +#: Done! +#: ZeroDivisionError diff --git a/test/parser/typecheck/test_function.codon b/test/parser/typecheck/test_function.codon new file mode 100644 index 00000000..74aa90dd --- /dev/null +++ b/test/parser/typecheck/test_function.codon @@ -0,0 +1,690 @@ + +#%% lambda,barebones +l = lambda a, b: a + b +print l(1, 2) #: 3 + +e = 5 +lp = lambda x: x + e +print lp(1) #: 6 + +e = 7 +print lp(2) #: 9 + +def foo[T](a: T, l: Callable[[T], T]): + return l(a) +print foo(4, lp) #: 11 + +def foox(a, l): + return l(a) +print foox(4, lp) #: 11 + +# Fix 216 +g = lambda a, L=List[int]() : (L.append(a), L)[1] +print(g(1)) +#: [1] +g = lambda a, b=1, *s, **kw: ((a,b,*s),kw) +print(g('hey!', c=3)) +#: (('hey!', 1), (c: 3)) +print(g('hey!', 2, 3, 4, zz=3)) +#: (('hey!', 2, 3, 4), (zz: 3)) + +#%% nested_lambda,barebones +def foo(): + print list(a*a for a in range(3)) +foo() #: [0, 1, 4] + +#%% yieldexpr,barebones +def mysum(start): + m = start + while True: + a = (yield) + print a.__class__.__name__ #: int + if a == -1: + break + m += a + yield m +iadder = mysum(0) +next(iadder) +for i in range(10): + iadder.send(i) +#: int +#: int +#: int +#: int +#: int +#: int +#: int +#: int +#: int +#: int +print iadder.send(-1) #: 45 + +#%% return,barebones +def foo(): + return 1 +print foo() #: 1 + +def bar(): + print 2 + return + print 1 +bar() #: 2 + +#%% yield,barebones +def foo(): + yield 1 +print [i for i in foo()], str(foo())[:16] #: [1] foo! ' + str(self.i) + def __exit__(self: Foo): + print '< foo! ' + str(self.i) + def foo(self: Foo): + print 'woof' +class Bar: + s: str + def __enter__(self: Bar): + print '> bar! ' + self.s + def __exit__(self: Bar): + print '< bar! ' + self.s + def bar(self: Bar): + print 'meow' +with Foo(0) as f: +#: > foo! 0 + f.foo() #: woof +#: < foo! 0 +with Foo(1) as f, Bar('s') as b: +#: > foo! 1 +#: > bar! s + f.foo() #: woof + b.bar() #: meow +#: < bar! s +#: < foo! 1 +with Foo(2), Bar('t') as q: +#: > foo! 2 +#: > bar! t + print 'eeh' #: eeh + q.bar() #: meow +#: < bar! t +#: < foo! 2 + + +#%% function_err_0,barebones +def foo(a, b, a): + pass #! duplicate argument 'a' in function definition + +#%% function_err_0b,barebones +def foo(a, b=1, c): + pass #! non-default argument 'c' follows default argument + +#%% function_err_0b_ok,barebones +def foo(a, b=1, *c): + pass + +#%% function_err_0c,barebones +def foo(a, b=1, *c, *d): + pass #! multiple star arguments provided + +#%% function_err_0e,barebones +def foo(a, b=1, *c = 1): + pass #! star arguments cannot have default values + +#%% function_err_0f,barebones +def foo(a, b=1, **c, **kwargs): + pass #! kwargs must be the last argument + +#%% function_err_0h,barebones +def foo(a, b=1, **c = 1): + pass #! star arguments cannot have default values + +#%% function_err_0i,barebones +def foo(a, **c, d): + pass #! kwargs must be the last argument + +#%% function_err_1,barebones +def foo(): + @__force__ + def bar(): pass +foo() +#! builtin function must be a top-level statement +#! during the realization of foo() + +#%% function_err_2,barebones +def f[T: Static[float]](): + pass +#! expected 'int', 'bool' or 'str' + +#%% function_err_3,barebones +def f(a, b=a): + pass +#! name 'a' is not defined + +#%% function_llvm_err_1,barebones +@llvm +def foo(): + blah +#! return types required for LLVM and C functions + +#%% function_llvm_err_2,barebones +@llvm +def foo() -> int: + a{={=}} +#! invalid LLVM code + +#%% function_llvm_err_4,barebones +a = 5 +@llvm +def foo() -> int: + a{=a +#! invalid LLVM code + +#%% function_self,barebones +class Foo: + def foo(self): + return 'F' +f = Foo() +print f.foo() #: F + +#%% function_self_err,barebones +class Foo: + def foo(self): + return 'F' +Foo.foo(1) #! 'int' does not match expected type 'Foo' + +#%% function_nested,barebones +def foo(v): + value = v + def bar(): + return value + return bar +baz = foo(2) +print baz() #: 2 + +def f(x): + a=1 + def g(y): + return a+y + return g(x) +print f(5) #: 6 + +#%% nested_generic_static,barebones +def foo(): + N: Static[int] = 5 + Z: Static[int] = 15 + T = Int[Z] + def bar(): + x = __array__[T](N) + print(x.__class__.__name__) + return bar +foo()() #: Array[Int[15]] + +#%% nested_generic_error,barebones +def f[T](): + def g(): + return T() + return g() +print f(int) +#! name 'T' cannot be captured +#! during +#! during + +#%% block_unroll,barebones +# Ensure that block unrolling is done in RAII manner on error +def foo(): + while True: + def magic(a: x): + return + print b +foo() +#! name 'x' is not defined +#! during the realization of foo() + +#%% capture_recursive,barebones +def f(x: int) -> int: + z = 2 * x + def g(y: int) -> int: + if y == 0: + return 1 + else: + return g(y - 1) * z + return g(4) +print(f(3)) #: 1296 + +#%% id_static,barebones +def foo[N: Static[int]](): + print N +foo(5) #: 5 + +def fox(N: Static[int]): + print N +fox(6) #: 6 + +#%% function_typecheck_level,barebones +def foo(x): + def bar(z): # bar has a parent foo(), however its unbounds must not be generalized! + print z + bar(x) + bar('x') +foo(1) +#: 1 +#: x +foo('s') +#: s +#: x + +#%% function_builtin_error,barebones +@__force__ +def foo(x): + pass +#! builtin, exported and external functions cannot be generic + +#%% early_return,barebones +def foo(x): + print x-1 + return + print len(x) +foo(5) #: 4 + +def foo2(x): + if isinstance(x, int): + print x+1 + return + print len(x) +foo2(1) #: 2 +foo2('s') #: 1 + +#%% static_fn,barebones +class A[TA]: + a: TA + def dump(a, b, c): + print a, b, c + def m2(): + A.dump(1, 2, 's') + def __str__(self): + return 'A' +A.dump(1, 2, 3) #: 1 2 3 +A[int].m2() #: 1 2 s +A.m2() #: 1 2 s +c = A[str]('s') +c.dump('y', 1.1) #: A y 1.1 + +#%% static_fn_overload,barebones +def foo(x: Static[int]): + print('int', x) + +@overload +def foo(x: Static[str]): + print('str', x) + +foo(10) +#: int 10 +foo('s') +#: str s + +#%% instantiate_function_2,barebones +def fx[T](x: T) -> T: + def g[T](z): + return z(T()) + return g(fx, T) +print fx(1.1).__class__.__name__, fx(1).__class__.__name__ #: float int + +#%% void,barebones +def foo(): + print 'foo' +def bar(x): + print 'bar', x.__class__.__name__ +a = foo() #: foo +bar(a) #: bar NoneType + +def x(): + pass +b = lambda: x() +b() +x() if True else x() + +#%% void_2,barebones +def foo(): + i = 0 + while i < 10: + print i #: 0 + yield + i += 10 +a = list(foo()) +print(a) #: [None] + +#%% global_none,barebones +a, b = None, None +def foo(): + global a, b + a = [1, 2] + b = 3 +print a, b, +foo() +print a, b #: None None [1, 2] 3 + +#%% return_fn,barebones +def retfn(a): + def inner(b, *args, **kwargs): + print a, b, args, kwargs + print inner.__class__.__name__ #: inner[...,...,int,...] + return inner(15, ...) +f = retfn(1) +print f.__class__.__name__ #: inner[int,...,int,...] +f(2,3,foo='bar') #: 1 15 (2, 3) (foo: 'bar') + +#%% decorator_manual,barebones +def foo(x, *args, **kwargs): + print x, args, kwargs + return 1 +def dec(fn, a): + print 'decorating', fn.__class__.__name__ #: decorating foo[...,...,...] + def inner(*args, **kwargs): + print 'decorator', args, kwargs #: decorator (5.5, 's') (z: True) + return fn(a, *args, **kwargs) + return inner(...) +ff = dec(foo(...), 10) +print ff(5.5, 's', z=True) +#: 10 (5.5, 's') (z: True) +#: 1 + + +#%% decorator,barebones +def foo(x, *args, **kwargs): + print x, args, kwargs + return 1 +def dec(a): + def f(fn): + print 'decorating', fn.__class__.__name__ + def inner(*args, **kwargs): + print 'decorator', args, kwargs + return fn(a, *args, **kwargs) + return inner + return f +ff = dec(10)(foo) +print ff(5.5, 's', z=True) +#: decorating foo[...,...,...] +#: decorator (5.5, 's') (z: True) +#: 10 (5.5, 's') (z: True) +#: 1 + +@dec(a=5) +def zoo(e, b, *args): + return f'zoo: {e}, {b}, {args}' +print zoo(2, 3) +print zoo('s', 3) +#: decorating zoo[...,...,...] +#: decorator (2, 3) () +#: zoo: 5, 2, (3,) +#: decorator ('s', 3) () +#: zoo: 5, s, (3,) + +def mydecorator(func): + def inner(): + print("before") + func() + print("after") + return inner +@mydecorator +def foo2(): + print("foo") +foo2() +#: before +#: foo +#: after + +def timeme(func): + def inner(*args, **kwargs): + begin = 1 + end = func(*args, **kwargs) - begin + print('time needed for', func.__class__.__name__, 'is', end) + return inner +@timeme +def factorial(num): + n = 1 + for i in range(1,num + 1): + n *= i + print(n) + return n +factorial(10) +#: 3628800 +#: time needed for factorial[...] is 3628799 + +def dx1(func): + def inner(): + x = func() + return x * x + return inner +def dx2(func): + def inner(): + x = func() + return 2 * x + return inner +@dx1 +@dx2 +def num(): + return 10 +print(num()) #: 400 + +def dy1(func): + def inner(*a, **kw): + x = func(*a, **kw) + return x * x + return inner +def dy2(func): + def inner(*a, **kw): + x = func(*a, **kw) + return 2 * x + return inner +@dy1 +@dy2 +def num2(a, b): + return a+b +print(num2(10, 20)) #: 3600 + +#%% c_void_return,barebones +from C import seq_print(str) +x = seq_print("not ") +print x #: not None + +#%% return_none_err_1,barebones +def foo(n: int): + if n > 0: + return + else: + return 1 +foo(1) +#! 'NoneType' does not match expected type 'int' +#! during the realization of foo(n: int) + +#%% return_none_err_2,barebones +def foo(n: int): + if n > 0: + return 1 + return +foo(1) +#! 'int' does not match expected type 'NoneType' +#! during the realization of foo(n: int) + +#%% return_fail,barebones +return #! 'return' outside function + +#%% yield_fail,barebones +yield 5 #! 'yield' outside function + +#%% yield_fail_2,barebones +(yield) #! 'yield' outside function + +#%% proxyfunc,barebones +def foo(x: ProxyFunc[[int, int], str]): + return x(1, 2) + +def f1(a, b): + return f'f1:{a}.{b}' +# Case 1: normal functions +print foo(f1) +#: f1:1.2 + +def f2(a, b): + return f'f2:{a}+{b}' +# Case 2: function pointers +f2p: Function[[int,int],str] = f2 +print foo(f2p) +#: f2:1+2 + +def f3(a, b, c): + return f'f3:<{a}+{b}+{c}>' +# Case 3: Partials +pt = f3(c='hey!', ...) +print foo(pt) +#: f3:<1+2+hey!> +print foo(f3(b='hey!', ...)) +#: f3:<1+hey!+2> + +# Case 4: expressions +def i2i_1(x: int) -> int: + return x + 1 +def i2i_2(x: int) -> int: + return x + 2 +# TODO: auto-deduction! +fn = ProxyFunc[[int], int](i2i_1) if int(1) else i2i_2 +print(fn(1)) #: 2 +print (ProxyFunc[[int], int](i2i_1) if int(0) else i2i_2)(1) #: 3 +# TODO: auto-deduction! +l = [ProxyFunc[[int, int],str](f1), f2p, pt] +for fn in l: print(fn(1, 2)) +#: f1:1.2 +#: f2:1+2 +#: f3:<1+2+hey!> + +#%% decorator_self_reference +store = Dict[int, int]() # need to manually configure cache for now. +def memoize(func): + def inner(val: int) -> int: + if val in store: + print(f"<- cache[{val}]") + return store[val] + else: + result = func(val) + store[val] = result + return result + return inner + +@memoize +def fib(n: int) -> int: + print(f"<- fib[{n}]") + if n < 2: + return n + else: + return fib(n - 1) + fib(n - 2) ## << not accessing decorated function + +f4 = fib(4) +print(f"{f4=} : {store=}") +#: <- fib[4] +#: <- fib[3] +#: <- fib[2] +#: <- fib[1] +#: <- fib[0] +#: <- cache[1] +#: <- cache[2] +#: f4=3 : store={0: 0, 1: 1, 2: 1, 3: 2, 4: 3} + +f6 = fib(6) +print(f"{f6=} : {store=}") +#: <- fib[6] +#: <- fib[5] +#: <- cache[4] +#: <- cache[3] +#: <- cache[4] +#: f6=8 : store={0: 0, 1: 1, 2: 1, 3: 2, 4: 3, 5: 5, 6: 8} + +f6 = fib(6) +print(f"{f6=} : {store=}") +#: <- cache[6] +#: f6=8 : store={0: 0, 1: 1, 2: 1, 3: 2, 4: 3, 5: 5, 6: 8} diff --git a/test/parser/typecheck/test_import.codon b/test/parser/typecheck/test_import.codon new file mode 100644 index 00000000..dbe9641b --- /dev/null +++ b/test/parser/typecheck/test_import.codon @@ -0,0 +1,139 @@ +#%% import_c,barebones +from C import sqrt(float) -> float +print sqrt(4.0) #: 2 + +from C import puts(cobj) +puts("hello".ptr) #: hello + +from C import atoi(cobj) -> int as s2i +print s2i("11".ptr) #: 11 + +@C +def log(x: float) -> float: + pass +print log(5.5) #: 1.70475 + +from C import seq_flags: Int[32] as e +# debug | standalone == 5 +print e #: 5 + +#%% import_c_shadow_error,barebones +# Issue #45 +from C import sqrt(float) -> float as foo +sqrt(100.0) #! name 'sqrt' is not defined + + +#%% import_c_dylib,barebones +from internal.dlopen import dlext +RT = "./libcodonrt." + dlext() +if RT[-3:] == ".so": + RT = "build/" + RT[2:] +from C import RT.seq_str_int(int, str, Ptr[bool]) -> str as sp +p = False +print sp(65, "", __ptr__(p)) #: 65 + +#%% import_c_dylib_error,barebones +from C import "".seq_print(str) as sp +sp("hi!") #! syntax error, unexpected '"' + +#%% import,barebones +zoo, _zoo = 1, 1 +print zoo, _zoo, __name__ #: 1 1 __main__ + +import a #: a +a.foo() #: a.foo + +from a import foo, bar as b +foo() #: a.foo +b() #: a.bar + +print str(a)[:9], str(a)[-18:] #: + +import a.b +print a.b.c #: a.b.c +a.b.har() #: a.b.har a.b.__init__ a.b.c + +print a.b.A.B.b_foo().__add__(1) #: a.b.A.B.b_foo() +#: 2 + +print str(a.b)[:9], str(a.b)[-20:] #: +print Int[a.b.stt].__class__.__name__ #: Int[5] + +from a.b import * +har() #: a.b.har a.b.__init__ a.b.c +a.b.har() #: a.b.har a.b.__init__ a.b.c +fx() #: a.foo +print(stt, Int[stt].__class__.__name__) #: 5 Int[5] + +from a import * +print zoo, _zoo, __name__ #: 5 1 __main__ + +f = Foo(Ptr[B]()) +print f.__class__.__name__, f.t.__class__.__name__ #: Foo Ptr[B] + +a.ha() #: B + +print par #: x + +#%% import_order,barebones +def foo(): + import a + a.foo() +def bar(): + import a + a.bar() + +bar() #: a +#: a.bar +foo() #: a.foo + +#%% import_class +import sys +print str(sys)[:20] #: +#! during the realization of + +#%% import_err_1,barebones +class Foo: + import bar #! unexpected expression in class definition + +#%% import_err_2,barebones +import "".a.b.c #! syntax error, unexpected '"' + +#%% import_err_3,barebones +from a.b import foo() #! function signatures only allowed when importing C or Python functions + +#%% import_err_4,barebones +from a.b.c import hai.hey #! expected identifier + +#%% import_err_4_x,barebones +import whatever #! no module named 'whatever' + +#%% import_err_5,barebones +import a.b +print a.b.x #! cannot import name 'x' from 'a.b.__init__' + +#%% import_err_6,barebones +from a.b import whatever #! cannot import name 'whatever' from 'a.b.__init__' + +#%% import_subimport,barebones +import a as xa #: a + +xa.foo() #: a.foo +#: a.sub +xa.sub.foo() #: a.sub.foo diff --git a/test/parser/typecheck/test_infer.codon b/test/parser/typecheck/test_infer.codon new file mode 100644 index 00000000..1c552ae2 --- /dev/null +++ b/test/parser/typecheck/test_infer.codon @@ -0,0 +1,969 @@ +#%% late_unify,barebones +a = [] +a.append(1) +print a #: [1] +print [1]+[1] #: [1, 1] + +#%% late_unify_2,barebones +class XX[T]: + y: T + def __init__(self): pass +a = XX() +def f(i: int) -> int: + return i +print a.y.__class__.__name__ #: int +f(a.y) +print a.__class__.__name__ #: XX[int] +print XX[bool].__class__.__name__ #: XX[bool] + +#%% map_unify +def map[T,S](l: List[T], f: Callable[[T], S]): + return [f(x) for x in l] +e = 1 +print map([1, 2, 3], lambda x: x+e) #: [2, 3, 4] + +def map2(l, f): + return [f(x) for x in l] +print map2([1, 2, 3], lambda x: x+e) #: [2, 3, 4] + +#%% nested,barebones +def m4[TD](a: int, d: TD): + def m5[TD,TE](a: int, d: TD, e: TE): + print a, d, e + m5(a, d, 1.12) +m4(1, 's') #: 1 s 1.12 +m4(1, True) #: 1 True 1.12 + +#%% nested_class,barebones +class A[TA]: + a: TA + # lots of nesting: + def m4[TD](self: A[TA], d: TD): + def m5[TA,TD,TE](a: TA, d: TD, e: TE): + print a, d, e + m5(self.a, d, d) +ax = A(42) +ax.m4(1) #: 42 1 1 + +#%% realization_big +class A[TA,TB,TC]: + a: TA + b: TB + c: TC + + def dump(a, b, c): + print a, b, c + + # non-generic method: + def m0(self: A[TA,TB,TC], a: int): + print a + + # basic generics: + def m1[X](self: A[TA,TB,TC], other: A[X,X,X]): + print other.a, other.b, other.c + + # non-generic method referencing outer generics: + def m2(a: TA, b: TB, c: TC): + A.dump(a, b, c) + + # generic args: + def m3(self, other): + return self.a + + # lots of nesting: + def m4[TD](self: A[TA,TB,TC], d: TD): + def m5[TA,TB,TC,TD,TE](a: TA, b: TB, c: TC, d: TD, e: TE): + print a, b, c, d, e + m5(self.a, self.b, self.c, d, d) + + # instantiating the type: + def m5(self): + x = A(self.a, self.b, self.c) + A.dump(x.a, x.b, x.c) + + # deeply nested generic type: + def m6[T](v: array[array[array[T]]]): + return v[0][0][0] +a1 = A(42, 3.14, "hello") +a2 = A(1, 2, 3) +a1.m1(a2) #: 1 2 3 +A[int,float,str].m2(1, 1.0, "one") #: 1 1 one +A[int,int,int].m2(11, 22, 33) #: 11 22 33 +print a1.m3(a2) #: 42 +print a1.m3(a2) #: 42 +print a2.m3(a1) #: 1 +a1.m4(True) #: 42 3.14 hello True True +a1.m4([1]) #: 42 3.14 hello [1] [1] +a2.m4("x") #: 1 2 3 x x +a1.m5() #: 42 3.14 hello +a2.m5() #: 1 2 3 + +v1 = array[array[array[str]]](1) +v2 = array[array[str]](1) +v3 = array[str](1) +v1[0] = v2 +v2[0] = v3 +v3[0] = "world" +print A.m6(v1) #: world + +f = a2.m0 +f(99) #: 99 + +#%% realization_small,barebones +class B1[T]: + a: T + def foo[S](self: S) -> B1[int]: + return B1[int](111) +b1 = B1[bool](True).foo() +print b1.foo().a #: 111 + +class B2[T]: + a: T + def foo[S](self: B2[S]): + return B2[int](222) +b2 = B2[str]("x").foo() +print b2.foo().a #: 222 + +# explicit realization: +def m7[T,S](): + print "works" +m7(str,float) #: works +m7(str,float) #: works +m7(float,str) #: works + +#%% recursive,barebones +def foo(a): + if not a: + foo(True) + print a +foo(0) +#: True +#: 0 + +def bar(a): + def baz(x): + if not x: + bar(True) + print (x) + baz(a) +bar(0) +#: True +#: 0 + +def rec2(x, y): + if x: + return rec2(y, x) + else: + return 1.0 +print rec2(1, False).__class__.__name__ #: float + +def pq(x): + return True +def rec3(x, y): + if pq(x): + return rec3(y, x) + else: + return y +print rec3('x', 's').__class__.__name__ #: str + +# Nested mutually recursive function +def f[T](x: T) -> T: + def g[T](z): + return z(T()) + return g(f, T=T) +print f(1.2).__class__.__name__ #: float +print f('s').__class__.__name__ #: str + +def f2[T](x: T): + return f2(x - 1, T) if x else 1 +print f2(1) #: 1 +print f2(1.1).__class__.__name__ #: int + + +#%% recursive_error,barebones +def pq(x): + return True +def rec3(x, y): #- ('a, 'b) -> 'b + if pq(x): + return rec3(y, x) + else: + return y +rec3(1, 's') +#! 'int' does not match expected type 'str' +#! during the realization of rec3(x: int, y: str) + +#%% optionals,barebones +y = None +print y #: None +y = 5 +print y #: 5 + +def foo(x: optional[int], y: int): + print 'foo', x, y +foo(y, 6) #: foo 5 6 +foo(5, 6) #: foo 5 6 +foo(5, y) #: foo 5 5 +y = None +try: + foo(5, y) +except ValueError: + print 'unwrap failed' #: unwrap failed + +class Cls: + x: int +c = None +for i in range(2): + if c: c.x += 1 # check for unwrap() dot access + c = Cls(1) +print(c.x) #: 1 + +#%% optional_methods,barebones +@extend +class int: + def x(self): + print 'x()!', self + +y = None +z = 1 if y else None +print z #: None + +y = 6 +z = 1 + y if y else None +print z #: 7 +z.x() #: x()! 7 +if 1: # otherwise compiler won't compile z.x() later + z = None +try: + z.x() +except ValueError: + print 'unwrap failed' #: unwrap failed + +print Optional(1) + Optional(2) #: 3 +print Optional(1) + 3 #: 4 +print 1 + Optional(1) #: 2 + +#%% optional_tuple,barebones +a = None +if True: + a = ('x', 's') +print(a) #: ('x', 's') +print(*a, (1, *a)) #: x s (1, 'x', 's') +x,y=a +print(x,y,[*a]) #: x s ['x', 's'] + +#%% default_type_none +class Test: + value: int + def __init__(self, value: int): + self.value = value + def __repr__(self): + return str(self.value) +def key_func(k: Test): + return k.value +print sorted([Test(1), Test(3), Test(2)], key=key_func) #: [1, 2, 3] +print sorted([Test(1), Test(3), Test(2)], key=lambda x: x.value) #: [1, 2, 3] +print sorted([1, 3, 2]) #: [1, 2, 3] + +#%% nested_map +print list(map(lambda i: i-2, map(lambda i: i+1, range(5)))) +#: [-1, 0, 1, 2, 3] + +def h(x: list[int]): + return x +print h(list(map(lambda i: i-1, map(lambda i: i+2, range(5))))) +#: [1, 2, 3, 4, 5] + +#%% func_unify_error,barebones +def foo(x:int): + print x +z = 1 & foo #! 'foo[...]' does not match expected type 'int' + +#%% tuple_type_late,barebones +coords = [] +for i in range(2): + coords.append( ('c', i, []) ) +coords[0][2].append((1, 's')) +print(coords) #: [('c', 0, [(1, 's')]), ('c', 1, [])] + +#%% instantiate_swap,barebones +class Foo[T, U]: + t: T + u: U + def __init__(self): + self.t = T() + self.u = U() + def __str__(self): + return f'{self.t} {self.u}' +print Foo[int, bool](), Foo[bool, int]() #: 0 False False 0 + +#%% static_fail,barebones +def test(i: Int[32]): + print int(i) +test(Int[5](1)) #! 'Int[5]' does not match expected type 'Int[32]' + +#%% static_fail_2,barebones +zi = Int[32](6) +def test3[N](i: Int[N]): + print int(i) +test3(zi) #! expected type expression +# TODO: nicer error message! + +#%% static_fail_3,barebones +zi = Int[32](6) +def test3[N: Static[int]](i: Int[N]): + print int(i) +test3(1, int) #! expected static expression +# TODO: nicer error message! + +#%% nested_fn_generic,barebones +def f(x): + def g(y): + return y + return g(x) +print f(5), f('s') #: 5 s + +def f2[U](x: U, y): + def g[T, U](x: T, y: U): + return (x, y) + return g(y, x) +x, y = 1, 'haha' +print f2(x, y).__class__.__name__ #: Tuple[str,int] +print f2('aa', 1.1, U=str).__class__.__name__ #: Tuple[float,str] + +#%% nested_fn_generic_error,barebones +def f[U](x: U, y): # ('u, 'a) -> tuple['a, 'u] + def g[T, U](x: T, y: U): # ('t, 'u) -> tuple['t, 'u] + return (x, y) + return g(y, x) +print f(1.1, 1, int).__class__.__name__ #! 'float' does not match expected type 'int' + +#%% fn_realization,barebones +def ff[T](x: T, y: tuple[T]): + print ff(T=str,...).__fn_name__ #: ff[str;str,Tuple[str]] + return x +x = ff(1, (1,)) +print x, x.__class__.__name__ #: 1 int +# print f.__class__.__name__ # TODO ERRORS + +def fg[T](x:T): + def g[T](y): + z = T() + return z + print fg(T=str,...).__fn_name__ #: fg[str;str] + print g(1, T).__class__.__name__ #: int +fg(1) +print fg(1).__class__.__name__ #: NoneType + +def f[T](x: T): + print f(x, T).__class__.__name__ #: int + print f(x).__class__.__name__ #: int + print f(x, int).__class__.__name__ #: int + return x +print f(1), f(1).__class__.__name__ #: 1 int +print f(1, int).__class__.__name__ #: int + +#%% fn_realization_error,barebones +def f[T](x: T): + print f(x, int).__class__.__name__ + return x +f('s') +#! 'str' does not match expected type 'int' +#! during the realization of f(x: str, T: str) + +#%% func_arg_instantiate,barebones +class A[T]: + y: T = T() + def foo(self, y: T): + self.y = y + return y + def bar(self, y): + return y +a = A() +print a.__class__.__name__ #: A[int] +a.y = 5 +print a.__class__.__name__ #: A[int] + +b = A() +print b.foo(5) #: 5 +print b.__class__.__name__, b.y #: A[int] 5 +print b.bar('s'), b.bar('s').__class__.__name__ #: s str +print b.bar(5), b.bar(5).__class__.__name__ #: 5 int + +aa = A() +print aa.foo('s') #: s +print aa.__class__.__name__, aa.y, aa.bar(5.1).__class__.__name__ #: A[str] s float + +#%% no_func_arg_instantiate_err,barebones +# TODO: allow unbound self? +class A[T]: + y: T = T() + def foo(self, y): self.y = y +a = A() +a.foo(1) +#! cannot typecheck +#! cannot typecheck +#! cannot typecheck + +#%% return_deduction,barebones +def fun[T, R](x, y: T) -> R: + def ffi[T, R, Z](x: T, y: R, z: Z): + return (x, y, z) + yy = ffi(False, byte(2), 's', T=bool, Z=str, R=R) + yz = ffi(1, byte(2), 's', T=int, Z=str, R=R) + return byte(1) +print fun(2, 1.1, float, byte).__class__.__name__ #: byte + +#%% return_auto_deduction_err,barebones +def fun[T, R](x, y: T) -> R: + return byte(1) +print fun(2, 1.1).__class__.__name__ #! cannot typecheck + +#%% random +# shuffle used to fail before for some reason (sth about unbound variables)... +def foo(): + from random import shuffle + v = list(range(10)) + shuffle(v) + print sorted(v) #: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +foo() + +#%% function_type,barebones +class F: + f: Function[[int], int] + g: function[[int], None] + x: int +def foo(x: int): + return x+1 +def goo(x: int): + print x+2 +f = F(foo, goo, 2) +print f.f(f.x) #: 3 +f.g(f.x) #: 4 + +def hoo(z): + print z+3 +f.g = hoo +f.g(f.x) #: 5 + +def hai(x, y, z): + print f'hai({x},{y},{z})' +fn = Function[[int, int, int], None](hai) +fn(1, 2, 3) #: hai(1,2,3) +print str(fn)[:12] #: 0: + z: Static[int] = 88 + print z #: 88 +print x #: 5 +x : Static[int] = 3 +print x #: 3 + +def fox(N: Static[int] = 4): + print Int[N].__class__.__name__, N +fox(5) #: Int[5] 5 +fox() #: Int[4] 4 + +#%% new_syntax_err,barebones +class Foo[T,U: Static[int]]: + a: T + b: Static[int] + c: Int[U] + d: type + e: List[d] + f: UInt[b] +print Foo[float,6].__class__.__name__ #! Foo takes 4 generics (2 given) + +#%% type_arg_transform,barebones +print list(map(str, range(5))) +#: ['0', '1', '2', '3', '4'] + + +#%% traits,barebones +def t[T](x: T, key: Optional[Callable[[T], S]] = None, S: type = NoneType): + if isinstance(S, NoneType): + return x + else: + return (key.__val__())(x) +print t(5) #: 5 +print t(6, lambda x: f'#{x}') #: #6 + +z: Callable[[int],int] = lambda x: x+1 +print z(5) #: 6 + +def foo[T](x: T, func: Optional[Callable[[], T]] = None) -> T: + return x +print foo(1) #: 1 + +#%% trait_callable +foo = [1,2,11] +print(sorted(foo, key=str)) +#: [1, 11, 2] + +foo = {1: "a", 2: "a", 11: "c"} +print(sorted(foo.items(), key=str)) +#: [(1, 'a'), (11, 'c'), (2, 'a')] + +def call(f: Callable[[int,int], Tuple[str,int]]): + print(f(1, 2)) + +def foo(*x): return f"{x}_{x.__class__.__name__}",1 +call(foo) +#: ('(1, 2)_Tuple[int,int]', 1) + +def foo(a:int, *b: float): return f"{a}_{b}", a+b[0] +call(foo) +#: ('1_(2,)', 3) + +def call(f: Callable[[int,int],str]): + print(f(1, 2)) +def foo(a: int, *b: int, **kw): return f"{a}_{b}_{kw}" +call(foo(zzz=1.1, ...)) +#: 1_(2,)_(zzz: 1.1) + +#%% traits_error,barebones +def t[T](x: T, key: Optional[Callable[[T], S]] = None, S: type = NoneType): + if isinstance(S, NoneType): + return x + else: + return (key.__val__())(x) +print t(6, Optional(1)) #! 'Optional[int]' does not match expected type 'Optional[Callable[[int],S]]' + +#%% traits_error_2,barebones +z: Callable[[int],int] = 4 #! 'Callable[[int],int]' does not match expected type 'int' + +#%% trait_defdict +class dd(Static[Dict[K,V]]): + fn: S + K: type + V: type + S: TypeVar[Callable[[], V]] + + def __init__(self: dd[K, VV, Function[[], V]], VV: TypeVar[V]): + self.fn = lambda: VV() + + def __init__(self, f: S): + self.fn = f + + def __getitem__(self, key: K) -> V: + if key not in self: + self.__setitem__(key, self.fn()) + return super().__getitem__(key) + + +x = dd(list) +x[1] = [1, 2] +print(x[2]) +#: [] +print(x) +#: {1: [1, 2], 2: []} + +z = 5 +y = dd(lambda: z+1) +y.update({'a': 5}) +print(y['b']) +#: 6 +z = 6 +print(y['c']) +#: 6 +# TODO: should be 7 once by-ref capture lands +print(y) +#: {'a': 5, 'b': 6, 'c': 6} + +xx = dd(lambda: 'empty') +xx.update({1: 's', 2: 'b'}) +print(xx[1], xx[44]) +#: s empty +print(xx) +#: {44: 'empty', 1: 's', 2: 'b'} + +s = 'mississippi' +d = dd(int) +for k in s: + d[k] = d["x" + k] +print(sorted(d.items())) +#: [('i', 0), ('m', 0), ('p', 0), ('s', 0), ('xi', 0), ('xm', 0), ('xp', 0), ('xs', 0)] + + +#%% kwargs_getattr,barebones +def foo(**kwargs): + print kwargs['foo'], kwargs['bar'] + +foo(foo=1, bar='s') +#: 1 s + + + +#%% union_types,barebones +def foo_int(x: int): + print(f'{x} {x.__class__.__name__}') +def foo_str(x: str): + print(f'{x} {x.__class__.__name__}') +def foo(x): + print(f'{x} {int(__internal__.union_get_tag(x))} {x.__class__.__name__}') + +a: Union[int, str] = 5 +foo_int(a) #: 5 int +foo(a) #: 5 0 Union[int | str] +print(staticlen(a)) #: 2 +print(staticlen(Union[int, int]), staticlen(Tuple[int, float, int])) #: 1 3 + +@extend +class str: + def __add__(self, i: int): + return int(self) + i + +a += 6 ## this is U.__new__(a.__getter__(__add__)(59)) +b = a + 59 +print(a, b, a.__class__.__name__, b.__class__.__name__) #: 11 70 Union[int | str] int + +if True: + a = 'hello' + foo_str(a) #: hello str + foo(a) #: hello 1 Union[int | str] + b = a[1:3] + print(b) #: el +print(a) #: hello + +a: Union[Union[Union[str], int], Union[int, int, str]] = 9 +foo(a) #: 9 0 Union[int | str] + +def ret(x): + z : Union = x + if x < 1: z = 1 + elif x < 10: z = False + else: z = 'oops' + return z +r = ret(2) +print(r, r.__class__.__name__) #: False Union[bool | int | str] +r = ret(33.3) +print(r, r.__class__.__name__) #: oops Union[bool | float | int | str] + +def ret2(x) -> Union: + if x < 1: return 1 + elif x < 10: return 2.2 + else: return ['oops'] +r = ret2(20) +print(r, r.__class__.__name__) #: ['oops'] Union[List[str] | float | int] + +class A: + x: int + def foo(self): + return f"A: {self.x}" +class B: + y: str + def foo(self): + return f"B: {self.y}" +x : Union[A,B] = A(5) # TODO: just Union does not work in test mode :/ +print(x.foo()) #: A: 5 +print(x.x) #: 5 +if True: + x = B("bee") +print(x.foo()) #: B: bee +print(x.y) #: bee +try: + print(x.x) +except TypeError as e: + print(e.message) #: invalid union call 'x' + +def do(x: A): + print('do', x.x) +try: + do(x) +except TypeError: + print('error') #: error + +def do2(x: B): + print('do2', x.y) +do2(x) #: do2 bee + +z: Union[int, str] = 1 +print isinstance(z, int), isinstance(z, str), isinstance(z, float), isinstance(z, Union[int, float]), isinstance(z, Union[int, str]) +#: True False False False True + +print isinstance(z, Union[int]), isinstance(z, Union[int, float, str]) +#: False False + +if True: + z = 's' +print isinstance(z, int), isinstance(z, str), isinstance(z, float), isinstance(z, Union[int, float]), isinstance(z, Union[int, str]) +#: False True False False True + +class A: + def foo(self): return 1 +class B: + def foo(self): return 's' +class C: + def foo(self): return [True, False] +x : Union[A,B,C] = A() +print x.foo(), x.foo().__class__.__name__ +#: 1 Union[List[bool] | int | str] + +xx = Union[int, str](0) +print(xx) #: 0 + +#%% union_error,barebones +a: Union[int, str] = 123 +print(123 == a) #: True +print(a == 123) #: True +try: + a = "foo" + print(a == 123) +except TypeError: + print("oops", a) #: oops foo + + +#%% delayed_lambda_realization,barebones +x = [] +for i in range(2): + print(all(x[j] < 0 for j in range(i))) + x.append(i) +#: True +#: False + +#%% no_generic,barebones +def foo(a, b: Static[int]): + pass +foo(5) #! generic 'b' not provided + + +#%% no_generic_2,barebones +def f(a, b, T: type): + print(a, b) +f(1, 2) #! generic 'T' not provided + +#%% variardic_tuples,barebones +na: Tuple[5, str] = ('a', 'b', 'c', 'd', 'e') +print(na, na.__class__.__name__) +#: ('a', 'b', 'c', 'd', 'e') Tuple[str,str,str,str,str] + +nb = Tuple[5, str]('a', 'b', 'c', 'd', 'e') +print(nb, nb.__class__.__name__) +#: ('a', 'b', 'c', 'd', 'e') Tuple[str,str,str,str,str] + +class Foo[N: Static[int], T: type]: + x: Tuple[N, T] + def __init__(self, t: T): + self.x = (t, ) * N + +f = Foo[5, str]('hi') +print(f.__class__.__name__) +#: Foo[5,str] +print(f.x.__class__.__name__) +#: Tuple[str,str,str,str,str] +print(f.x) +#: ('hi', 'hi', 'hi', 'hi', 'hi') + +f = Foo[2,int](1) +print(f.__class__.__name__) +#: Foo[2,int] +print(f.x.__class__.__name__) +#: Tuple[int,int] +print(f.x) +#: (1, 1) +f.x = (3, 4) +print(f.x) +#: (3, 4) + +print(Tuple[int, int].__class__.__name__) +#: Tuple[int,int] +print(Tuple[3, int].__class__.__name__) +#: Tuple[int,int,int] +print(Tuple[0].__class__.__name__) +#: Tuple +print(Tuple[-5, int].__class__.__name__) +#: Tuple +print(Tuple[5, int, str].__class__.__name__) +#: Tuple[int,str,int,str,int,str,int,str,int,str] + +def foo(t: Tuple[N, int], N: Static[int]): + print("foo", N, t) +foo((1, 2, 3)) +#: foo 3 (1, 2, 3) +foo((1, 2, 3, 4, 5)) +#: foo 5 (1, 2, 3, 4, 5) + + +#%% union_hasattr,barebones +class A: + def foo(self): + print('foo') + def bar(self): + print('bar') +class B: + def foo(self): + print('foo') + def baz(self): + print('baz') + +a = A() +print(hasattr(a, 'foo'), hasattr(a, 'bar'), hasattr(a, 'baz')) +#: True True False +b = B() +print(hasattr(b, 'foo'), hasattr(b, 'bar'), hasattr(b, 'baz')) +#: True False True + +c: Union[A, B] = A() +print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz')) +#: True True False + +c = B() +print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz')) +#: True False True + + +#%% delayed_dispatch +import math +def fox(a, b, key=None): # key=None delays it! + return a if a <= b else b + +a = 1.0 +b = 2.0 +c = fox(a, b) +print(math.log(c) / 2) #: 0 + +#%% repeated_lambda,barebones +def acc(i, func=lambda a, b: a + b): + return i + func(i, i) +print acc(1) #: 3 +print acc('i') #: iii + +x = 1 +def const(value): + return lambda: (value, x) +print const(5)() #: (5, 1) +print const('s')() #: ('s', 1) +x = 's' +print const(5)() #: (5, 's') +print const('s')() #: ('s', 's') + + +#%% type_variables_pass,barebones +def foo(a): + print(a.__class__.__name__, a) + print(a().__class__.__name__, a()) + +foo(float) +#: float +#: float 0 +print(float) +#: +foo(list[int]) +#: List[int] +#: List[int] [] +print(list[int]) +#: +foo(type(list[int])) +#: List[int] +#: List[int] [] + +# TODO: print(list) + +def typtest(a, b): + print isinstance(a, b) + print isinstance(a, int) + print(a) + print(b) + print(a.__repr__()) + +typtest(int, int) +#: True +#: True +#: +#: +#: +typtest(int, float) +#: False +#: True +#: +#: +#: + +print(List[int]) +print(List[int].__repr__()) +# print(int.__repr__()) # this catches int.__repr__ as it should... +print(type(int).__repr__()) +#: +#: +#: diff --git a/test/parser/typecheck/test_loops.codon b/test/parser/typecheck/test_loops.codon new file mode 100644 index 00000000..d33754f7 --- /dev/null +++ b/test/parser/typecheck/test_loops.codon @@ -0,0 +1,207 @@ +#%% while_else,barebones +a = 1 +while a: + print a #: 1 + a -= 1 +else: + print 'else' #: else +a = 1 +while a: + print a #: 1 + a -= 1 +else not break: + print 'else' #: else +while True: + print 'infinite' #: infinite + break +else: + print 'nope' + +#%% for_assignment,barebones +l = [[1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11]] +for a, *m, b in l: + print a + b, len(m) +#: 5 2 +#: 14 3 +#: 21 0 + +#%% for_else,barebones +for i in [1]: + print i #: 1 +else: + print 'else' #: else +for i in [1]: + print i #: 1 +else not break: + print 'else' #: else +for i in [1]: + print i #: 1 + break +else: + print 'nope' + +best = 4 +for s in [3, 4, 5]: + for i in [s]: + if s >= best: + print('b:', best) + break + else: + print('s:', s) + best = s +#: s: 3 +#: b: 3 +#: b: 3 + + +#%% loop_domination,barebones +for i in range(2): + try: dat = 1 + except: pass + print(dat) +#: 1 +#: 1 + +def comprehension_test(x): + for n in range(3): + print('>', n) + l = ['1', '2', str(x)] + x = [n for n in l] + print(x, n) +comprehension_test(5) +#: > 0 +#: > 1 +#: > 2 +#: ['1', '2', '5'] 2 + +#%% while,barebones +a = 3 +while a: + print a + a -= 1 +#: 3 +#: 2 +#: 1 + +#%% for_break_continue,barebones +for i in range(10): + if i % 2 == 0: + continue + print i + if i >= 5: + break +#: 1 +#: 3 +#: 5 + +#%% for_error,barebones +for i in 1: + pass +#! '1' object has no attribute '__iter__' + +#%% for_void,barebones +def foo(): yield +for i in foo(): + print i.__class__.__name__ #: NoneType + +#%% hetero_iter,barebones +e = (1, 2, 3, 'foo', 5, 'bar', 6) +for i in e: + if isinstance(i, int): + if i == 1: continue + if isinstance(i, str): + if i == 'bar': break + print i + +#%% static_for,barebones +def foo(i: Static[int]): + print('static', i, Int[i].__class__.__name__) + +for i in statictuple(1, 2, 3, 4, 5): + foo(i) + if i == 3: break +#: static 1 Int[1] +#: static 2 Int[2] +#: static 3 Int[3] +for i in staticrange(9, 4, -2): + foo(i) + if i == 3: + break +#: static 9 Int[9] +#: static 7 Int[7] +#: static 5 Int[5] +for i in statictuple("x", 1, 3.3, 2): + print(i) +#: x +#: 1 +#: 3.3 +#: 2 + +print tuple(Int[i+10](i) for i in statictuple(1, 2, 3)).__class__.__name__ +#: Tuple[Int[11],Int[12],Int[13]] + +for i in staticrange(0, 10): + if i % 2 == 0: continue + if i > 8: break + print('xyz', Int[i].__class__.__name__) +print('whoa') +#: xyz Int[1] +#: xyz Int[3] +#: xyz Int[5] +#: xyz Int[7] +#: whoa + +for i in staticrange(15): + if i % 2 == 0: continue + if i > 8: break + print('xyz', Int[i].__class__.__name__) +print('whoa') +#: xyz Int[1] +#: xyz Int[3] +#: xyz Int[5] +#: xyz Int[7] +#: whoa + +print tuple(Int[i-10](i) for i in staticrange(30,33)).__class__.__name__ +#: Tuple[Int[20],Int[21],Int[22]] + +for i in statictuple(0, 2, 4, 7, 11, 12, 13): + if i % 2 == 0: continue + if i > 8: break + print('xyz', Int[i].__class__.__name__) +print('whoa') +#: xyz Int[7] +#: whoa + +for i in staticrange(10): # TODO: large values are too slow! + pass +print('done') +#: done + +tt = (5, 'x', 3.14, False, [1, 2]) +for i, j in staticenumerate(tt): + print(foo(i * 2 + 1), j) +#: static 1 Int[1] +#: None 5 +#: static 3 Int[3] +#: None x +#: static 5 Int[5] +#: None 3.14 +#: static 7 Int[7] +#: None False +#: static 9 Int[9] +#: None [1, 2] + +print tuple((Int[i+1](i), j) for i, j in staticenumerate(tt)).__class__.__name__ +#: Tuple[Tuple[Int[1],int],Tuple[Int[2],str],Tuple[Int[3],float],Tuple[Int[4],bool],Tuple[Int[5],List[int]]] + +#%% static_range_error,barebones +for i in staticrange(1000, -2000, -2): + pass +#! staticrange too large (expected 0..1024, got instead 1500) + +#%% continue_error,barebones +continue #! 'continue' outside loop + +#%% break_error,barebones +break #! 'break' outside loop diff --git a/test/parser/typecheck/test_op.codon b/test/parser/typecheck/test_op.codon new file mode 100644 index 00000000..e5df113d --- /dev/null +++ b/test/parser/typecheck/test_op.codon @@ -0,0 +1,451 @@ + +#%% unary,barebones +a, b = False, 1 +print not a, not b, ~b, +b, -b, -(+(-b)) #: True False -2 1 -1 1 + +#%% binary_simple,barebones +x, y = 1, 0 +c = [1, 2, 3] + +print x and y, x or y #: False True +print x in c, x not in c #: True False +print c is c, c is not c #: True False + +z: Optional[int] = None +print z is None, None is z, None is not z, None is None #: True True False True + +#%% chain_binary,barebones +def foo(): + print 'foo' + return 15 +a = b = c = foo() #: foo +print a, b, c #: 15 15 15 + +x = y = [] +x.append(1) +print x, y #: [1] [1] + +print 1 <= foo() <= 10 #: foo +#: False +print 15 >= foo()+1 < 30 > 20 > foo() +#: foo +#: False +print 15 >= foo()-1 < 30 > 20 > foo() +#: foo +#: foo +#: True + +print True == (b == 15) #: True + +#%% pipe_error,barebones +def b(a, b, c, d): + pass +1 |> b(1, ..., 2, ...) #! multiple ellipsis expressions + +#%% index_normal,barebones +t: tuple[int, int] = (1, 2) +print t #: (1, 2) + +tt: Tuple[int] = (1, ) +print tt #: (1,) + +def foo(i: int) -> int: + return i + 1 +f: Callable[[int], int] = foo +print f(1) #: 2 +fx: function[[int], int] = foo +print fx(2) #: 3 +fxx: Function[[int], int] = foo +print fxx(3) #: 4 + +#%% index_special,barebones +class Foo: + def __getitem__(self, foo): + print foo +f = Foo() +f[0,0] #: (0, 0) +f[0,:] #: (0, slice(None, None, None)) +f[:,:] #: (slice(None, None, None), slice(None, None, None)) +f[:,0] #: (slice(None, None, None), 0) + +#%% index_error,barebones +Ptr[9.99] #! expected type expression + +#%% index_error_b,barebones +Ptr['s'] #! expected type expression + +#%% index_error_static,barebones +Ptr[1] #! expected type expression + +#%% index_error_2,barebones +Ptr[int, 's'] #! Ptr takes 1 generics (2 given) + +#%% index_error_3,barebones +Ptr[1, 's'] #! Ptr takes 1 generics (2 given) + +#%% callable_error,barebones +def foo(x: Callable[[]]): pass #! Callable takes 2 generics (1 given) + +#%% binary,barebones +@extend +class float: + def __add__(self, i: int): + print 'add'; return 0 + def __sub__(self, i: int): + print 'sub'; return 0 + def __mul__(self, i: int): + print 'mul'; return 0 + def __pow__(self, i: int): + print 'pow'; return 0 + def __truediv__(self, i: int): + print 'truediv'; return 0 + def __floordiv__(self, i: int): + print 'div'; return 0 + def __matmul__(self, i: int): + print 'matmul'; return 0 + def __mod__(self, i: int): + print 'mod'; return 0 + def __lt__(self, i: int): + print 'lt'; return 0 + def __le__(self, i: int): + print 'le'; return 0 + def __gt__(self, i: int): + print 'gt'; return 0 + def __ge__(self, i: int): + print 'ge'; return 0 + def __eq__(self, i: int): + print 'eq'; return 0 + def __ne__(self, i: int): + print 'ne'; return 0 + def __lshift__(self, i: int): + print 'lshift'; return 0 + def __rshift__(self, i: int): + print 'rshift'; return 0 + def __and__(self, i: int): + print 'and'; return 0 + def __or__(self, i: int): + print 'or'; return 0 + def __xor__(self, i: int): + print 'xor'; return 0 +# double assignment to disable propagation +def f(x): return x +a = f(1.0) +a = f(5.0) +a + f(1) #: add +# wrap in function to disable canonicalization +a - f(1) #: sub +a * f(2) #: mul +a ** f(2) #: pow +a // f(2) #: div +a / f(2) #: truediv +a @ f(1) #: matmul +a % f(1) #: mod +a < f(1) #: lt +a <= f(1) #: le +a > f(1) #: gt +a >= f(1) #: ge +a == f(1) #: eq +a != f(1) #: ne +a << f(1) #: lshift +a >> f(1) #: rshift +a & f(1) #: and +a | f(1) #: or +a ^ f(1) #: xor + +#%% binary_rmagic,barebones +class Foo[T]: + def __add__(self, other: T): + print 'add' + return self + def __radd__(self, other: T): + print 'radd' + return self +foo = Foo[int]() +foo + 1 #: add +1 + foo #: radd + +#%% binary_short_circuit,barebones +def moo(): + print 'moo' + return True +print True or moo() #: True +print moo() or True #: moo +#: True +print False and moo() #: False +print moo() and False #: moo +#: False + +#%% binary_is,barebones +print 5 is None #: False +print None is None #: True +print (None if bool(True) else 1) is None #: True +print (None if bool(False) else 1) is None #: False + +print 5 is 5.0 #: False +print 5 is 6 #: False +print 5 is 5 #: True +print 5 is 1.12 #: False +class Foo: + a: int +x = Foo(1) +y = Foo(1) +z = x +print x is x, x is y, x is z, z is x, z is y #: True False True True False + +a, b, c, d = Optional(5), Optional[int](), Optional(5), Optional(4) +print a is a, a is b, b is b, a is c, a is d #: True False True True False +aa, bb, cc, dd = Optional(Foo(1)), Optional[Foo](), Optional(Foo(1)), Optional(Foo(2)) +print aa is aa, aa is bb, bb is bb, aa is cc, aa is dd #: True False True False False + + +#%% pipe,barebones +def foo(a, b): + return a+b +bar = lambda c, d: c+d +def hai(e): + while e > 0: + yield e + e -= 2 +def echo(s): + print s +foo(1,2) |> bar(4) |> echo #: 7 +foo(1,2) |> bar(4) |> hai |> echo +#: 7 +#: 5 +#: 3 +#: 1 + +#%% pipe_prepend,barebones +def foo(a: Optional[int]): + print a + return 1 +5 |> foo #: 5 +None |> foo #: None +print (None |> foo).__class__.__name__ #: int + +def foo2(a: int): + print a + return 1 +Optional(5) |> foo2 #: 5 +try: + Optional[int]() |> foo2 +except ValueError as e: + print e.message #: optional unpack failed: expected int, got None + +#%% pipe_prepend_error,barebones +def foo2(a: int): + print a + return 1 +try: + None |> foo2 +except ValueError: + print 'exception' #: exception +# Explanation: None can also be Optional[Generator[int]] +# We cannot decide if this is a generator to be unrolled in a pipe, +# or just an argument to be passed to a function. +# So this will default to NoneType at the end. + +#%% instantiate_err,barebones +def foo[N](): + return N() +foo(int, float) #! foo() takes 1 arguments (2 given) + +#%% instantiate_err_2,barebones +def foo[N, T](): + return N() +foo(int) #! generic 'T' not provided + +#%% instantiate_err_3,barebones +Ptr[int, float]() #! Ptr takes 1 generics (2 given) + +#%% slice,barebones +z = [1, 2, 3, 4, 5] +y = (1, 'foo', True) +print z[2], y[1] #: 3 foo +print z[:1], z[1:], z[1:3], z[:4:2], z[::-1] #: [1] [2, 3, 4, 5] [2, 3] [1, 3] [5, 4, 3, 2, 1] + +#%% static_index,barebones +a = (1, '2s', 3.3) +print a[1] #: 2s +print a[0:2], a[:2], a[1:] #: (1, '2s') (1, '2s') ('2s', 3.3) +print a[0:3:2], a[-1:] #: (1, 3.3) (3.3,) + +#%% static_index_side,barebones +def foo(a): + print(a) + return a + +print (foo(2), foo(1))[::-1] +#: 2 +#: 1 +#: (1, 2) +print (foo(1), foo(2), foo(3), foo(4))[2] +#: 1 +#: 2 +#: 3 +#: 4 +#: 3 + +#%% static_index_lenient,barebones +a = (1, 2) +print a[3:5] #: () + +#%% static_index_err,barebones +a = (1, 2) +a[5] #! tuple index out of range (expected 0..1, got instead 5) + +#%% static_index_err_2,barebones +a = (1, 2) +a[-3] #! tuple index out of range (expected 0..1, got instead -1) + +#%% index_func_instantiate,barebones +class X: + def foo[T](self, x: T): + print x.__class__.__name__, x +x = X() +x.foo(5, int) #: int 5 + +#%% index,barebones +l = [1, 2, 3] +print l[2] #: 3 + +#%% index_two_rounds,barebones +l = [] +print l[::-1] #: [] +l.append(('str', 1, True, 5.15)) +print l, l.__class__.__name__ #: [('str', 1, True, 5.15)] List[Tuple[str,int,bool,float]] + +#%% nested_generic,barebones +x = Array[Array[int]](0) +f = Optional[Optional[Optional[int]]](Optional[Optional[int]](Optional[int](5))) +print x.len, f #: 0 5 + +#%% static,barebones +class Num[N_: Static[int]]: + def __str__(self): + return f'[{N_}]' + def __init__(self): + pass +def foo[N: Static[int]](): + print Num[N*2]() +foo(3) #: [6] + +class XX[N_: Static[int]]: + a: Num[N_*2] + def __init__(self): + self.a = Num() +y = XX[5]() +print y.a, y.__class__.__name__, y.a.__class__.__name__ #: [10] XX[5] Num[10] + +@tuple +class FooBar[N: Static[int]]: + x: Int[N] +z = FooBar(i32(5)) +print z, z.__class__.__name__, z.x.__class__.__name__ #: (x: Int[32](5)) FooBar[32] Int[32] + +@tuple +class Foo[N: Static[int]]: + x: Int[2*N] + def ctr(x: Int[2*N]) -> Foo[N]: + return Foo[N](x) +foo = Foo[10].ctr(Int[20](0)) +print foo.__class__.__name__, foo.x.__class__.__name__ #: Foo[10] Int[20] + +#%% static_2,barebones +class Num[N: Static[int]]: + def __str__(self): + return f'~{N}' + def __init__(self): + pass +class Foo[T, A: Static[int], B: Static[int]]: + a: Num[A+B] + b: Num[A-B] + c: Num[A if A > 3 else B] + t: T + def __init__(self): + self.a = Num() + self.b = Num() + self.c = Num() + self.t = T() + def __str__(self): + return f'<{self.a} {self.b} {self.c} :: {self.t}>' +print Foo[int, 3, 4](), Foo[int, 5, 4]() +#: <~7 ~-1 ~4 :: 0> <~9 ~1 ~5 :: 0> + +#%% static_int,barebones +def foo(n: Static[int]): + print n +@overload +def foo(n: Static[bool]): + print n + +a: Static[int] = 5 +foo(a < 1) #: False +foo(a <= 1) #: False +foo(a > 1) #: True +foo(a >= 1) #: True +foo(a == 1) #: False +foo(a != 1) #: True +foo(a and 1) #: True +foo(a or 1) #: True +foo(a + 1) #: 6 +foo(a - 1) #: 4 +foo(a * 1) #: 5 +foo(a // 2) #: 2 +foo(a % 2) #: 1 +foo(a & 2) #: 0 +foo(a | 2) #: 7 +foo(a ^ 1) #: 4 + +#%% static_str,barebones +class X: + s: Static[str] + i: Int[1 + (s == "abc")] + def __init__(self: X[s], s: Static[str]): + i = Int[1+(s=="abc")]() + print s, self.s, self.i.__class__.__name__ +def foo(x: Static[str], y: Static[str]): + print x+y +z: Static[str] = "woo" +foo("he", z) #: hewoo +X(s='lolo') #: lolo lolo Int[1] +X('abc') #: abc abc Int[2] + +def foo2(x: Static[str]): + print(x, x.__is_static__) +s: Static[str] = "abcdefghijkl" +foo2(s) #: abcdefghijkl True +foo2(s[1]) #: b True +foo2(s[1:5]) #: bcde True +foo2(s[10:50]) #: kl True +foo2(s[1:30:3]) #: behk True +foo2(s[::-1]) #: lkjihgfedcba True + +#%% static_short_circuit,barebones +x = 3.14 +if isinstance(x, List) and x.T is float: + print('is list') +else: + print('not list') #: not list + +#%% partial_star_pipe_args,barebones +iter(['A', 'C']) |> print +#: A +#: C +iter(range(4)) |> print('x', ..., 1) +#: x 0 1 +#: x 1 1 +#: x 2 1 +#: x 3 1 + +#%% partial_static_keep,barebones +def foo(x: Static[int]): + return lambda: str(x) +f = foo(5) +print foo(5)() #: 5 +print foo(8)() #: 8 + +def itemgetter(item: Static[int]): + return lambda o: o[item] +print itemgetter(1)([1, 2, 3]) #: 2 +print itemgetter(2)("abc") #: c diff --git a/test/parser/typecheck/test_parser.codon b/test/parser/typecheck/test_parser.codon new file mode 100644 index 00000000..38987bd5 --- /dev/null +++ b/test/parser/typecheck/test_parser.codon @@ -0,0 +1,29 @@ +#%% keyword_prefix,barebones +def foo(return_, pass_, yield_, break_, continue_, print_, assert_): + return_.append(1) + pass_.append(2) + yield_.append(3) + break_.append(4) + continue_.append(5) + print_.append(6) + assert_.append(7) + return return_, pass_, yield_, break_, continue_, print_, assert_ +print foo([1], [1], [1], [1], [1], [1], [1]) +#: ([1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7]) + +#%% spaces,barebones +def space_test(): + x = 0.77 + y = 0.86 + z = x/(1if((1if(y==0)else(y))==0)else((1if(y==0)else(y)))) + print(z) #: 0.895349 + + h = "hello" + b = ((True)or(False))or((("sR2Kt7"))==(h)) + print(b) #: True + + h2: Optional[str] = "hi" + h3 = "r" + b2 = (((h2)==None)and(h3)==("r")) + print(b2) #: False +space_test() \ No newline at end of file diff --git a/test/parser/typecheck/test_python.codon b/test/parser/typecheck/test_python.codon new file mode 100644 index 00000000..913f5837 --- /dev/null +++ b/test/parser/typecheck/test_python.codon @@ -0,0 +1,136 @@ +#%% python +from python import os +print os.name #: posix + +from python import datetime +z = datetime.datetime.utcfromtimestamp(0) +print z #: 1970-01-01 00:00:00 + +#%% python_numpy +from python import numpy as np +a = np.arange(9).reshape(3, 3) +print a +#: [[0 1 2] +#: [3 4 5] +#: [6 7 8]] +print a.dtype.name #: int64 +print np.transpose(a) +#: [[0 3 6] +#: [1 4 7] +#: [2 5 8]] +n = np.array([[1, 2], [3, 4]]) +print n[0], n[0][0] + 1 #: [1 2] 2 + +a = np.array([1,2,3]) +print(a + 1) #: [2 3 4] +print(a - 1) #: [0 1 2] +print(1 - a) #: [ 0 -1 -2] + +#%% python_import_fn +from python import re.split(str, str) -> List[str] as rs +print rs(r'\W+', 'Words, words, words.') #: ['Words', 'words', 'words', ''] + +#%% python_import_fn_2 +from python import os.system(str) -> int +system("echo 'hello!'") #: hello! + +#%% python_pydef +@python +def test_pydef(n) -> str: + return ''.join(map(str,range(n))) +print test_pydef(5) #: 01234 + +#%% python_pydef_nested +def foo(): + @python + def pyfoo(): + return 1 + print pyfoo() #: 1 + if True: + @python + def pyfoo2(): + return 2 + print pyfoo2() #: 2 + pass + @python + def pyfoo3(): + if 1: + return 3 + return str(pyfoo3()) +print foo() #: 3 + +#%% python_pyobj +@python +def foofn() -> Dict[pyobj, pyobj]: + return {"str": "hai", "int": 1} + +foo = foofn() +print(sorted(foo.items(), key=lambda x: str(x)), foo.__class__.__name__) +#: [('int', 1), ('str', 'hai')] Dict[pyobj,pyobj] +foo["codon"] = 5.15 +print(sorted(foo.items(), key=lambda x: str(x)), foo["codon"].__class__.__name__, foo.__class__.__name__) +#: [('codon', 5.15), ('int', 1), ('str', 'hai')] pyobj Dict[pyobj,pyobj] + +a = {1: "s", 2: "t"} +a[3] = foo["str"] +print(sorted(a.items())) #: [(1, 's'), (2, 't'), (3, 'hai')] + + +#%% python_isinstance +import python + +@python +def foo(): + return 1 + +z = foo() +print(z.__class__.__name__) #: pyobj + +print isinstance(z, pyobj) #: True +print isinstance(z, int) #: False +print isinstance(z, python.int) #: True +print isinstance(z, python.ValueError) #: False + +print isinstance(z, (int, str, python.int)) #: True +print isinstance(z, (int, str, python.AttributeError)) #: False + +try: + foo().x +except python.ValueError: + pass +except python.AttributeError as e: + print('caught', e, e.__class__.__name__) #: caught 'int' object has no attribute 'x' pyobj + + +#%% python_exceptions +import python + +@python +def foo(): + return 1 + +try: + foo().x +except python.AttributeError as f: + print 'py.Att', f #: py.Att 'int' object has no attribute 'x' +except ValueError: + print 'Val' +except PyError as e: + print 'PyError', e +try: + foo().x +except python.ValueError as f: + print 'py.Att', f +except ValueError: + print 'Val' +except PyError as e: + print 'PyError', e #: PyError 'int' object has no attribute 'x' +try: + raise ValueError("ho") +except python.ValueError as f: + print 'py.Att', f +except ValueError: + print 'Val' #: Val +except PyError as e: + print 'PyError', e + diff --git a/test/parser/typecheck/test_typecheck.codon b/test/parser/typecheck/test_typecheck.codon new file mode 100644 index 00000000..e27ee051 --- /dev/null +++ b/test/parser/typecheck/test_typecheck.codon @@ -0,0 +1,127 @@ +#%% pass,barebones +pass + +#%% print,barebones +print 1, +print 1, 2 #: 1 1 2 + +print 1, 2 #: 1 2 +print(3, "4", sep="-", end=" !\n") #: 3-4 ! + +print(1, 2) #: 1 2 +print (1, 2) #: (1, 2) + +def foo(i, j): + return i + j +print 3 |> foo(1) #: 4 + +#%% typeof_definitions,barebones +a = 10 +def foo(a)->type(a): return a +print(foo(5).__class__.__name__) #: int + +b: type(a) = 1 +print(b.__class__.__name__) #: int + +#%% multi_error,barebones +# TODO in new parser! +# a = 55 +# print z # name 'z' is not defined +# print(a, q, w) # name 'q' is not defined +# print quit # name 'quit' is not defined + +#%% static_unify,barebones +def foo(x: Callable[[1,2], 3]): pass #! Callable cannot take static types + +#%% static_unify_2,barebones +def foo(x: List[1]): pass #! expected type expression + +#%% expr,barebones +a = 5; b = 3 +print a, b #: 5 3 + +#%% delayed_instantiation_correct_context,barebones +# Test timing of the statements; ensure that delayed blocks still +# use correct names. +def foo(): + l = [] + + s = 1 # CH1 + if isinstance(l, List[int]): # delay typechecking this block + print(s) #: 1 + # if this is done badly, this print will print 's' + # or result in assertion error + print(s) #: 1 + + s = 's' # CH2 + print(s) #: s + + # instantiate l so that the block above + # is typechecked in the next iteration + l.append(1) +foo() + +# check that this does not mess up comprehensions +# (where variable names are used BEFORE their declaration) +slice_prefixes = [(start, end) + for start, end in [(1, 2), (3, 4)]] +print(slice_prefixes) #: [(1, 2), (3, 4)] + +def foo(): + # fn itself must be delayed and unbound for this to reproduce + fn = (lambda _: lambda x: x)(None) + + zizzer = 1 + y = fn(zizzer) + print(y) #: 1 + + zizzer = 's' + y = fn(zizzer) + print(y) #: s +foo() + +#%% do_not_resolve_default_generics_on_partial,barebones +def coerce(): + def foo(): pass + def bar(T1: type, I1: type = T1): + print(T1 is I1) #: False + foo() + bar(int, I1=Int[64]) # creates bar=bar(foo,...) first +coerce() + +#%% compile_error_realization,barebones +def ctx(): + def foo(): compile_error("bah!") + def bar(err: Static[bool]): + if err: foo() + else: print("ok") + bar(False) +ctx() #: ok + +#%% ctx_time_resolver,barebones +def bar(j): + # captures stdlib range, not foo's range + for i in range(*j): print(i) +def foo(range): + bar(range) +foo((1, 2)) #: 1 + +# Test whether for loop variables respect ctx->add() time +slopes = [1.0,2,3] if len("abc") <= 5 else None +if slopes is not None: + for ixx in range(3): + slopes[ixx] = ixx +for ixx in range(5): + ixx + +#%% capture_function_partial_proper_realize,barebones +def concatenate(arrays, axis = 0, out = None, dtype: type = NoneType): + def concat_inner(arrays, axis, out, dtype: type): + return 1 + + def concat_tuple(arrays, axis = 0, out = None, dtype: type = NoneType): + return concat_inner(arrays, axis, out=None, dtype=dtype) + + return 1 + +print concatenate((1, 2)) #: 1 diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index 78ba102a..e69de29b 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -1,881 +0,0 @@ -#%% bool,barebones -a = True -print a.__class__.__name__ #: bool - -#%% int,barebones -i = 15 -print i.__class__.__name__ #: int - -#%% float,barebones -a = 1.11 -print a.__class__.__name__ #: float - -#%% str,barebones -a = 'hi' -print a.__class__.__name__ #: str - -#%% none_unbound,barebones -a = None - -#%% list_unbound,barebones -a = [] -#! cannot typecheck the program - - -#%% id_static,barebones -def foo[N: Static[int]](): - print N -foo(5) #: 5 - -def fox(N: Static[int]): - print N -fox(6) #: 6 - -#%% if,barebones -y = 1 if True else 2 -print y.__class__.__name__ #: int - -a = None -b = 5 -z = a if bool(True) else b # needs bool to prevent static evaluation -print z, z.__class__.__name__ #: None Optional[int] - -zz = 1.11 if True else None -print zz, zz.__class__.__name__ #: 1.11 float - -#%% binary,barebones -@extend -class float: - def __add__(self, i: int): print 'add'; return 0 - def __sub__(self, i: int): print 'sub'; return 0 - def __mul__(self, i: int): print 'mul'; return 0 - def __pow__(self, i: int): print 'pow'; return 0 - def __truediv__(self, i: int): print 'truediv'; return 0 - def __floordiv__(self, i: int): print 'div'; return 0 - def __matmul__(self, i: int): print 'matmul'; return 0 - def __mod__(self, i: int): print 'mod'; return 0 - def __lt__(self, i: int): print 'lt'; return 0 - def __le__(self, i: int): print 'le'; return 0 - def __gt__(self, i: int): print 'gt'; return 0 - def __ge__(self, i: int): print 'ge'; return 0 - def __eq__(self, i: int): print 'eq'; return 0 - def __ne__(self, i: int): print 'ne'; return 0 - def __lshift__(self, i: int): print 'lshift'; return 0 - def __rshift__(self, i: int): print 'rshift'; return 0 - def __and__(self, i: int): print 'and'; return 0 - def __or__(self, i: int): print 'or'; return 0 - def __xor__(self, i: int): print 'xor'; return 0 -# double assignment to disable propagation -def f(x): return x -a = f(1.0) -a = f(5.0) -a + f(1) #: add -# wrap in function to disable canonicalization -a - f(1) #: sub -a * f(2) #: mul -a ** f(2) #: pow -a // f(2) #: div -a / f(2) #: truediv -a @ f(1) #: matmul -a % f(1) #: mod -a < f(1) #: lt -a <= f(1) #: le -a > f(1) #: gt -a >= f(1) #: ge -a == f(1) #: eq -a != f(1) #: ne -a << f(1) #: lshift -a >> f(1) #: rshift -a & f(1) #: and -a | f(1) #: or -a ^ f(1) #: xor - -#%% binary_rmagic,barebones -class Foo[T]: - def __add__(self, other: T): - print 'add' - return self - def __radd__(self, other: T): - print 'radd' - return self -foo = Foo[int]() -foo + 1 #: add -1 + foo #: radd - -#%% binary_short_circuit,barebones -def moo(): - print 'moo' - return True -print True or moo() #: True -print moo() or True #: moo -#: True -print False and moo() #: False -print moo() and False #: moo -#: False - -#%% binary_is,barebones -print 5 is None #: False -print None is None #: True -print (None if bool(True) else 1) is None #: True -print (None if bool(False) else 1) is None #: False - -print 5 is 5.0 #: False -print 5 is 6 #: False -print 5 is 5 #: True -print 5 is 1.12 #: False -class Foo: - a: int -x = Foo(1) -y = Foo(1) -z = x -print x is x, x is y, x is z, z is x, z is y #: True False True True False - -a, b, c, d = Optional(5), Optional[int](), Optional(5), Optional(4) -print a is a, a is b, b is b, a is c, a is d #: True False True True False -aa, bb, cc, dd = Optional(Foo(1)), Optional[Foo](), Optional(Foo(1)), Optional(Foo(2)) -print aa is aa, aa is bb, bb is bb, aa is cc, aa is dd #: True False True False False - - -#%% pipe,barebones -def foo(a, b): - return a+b -bar = lambda c, d: c+d -def hai(e): - while e > 0: - yield e - e -= 2 -def echo(s): - print s -foo(1,2) |> bar(4) |> echo #: 7 -foo(1,2) |> bar(4) |> hai |> echo -#: 7 -#: 5 -#: 3 -#: 1 - -#%% pipe_prepend,barebones -def foo(a: Optional[int]): - print a - return 1 -5 |> foo #: 5 -None |> foo #: None -print (None |> foo).__class__.__name__ #: int - -def foo2(a: int): - print a - return 1 -Optional(5) |> foo2 #: 5 -try: - Optional[int]() |> foo2 -except ValueError as e: - print e.message #: optional is None - -#%% pipe_prepend_error,barebones -def foo2(a: int): - print a - return 1 -try: - None |> foo2 -except ValueError: - print 'exception' #: exception -# Explanation: None can also be Optional[Generator[int]] -# We cannot decide if this is a generator to be unrolled in a pipe, -# or just an argument to be passed to a function. -# So this will default to NoneType at the end. - -#%% instantiate_err,barebones -def foo[N](): - return N() -foo(int, float) #! foo() takes 1 arguments (2 given) - -#%% instantiate_err_2,barebones -def foo[N, T](): - return N() -foo(int) #! generic 'T' not provided - -#%% instantiate_err_3,barebones -Ptr[int, float]() #! Ptr takes 1 generics (2 given) - -#%% slice,barebones -z = [1, 2, 3, 4, 5] -y = (1, 'foo', True) -print z[2], y[1] #: 3 foo -print z[:1], z[1:], z[1:3], z[:4:2], z[::-1] #: [1] [2, 3, 4, 5] [2, 3] [1, 3] [5, 4, 3, 2, 1] - -#%% static_index,barebones -a = (1, '2s', 3.3) -print a[1] #: 2s -print a[0:2], a[:2], a[1:] #: (1, '2s') (1, '2s') ('2s', 3.3) -print a[0:3:2], a[-1:] #: (1, 3.3) (3.3,) - -#%% static_index_side,barebones -def foo(a): - print(a) - return a - -print (foo(2), foo(1))[::-1] -#: 2 -#: 1 -#: (1, 2) -print (foo(1), foo(2), foo(3), foo(4))[2] -#: 1 -#: 2 -#: 3 -#: 4 -#: 3 - -#%% static_index_lenient,barebones -a = (1, 2) -print a[3:5] #: () - -#%% static_index_err,barebones -a = (1, 2) -a[5] #! tuple index out of range (expected 0..1, got instead 5) - -#%% static_index_err_2,barebones -a = (1, 2) -a[-3] #! tuple index out of range (expected 0..1, got instead -1) - -#%% index_func_instantiate,barebones -class X: - def foo[T](self, x: T): - print x.__class__.__name__, x -x = X() -x.foo(5, int) #: int 5 - -#%% index,barebones -l = [1, 2, 3] -print l[2] #: 3 - -#%% index_two_rounds,barebones -l = [] -print l[::-1] #: [] -l.append(('str', 1, True, 5.15)) -print l, l.__class__.__name__ #: [('str', 1, True, 5.15)] List[Tuple[str,int,bool,float]] - -#%% dot_case_1,barebones -a = [] -print a[0].loop() #! 'int' object has no attribute 'loop' -a.append(5) - -#%% dot_case_2,barebones -a = Optional(0) -print a.__bool__() #: False -print a.__add__(1) #: 1 - -#%% dot_case_4,barebones -a = [5] -print a.len #: 1 - -#%% dot_case_4_err,barebones -a = [5] -a.foo #! 'List[int]' object has no attribute 'foo' - -#%% dot_case_6,barebones -# Did heavy changes to this testcase because -# of the automatic optional wraps/unwraps and promotions -class Foo: - def bar(self, a): - print 'generic', a, a.__class__.__name__ - def bar(self, a: Optional[float]): - print 'optional', a - def bar(self, a: int): - print 'normal', a -f = Foo() -f.bar(1) #: normal 1 -f.bar(1.1) #: optional 1.1 -f.bar(Optional('s')) #: generic s Optional[str] -f.bar('hehe') #: generic hehe str - - -#%% dot_case_6b,barebones -class Foo: - def bar(self, a, b): - print '1', a, b - def bar(self, a, b: str): - print '2', a, b - def bar(self, a: str, b): - print '3', a, b -f = Foo() -# Take the newest highest scoring method -f.bar('s', 't') #: 3 s t -f.bar(1, 't') #: 2 1 t -f.bar('s', 1) #: 3 s 1 -f.bar(1, 2) #: 1 1 2 - -#%% dot,barebones -class Foo: - def clsmethod(): - print 'foo' - def method(self, a): - print a -Foo.clsmethod() #: foo -Foo.method(Foo(), 1) #: 1 -m1 = Foo.method -m1(Foo(), 's') #: s -m2 = Foo().method -m2(1.1) #: 1.1 - -#%% dot_error_static,barebones -class Foo: - def clsmethod(): - print 'foo' - def method(self, a): - print a -Foo().clsmethod() #! 'Foo' object has no method 'clsmethod' with arguments (Foo) - -#%% call,barebones -def foo(a, b, c='hi'): - print 'foo', a, b, c - return 1 -class Foo: - def __init__(self): - print 'Foo.__init__' - def foo(self, a): - print 'Foo.foo', a - return 's' - def bar[T](self, a: T): - print 'Foo.bar', a - return a.__class__.__name__ - def __call__(self, y): - print 'Foo.__call__' - return foo(2, y) - -foo(1, 2.2, True) #: foo 1 2.2 True -foo(1, 2.2) #: foo 1 2.2 hi -foo(b=2.2, a=1) #: foo 1 2.2 hi -foo(b=2.2, c=12u, a=1) #: foo 1 2.2 12 - -f = Foo() #: Foo.__init__ -print f.foo(a=5) #: Foo.foo 5 -#: s -print f.bar(a=1, T=int) #: Foo.bar 1 -#: int -print Foo.bar(Foo(), 1.1, T=float) #: Foo.__init__ -#: Foo.bar 1.1 -#: float -print Foo.bar(Foo(), 's') #: Foo.__init__ -#: Foo.bar s -#: str -print f('hahaha') #: Foo.__call__ -#: foo 2 hahaha hi -#: 1 - -@tuple -class Moo: - moo: int - def __new__(i: int) -> Moo: - print 'Moo.__new__' - return (i,) -print Moo(1) #: Moo.__new__ -#: (moo: 1) - -#%% call_err_2,barebones -class A: - a: A -a = A() #! argument 'a' has recursive default value - -#%% call_err_3,barebones -class G[T]: - t: T -class A: - ga: G[A] -a = A() #! argument 'ga' has recursive default value - -#%% call_err_4,barebones -seq_print_full(1, name="56", name=2) #! keyword argument repeated: name - -#%% call_partial,barebones -def foo(i, j, k): - return i + j + k -print foo(1.1, 2.2, 3.3) #: 6.6 -p = foo(6, ...) -print p.__class__.__name__ #: foo[int,...,...] -print p(2, 1) #: 9 -print p(k=3, j=6) #: 15 -q = p(k=1, ...) -print q(3) #: 10 -qq = q(2, ...) -print qq() #: 9 -# -add_two = foo(3, k=-1, ...) -print add_two(42) #: 44 -print 3 |> foo(1, 2) #: 6 -print 42 |> add_two #: 44 -# -def moo(a, b, c=3): - print a, b, c -m = moo(b=2, ...) -print m.__class__.__name__ #: moo[...,int,...] -m('s', 1.1) #: s 2 1.1 -# # -n = m(c=2.2, ...) -print n.__class__.__name__ #: moo[...,int,float] -n('x') #: x 2 2.2 -print n('y').__class__.__name__ #: NoneType - -def ff(a, b, c): - return a, b, c -print ff(1.1, 2, True).__class__.__name__ #: Tuple[float,int,bool] -print ff(1.1, ...)(2, True).__class__.__name__ #: Tuple[float,int,bool] -y = ff(1.1, ...)(c=True, ...) -print y.__class__.__name__ #: ff[float,...,bool] -print ff(1.1, ...)(2, ...)(True).__class__.__name__ #: Tuple[float,int,bool] -print y('hei').__class__.__name__ #: Tuple[float,str,bool] -z = ff(1.1, ...)(c='s', ...) -print z.__class__.__name__ #: ff[float,...,str] - -def fx(*args, **kw): - print(args, kw) -f1 = fx(1, x=1, ...) -f2 = f1(2, y=2, ...) -f3 = f2(3, z=3, ...) -f3() -#: (1, 2, 3) (x: 1, y: 2, z: 3) - -#%% call_arguments_partial,barebones -def doo[R, T](a: Callable[[T], R], b: Generator[T], c: Optional[T], d: T): - print R.__class__.__name__, T.__class__.__name__ - print a.__class__.__name__[:8], b.__class__.__name__ - for i in b: - print a(i) - print c, c.__class__.__name__ - print d, d.__class__.__name__ - -l = [1, 2, 3] -doo(b=l, d=Optional(5), c=l[0], a=lambda x: x+1) -#: int int -#: ._lambda Generator[int] -#: 2 -#: 3 -#: 4 -#: 1 Optional[int] -#: 5 int - -l = [1] -def adder(a, b): return a+b -doo(b=l, d=Optional(5), c=l[0], a=adder(b=4, ...)) -#: int int -#: adder[.. Generator[int] -#: 5 -#: 1 Optional[int] -#: 5 int - -#%% call_partial_star,barebones -def foo(x, *args, **kwargs): - print x, args, kwargs -p = foo(...) -p(1, z=5) #: 1 () (z: 5) -p('s', zh=65) #: s () (zh: 65) -q = p(zh=43, ...) -q(1) #: 1 () (zh: 43) -r = q(5, 38, ...) -r() #: 5 (38,) (zh: 43) -r(1, a=1) #: 5 (38, 1) (zh: 43, a: 1) - -#%% call_args_kwargs_type,barebones -def foo(*args: float, **kwargs: int): - print(args, kwargs, args.__class__.__name__) - -foo(1, f=1) #: (1,) (f: 1) Tuple[float] -foo(1, 2.1, 3, z=2) #: (1, 2.1, 3) (z: 2) Tuple[float,float,float] - -def sum(x: Generator[int]): - a = 0 - for i in x: - a += i - return a - -def sum_gens(*x: Generator[int]) -> int: - a = 0 - for i in x: - a += sum(i) - return a -print sum_gens([1, 2, 3]) #: 6 -print sum_gens({1, 2, 3}) #: 6 -print sum_gens(iter([1, 2, 3])) #: 6 - -#%% call_kwargs,barebones -def kwhatever(**kwargs): - print 'k', kwargs -def whatever(*args): - print 'a', args -def foo(a, b, c=1, *args, **kwargs): - print a, b, c, args, kwargs - whatever(a, b, *args, c) - kwhatever(x=1, **kwargs) -foo(1, 2, 3, 4, 5, arg1='s', kwa=2) -#: 1 2 3 (4, 5) (arg1: 's', kwa: 2) -#: a (1, 2, 4, 5, 3) -#: k (arg1: 's', kwa: 2, x: 1) -foo(1, 2) -#: 1 2 1 () () -#: a (1, 2, 1) -#: k (x: 1) -foo(1, 2, 3) -#: 1 2 3 () () -#: a (1, 2, 3) -#: k (x: 1) -foo(1, 2, 3, 4) -#: 1 2 3 (4,) () -#: a (1, 2, 4, 3) -#: k (x: 1) -foo(1, 2, zamboni=3) -#: 1 2 1 () (zamboni: 3) -#: a (1, 2, 1) -#: k (x: 1, zamboni: 3) - -#%% call_unpack,barebones -def foo(*args, **kwargs): - print args, kwargs - -@tuple -class Foo: - x: int = 5 - y: bool = True - -t = (1, 's') -f = Foo(6) -foo(*t, **f) #: (1, 's') (x: 6, y: True) -foo(*(1,2)) #: (1, 2) () -foo(3, f) #: (3, (x: 6, y: True)) () -foo(k = 3, **f) #: () (k: 3, x: 6, y: True) - -#%% call_partial_args_kwargs,barebones -def foo(*args): - print(args) -a = foo(1, 2, ...) -b = a(3, 4, ...) -c = b(5, ...) -c('zooooo') -#: (1, 2, 3, 4, 5, 'zooooo') - -def fox(*args, **kwargs): - print(args, kwargs) -xa = fox(1, 2, x=5, ...) -xb = xa(3, 4, q=6, ...) -xc = xb(5, ...) -xd = xc(z=5.1, ...) -xd('zooooo', w='lele') -#: (1, 2, 3, 4, 5, 'zooooo') (x: 5, q: 6, z: 5.1, w: 'lele') - -class Foo: - i: int - def __str__(self): - return f'#{self.i}' - def foo(self, a): - return f'{self}:generic' - def foo(self, a: float): - return f'{self}:float' - def foo(self, a: int): - return f'{self}:int' -f = Foo(4) - -def pacman(x, f): - print f(x, '5') - print f(x, 2.1) - print f(x, 4) -pacman(f, Foo.foo) -#: #4:generic -#: #4:float -#: #4:int - -def macman(f): - print f('5') - print f(2.1) - print f(4) -macman(f.foo) -#: #4:generic -#: #4:float -#: #4:int - -class Fox: - i: int - def __str__(self): - return f'#{self.i}' - def foo(self, a, b): - return f'{self}:generic b={b}' - def foo(self, a: float, c): - return f'{self}:float, c={c}' - def foo(self, a: int): - return f'{self}:int' - def foo(self, a: int, z, q): - return f'{self}:int z={z} q={q}' -ff = Fox(5) -def maxman(f): - print f('5', b=1) - print f(2.1, 3) - print f(4) - print f(5, 1, q=3) -maxman(ff.foo) -#: #5:generic b=1 -#: #5:float, c=3 -#: #5:int -#: #5:int z=1 q=3 - - -#%% call_static,barebones -print isinstance(1, int), isinstance(2.2, float), isinstance(3, bool) -#: True True False -print isinstance((1, 2), Tuple), isinstance((1, 2), Tuple[int, int]), isinstance((1, 2), Tuple[float, int]) -#: True True False -print isinstance([1, 2], List), isinstance([1, 2], List[int]), isinstance([1, 2], List[float]) -#: True True False -print isinstance({1, 2}, List), isinstance({1, 2}, Set[float]) -#: False False -print isinstance(Optional(5), Optional[int]), isinstance(Optional(), Optional) -#: True True -print isinstance(Optional(), Optional[int]), isinstance(Optional('s'), Optional[int]) -#: False False -print isinstance(None, Optional), isinstance(None, Optional[int]) -#: True False -print isinstance(None, Optional[NoneType]) -#: True -print isinstance({1, 2}, List) -#: False - -print staticlen((1, 2, 3)), staticlen((1, )), staticlen('hehe') -#: 3 1 2 - -print hasattr([1, 2], "__getitem__") -#: True -print hasattr(type([1, 2]), "__getitem__") -#: True -print hasattr(int, "__getitem__") -#: False -print hasattr([1, 2], "__getitem__", str) -#: False - -#%% isinstance_inheritance,barebones -class AX[T]: - a: T - def __init__(self, a: T): - self.a = a -class Side: - def __init__(self): - pass -class BX[T,U](Static[AX[T]], Static[Side]): - b: U - def __init__(self, a: T, b: U): - super().__init__(a) - self.b = b -class CX[T,U](Static[BX[T,U]]): - c: int - def __init__(self, a: T, b: U): - super().__init__(a, b) - self.c = 1 -c = CX('a', False) -print isinstance(c, CX), isinstance(c, BX), isinstance(c, AX), isinstance(c, Side) -#: True True True True -print isinstance(c, BX[str, bool]), isinstance(c, BX[str, str]), isinstance(c, AX[int]) -#: True False False - -#%% staticlen_err,barebones -print staticlen([1, 2]) #! expected tuple type - -#%% compile_error,barebones -compile_error("woo-hoo") #! woo-hoo - -#%% stack_alloc,barebones -a = __array__[int](2) -print a.__class__.__name__ #: Array[int] - -#%% typeof,barebones -a = 5 -z = [] -z.append(6) -print z.__class__.__name__, z, type(1.1).__class__.__name__ #: List[int] [6] float - -#%% ptr,barebones -v = 5 -c = __ptr__(v) -print c.__class__.__name__ #: Ptr[int] - -#%% yieldexpr,barebones -def mysum(start): - m = start - while True: - a = (yield) - print a.__class__.__name__ #: int - if a == -1: - break - m += a - yield m -iadder = mysum(0) -next(iadder) -for i in range(10): - iadder.send(i) -#: int -#: int -#: int -#: int -#: int -#: int -#: int -#: int -#: int -#: int -print iadder.send(-1) #: 45 - -#%% function_typecheck_level,barebones -def foo(x): - def bar(z): # bar has a parent foo(), however its unbounds must not be generalized! - print z - bar(x) - bar('x') -foo(1) -#: 1 -#: x -foo('s') -#: s -#: x - -#%% tuple_generator,barebones -a = (1, 2) -b = ('f', 'g') -print a, b #: (1, 2) ('f', 'g') -c = (*a, True, *b) -print c #: (1, 2, True, 'f', 'g') -print a + b + c #: (1, 2, 'f', 'g', 1, 2, True, 'f', 'g') -print () + (1, ) + ('a', 'b') #: (1, 'a', 'b') - -t = tuple(i+1 for i in (1,2,3)) -print t #: (2, 3, 4) -print tuple((j, i) for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))) -#: (('a', 1), ('b', 2), ('c', 3)) - -#%% tuple_fn,barebones -@tuple -class unpackable_plain: - a: int - b: str - -u = unpackable_plain(1, 'str') -a, b = tuple(u) -print a, b #: 1 str - -@tuple -class unpackable_gen: - a: int - b: T - T: type - -u2 = unpackable_gen(1, 'str') -a2, b2 = tuple(u2) -print a2,b2 #: 1 str - -class plain: - a: int - b: str - -c = plain(3, 'heh') -z = tuple(c) -print z, z.__class__.__name__ #: (3, 'heh') Tuple[int,str] - -#%% static_unify,barebones -def foo(x: Callable[[1,2], 3]): pass #! '2' does not match expected type 'T1' - -#%% static_unify_2,barebones -def foo(x: List[1]): pass #! '1' does not match expected type 'T' - -#%% super,barebones -class A[T]: - a: T - def __init__(self, t: T): - self.a = t - def foo(self): - return f'A:{self.a}' -class B(Static[A[str]]): - b: int - def __init__(self): - super().__init__('s') - self.b = 6 - def baz(self): - return f'{super().foo()}::{self.b}' -b = B() -print b.foo() #: A:s -print b.baz() #: A:s::6 - -class AX[T]: - a: T - def __init__(self, a: T): - self.a = a - def foo(self): - return f'[AX:{self.a}]' -class BX[T,U](Static[AX[T]]): - b: U - def __init__(self, a: T, b: U): - print super().__class__.__name__ - super().__init__(a) - self.b = b - def foo(self): - return f'[BX:{super().foo()}:{self.b}]' -class CX[T,U](Static[BX[T,U]]): - c: int - def __init__(self, a: T, b: U): - print super().__class__.__name__ - super().__init__(a, b) - self.c = 1 - def foo(self): - return f'CX:{super().foo()}:{self.c}' -c = CX('a', False) -print c.__class__.__name__, c.foo() -#: BX[str,bool] -#: AX[str] -#: CX[str,bool] CX:[BX:[AX:a]:False]:1 - -#%% super_vtable_2 -class Base: - def test(self): - print('base.test') -class A(Base): - def test(self): - super().test() - Base.test(self) - print('a.test') -a = A() -a.test() -def moo(x: Base): - x.test() -moo(a) -Base.test(a) -#: base.test -#: base.test -#: a.test -#: base.test -#: base.test -#: a.test -#: base.test - -#%% super_tuple,barebones -@tuple -class A[T]: - a: T - x: int - def __new__(a: T) -> A[T]: - return (a, 1) - def foo(self): - return f'A:{self.a}' -@tuple -class B(Static[A[str]]): - b: int - def __new__() -> B: - return (*(A('s')), 6) - def baz(self): - return f'{super().foo()}::{self.b}' - -b = B() -print b.foo() #: A:s -print b.baz() #: A:s::6 - - -#%% super_error,barebones -class A: - def __init__(self): - super().__init__() -a = A() -#! no super methods found -#! during the realization of __init__(self: A) - -#%% super_error_2,barebones -super().foo(1) #! no super methods found diff --git a/test/parser/typecheck_stmt.codon b/test/parser/typecheck_stmt.codon index dfa56741..e69de29b 100644 --- a/test/parser/typecheck_stmt.codon +++ b/test/parser/typecheck_stmt.codon @@ -1,393 +0,0 @@ -#%% expr,barebones -a = 5; b = 3 -print a, b #: 5 3 - -#%% assign_optional,barebones -a = None -print a #: None -a = 5 -print a #: 5 - -b: Optional[float] = Optional[float](6.5) -c: Optional[float] = 5.5 -print b, c #: 6.5 5.5 - -#%% assign_type_alias,barebones -I = int -print I(5) #: 5 - -L = Dict[int, str] -l = L() -print l #: {} -l[5] = 'haha' -print l #: {5: 'haha'} - -#%% assign_type_annotation,barebones -a: List[int] = [] -print a #: [] - -#%% assign_type_err,barebones -a = 5 -if 1: - a = 3.3 #! 'float' does not match expected type 'int' -a - -#%% assign_atomic,barebones -i = 1 -f = 1.1 - -@llvm -def xchg(d: Ptr[int], b: int) -> None: - %tmp = atomicrmw xchg i64* %d, i64 %b seq_cst - ret {} {} -@llvm -def aadd(d: Ptr[int], b: int) -> int: - %tmp = atomicrmw add i64* %d, i64 %b seq_cst - ret i64 %tmp -@llvm -def amin(d: Ptr[int], b: int) -> int: - %tmp = atomicrmw min i64* %d, i64 %b seq_cst - ret i64 %tmp -@llvm -def amax(d: Ptr[int], b: int) -> int: - %tmp = atomicrmw max i64* %d, i64 %b seq_cst - ret i64 %tmp -def min(a, b): return a if a < b else b -def max(a, b): return a if a > b else b - -@extend -class int: - def __atomic_xchg__(self: Ptr[int], i: int): - print 'atomic:', self[0], '<-', i - xchg(self, i) - def __atomic_add__(self: Ptr[int], i: int): - print 'atomic:', self[0], '+=', i - return aadd(self, i) - def __atomic_min__(self: Ptr[int], b: int): - print 'atomic:', self[0], '?=', b - return amax(self, b) - -@atomic -def foo(x): - global i, f - - i += 1 #: atomic: 1 += 1 - print i #: 2 - i //= 2 #: atomic: 2 <- 1 - print i #: 1 - i = 3 #: atomic: 1 <- 3 - print i #: 3 - i = min(i, 10) #: atomic: 3 ?= 10 - print i #: 10 - i = max(20, i) #: atomic: 10 <- 20 - print i #: 20 - - f += 1.1 - f = 3.3 - f = max(f, 5.5) -foo(1) -print i, f #: 20 5.5 - -#%% assign_atomic_real -i = 1 -f = 1.1 -@atomic -def foo(x): - global i, f - - i += 1 - print i #: 2 - i //= 2 - print i #: 1 - i = 3 - print i #: 3 - i = min(i, 10) - print i #: 3 - i = max(i, 10) - print i #: 10 - - f += 1.1 - f = 3.3 - f = max(f, 5.5) -foo(1) -print i, f #: 10 5.5 - -#%% assign_member,barebones -class Foo: - x: Optional[int] -f = Foo() -print f.x #: None -f.x = 5 -print f.x #: 5 - -fo = Optional(Foo()) -fo.x = 6 -print fo.x #: 6 - -#%% assign_member_err_1,barebones -class Foo: - x: Optional[int] -Foo().y = 5 #! 'Foo' object has no attribute 'y' - -#%% assign_member_err_2,barebones -@tuple -class Foo: - x: Optional[int] -Foo().x = 5 #! cannot modify tuple attributes - -#%% return,barebones -def foo(): - return 1 -print foo() #: 1 - -def bar(): - print 2 - return - print 1 -bar() #: 2 - -#%% yield,barebones -def foo(): - yield 1 -print [i for i in foo()], str(foo())[:16] #: [1] 0: - return - else: - return 1 -foo(1) -#! 'NoneType' does not match expected type 'int' -#! during the realization of foo(n: int) - -#%% return_none_err_2,barebones -def foo(n: int): - if n > 0: - return 1 - return -foo(1) -#! 'int' does not match expected type 'NoneType' -#! during the realization of foo(n: int) - -#%% while,barebones -a = 3 -while a: - print a - a -= 1 -#: 3 -#: 2 -#: 1 - -#%% for_break_continue,barebones -for i in range(10): - if i % 2 == 0: - continue - print i - if i >= 5: - break -#: 1 -#: 3 -#: 5 - -#%% for_error,barebones -for i in 1: - pass -#! 'int' object has no attribute '__iter__' - -#%% for_void,barebones -def foo(): yield -for i in foo(): - print i.__class__.__name__ #: NoneType - -#%% if,barebones -for a, b in [(1, 2), (3, 3), (5, 4)]: - if a > b: - print '1', - elif a == b: - print '=', - else: - print '2', -print '_' #: 2 = 1 _ - -if 1: - print '1' #: 1 - -#%% static_if,barebones -def foo(x, N: Static[int]): - if isinstance(x, int): - return x + 1 - elif isinstance(x, float): - return x.__pow__(.5) - elif isinstance(x, Tuple[int, str]): - return f'foo: {x[1]}' - elif isinstance(x, Tuple) and (N >= 3 or staticlen(x) > 2): - return x[2:] - elif hasattr(x, '__len__'): - return 'len ' + str(x.__len__()) - else: - compile_error('invalid type') -print foo(N=1, x=1) #: 2 -print foo(N=1, x=2.0) #: 1.41421 -print foo(N=1, x=(1, 'bar')) #: foo: bar -print foo(N=1, x=(1, 2)) #: len 2 -print foo(N=3, x=(1, 2)) #: () -print foo(N=1, x=(1, 2, 3)) #: (3,) - -#%% try_throw,barebones -class MyError(Static[Exception]): - def __init__(self, message: str): - super().__init__('MyError', message) -try: - raise MyError("hello!") -except MyError as e: - print str(e) #: hello! -try: - raise OSError("hello os!") -# TODO: except (MyError, OSError) as e: -# print str(e) -except MyError: - print "my" -except OSError as o: - print "os", o.typename, len(o.message), o.file[-20:], o.line - #: os OSError 9 typecheck_stmt.codon 284 -finally: - print "whoa" #: whoa - -# Test function name -def foo(): - raise MyError("foo!") -try: - foo() -except MyError as e: - print e.typename, e.message #: MyError foo! - -#%% throw_error,barebones -raise 'hello' -#! exceptions must derive from BaseException - -#%% function_builtin_error,barebones -@__force__ -def foo(x): - pass -#! builtin, exported and external functions cannot be generic - -#%% extend,barebones -@extend -class int: - def run_lola_run(self): - while self > 0: - yield self - self -= 1 -print list((5).run_lola_run()) #: [5, 4, 3, 2, 1] - - -#%% early_return,barebones -def foo(x): - print x-1 - return - print len(x) -foo(5) #: 4 - -def foo2(x): - if isinstance(x, int): - print x+1 - return - print len(x) -foo2(1) #: 2 -foo2('s') #: 1 - -#%% superf,barebones -class Foo: - def foo(a): - # superf(a) - print 'foo-1', a - def foo(a: int): - superf(a) - print 'foo-2', a - def foo(a: str): - superf(a) - print 'foo-3', a - def foo(a): - superf(a) - print 'foo-4', a -Foo.foo(1) -#: foo-1 1 -#: foo-2 1 -#: foo-4 1 - -class Bear: - def woof(x): - return f'bear woof {x}' -@extend -class Bear: - def woof(x): - return superf(x) + f' bear w--f {x}' -print Bear.woof('!') -#: bear woof ! bear w--f ! - -class PolarBear(Static[Bear]): - def woof(): - return 'polar ' + superf('@') -print PolarBear.woof() -#: polar bear woof @ bear w--f @ - -#%% superf_error,barebones -class Foo: - def foo(a): - superf(a) - print 'foo-1', a -Foo.foo(1) -#! no superf methods found -#! during the realization of foo(a: int) - -#%% staticmethod,barebones -class Foo: - def __repr__(self): - return 'Foo' - def m(self): - print 'm', self - @staticmethod - def sm(i): - print 'sm', i -Foo.sm(1) #: sm 1 -Foo().sm(2) #: sm 2 -Foo().m() #: m Foo diff --git a/test/parser/types.codon b/test/parser/types.codon index cf645622..e69de29b 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -1,2102 +0,0 @@ -#%% basic,barebones -a = 5 -b: float = 6.16 -c: optional[str] = None -print a, b, c #: 5 6.16 None - -#%% late_unify,barebones -a = [] -a.append(1) -print a #: [1] -print [1]+[1] #: [1, 1] - -#%% late_unify_2,barebones -class XX[T]: - y: T -a = XX() -def f(i: int) -> int: - return i -print a.y.__class__.__name__ #: int -f(a.y) -print a.__class__.__name__ #: XX[int] -print XX[bool].__class__.__name__ #: XX[bool] - -#%% nested_generic,barebones -x = Array[Array[int]](0) -f = Optional[Optional[Optional[int]]](Optional[Optional[int]](Optional[int](5))) -print x.len, f #: 0 5 - -#%% map_unify -def map[T,S](l: List[T], f: Callable[[T], S]): - return [f(x) for x in l] -e = 1 -print map([1, 2, 3], lambda x: x+e) #: [2, 3, 4] - -def map2(l, f): - return [f(x) for x in l] -print map2([1, 2, 3], lambda x: x+e) #: [2, 3, 4] - -#%% nested,barebones -def m4[TD](a: int, d: TD): - def m5[TD,TE](a: int, d: TD, e: TE): - print a, d, e - m5(a, d, 1.12) -m4(1, 's') #: 1 s 1.12 -m4(1, True) #: 1 True 1.12 - -#%% nested_class,barebones -class A[TA]: - a: TA - # lots of nesting: - def m4[TD](self: A[TA], d: TD): - def m5[TA,TD,TE](a: TA, d: TD, e: TE): - print a, d, e - m5(self.a, d, d) -ax = A(42) -ax.m4(1) #: 42 1 1 - -#%% static_fn,barebones -class A[TA]: - a: TA - def dump(a, b, c): - print a, b, c - def m2(): - A.dump(1, 2, 's') - def __str__(self): - return 'A' -A.dump(1, 2, 3) #: 1 2 3 -A[int].m2() #: 1 2 s -A.m2() #: 1 2 s -c = A[str]('s') -c.dump('y', 1.1) #: A y 1.1 - -#%% static_fn_overload,barebones -def foo(x: Static[int]): - print('int', x) - -@overload -def foo(x: Static[str]): - print('str', x) - -foo(10) -#: int 10 -foo('s') -#: str s - -#%% realization_big -class A[TA,TB,TC]: - a: TA - b: TB - c: TC - - def dump(a, b, c): - print a, b, c - - # non-generic method: - def m0(self: A[TA,TB,TC], a: int): - print a - - # basic generics: - def m1[X](self: A[TA,TB,TC], other: A[X,X,X]): - print other.a, other.b, other.c - - # non-generic method referencing outer generics: - def m2(a: TA, b: TB, c: TC): - A.dump(a, b, c) - - # generic args: - def m3(self, other): - return self.a - - # lots of nesting: - def m4[TD](self: A[TA,TB,TC], d: TD): - def m5[TA,TB,TC,TD,TE](a: TA, b: TB, c: TC, d: TD, e: TE): - print a, b, c, d, e - m5(self.a, self.b, self.c, d, d) - - # instantiating the type: - def m5(self): - x = A(self.a, self.b, self.c) - A.dump(x.a, x.b, x.c) - - # deeply nested generic type: - def m6[T](v: array[array[array[T]]]): - return v[0][0][0] -a1 = A(42, 3.14, "hello") -a2 = A(1, 2, 3) -a1.m1(a2) #: 1 2 3 -A[int,float,str].m2(1, 1.0, "one") #: 1 1 one -A[int,int,int].m2(11, 22, 33) #: 11 22 33 -print a1.m3(a2) #: 42 -print a1.m3(a2) #: 42 -print a2.m3(a1) #: 1 -a1.m4(True) #: 42 3.14 hello True True -a1.m4([1]) #: 42 3.14 hello [1] [1] -a2.m4("x") #: 1 2 3 x x -a1.m5() #: 42 3.14 hello -a2.m5() #: 1 2 3 - -v1 = array[array[array[str]]](1) -v2 = array[array[str]](1) -v3 = array[str](1) -v1[0] = v2 -v2[0] = v3 -v3[0] = "world" -print A.m6(v1) #: world - -f = a2.m0 -f(99) #: 99 - -#%% realization_small,barebones -class B1[T]: - a: T - def foo[S](self: S) -> B1[int]: - return B1[int](111) -b1 = B1[bool](True).foo() -print b1.foo().a #: 111 - -class B2[T]: - a: T - def foo[S](self: B2[S]): - return B2[int](222) -b2 = B2[str]("x").foo() -print b2.foo().a #: 222 - -# explicit realization: -def m7[T,S](): - print "works" -m7(str,float) #: works -m7(str,float) #: works -m7(float,str) #: works - -#%% recursive,barebones -def foo(a): - if not a: - foo(True) - print a -foo(0) -#: True -#: 0 - -def bar(a): - def baz(x): - if not x: - bar(True) - print (x) - baz(a) -bar(0) -#: True -#: 0 - -def rec2(x, y): - if x: - return rec2(y, x) - else: - return 1.0 -print rec2(1, False).__class__.__name__ #: float - -def pq(x): - return True -def rec3(x, y): - if pq(x): - return rec3(y, x) - else: - return y -print rec3('x', 's').__class__.__name__ #: str - -# Nested mutually recursive function -def f[T](x: T) -> T: - def g[T](z): - return z(T()) - return g(f, T=T) -print f(1.2).__class__.__name__ #: float -print f('s').__class__.__name__ #: str - -def f2[T](x: T): - return f2(x - 1, T) if x else 1 -print f2(1) #: 1 -print f2(1.1).__class__.__name__ #: int - - -#%% recursive_error,barebones -def pq(x): - return True -def rec3(x, y): #- ('a, 'b) -> 'b - if pq(x): - return rec3(y, x) - else: - return y -rec3(1, 's') -#! 'int' does not match expected type 'str' -#! during the realization of rec3(x: int, y: str) - -#%% instantiate_function_2,barebones -def fx[T](x: T) -> T: - def g[T](z): - return z(T()) - return g(fx, T) -print fx(1.1).__class__.__name__, fx(1).__class__.__name__ #: float int - -#%% optionals,barebones -y = None -print y #: None -y = 5 -print y #: 5 - -def foo(x: optional[int], y: int): - print 'foo', x, y -foo(y, 6) #: foo 5 6 -foo(5, 6) #: foo 5 6 -foo(5, y) #: foo 5 5 -y = None -try: - foo(5, y) -except ValueError: - print 'unwrap failed' #: unwrap failed - -class Cls: - x: int -c = None -for i in range(2): - if c: c.x += 1 # check for unwrap() dot access - c = Cls(1) -print(c.x) #: 1 - -#%% optional_methods,barebones -@extend -class int: - def x(self): - print 'x()!', self - -y = None -z = 1 if y else None -print z #: None - -y = 6 -z = 1 + y if y else None -print z #: 7 -z.x() #: x()! 7 -if 1: # otherwise compiler won't compile z.x() later - z = None -try: - z.x() -except ValueError: - print 'unwrap failed' #: unwrap failed - -print Optional(1) + Optional(2) #: 3 -print Optional(1) + 3 #: 4 -print 1 + Optional(1) #: 2 - -#%% optional_tuple,barebones -a = None -if True: - a = ('x', 's') -print(a) #: ('x', 's') -print(*a, (1, *a)) #: x s (1, 'x', 's') -x,y=a -print(x,y,[*a]) #: x s ['x', 's'] - -#%% global_none,barebones -a, b = None, None -def foo(): - global a, b - a = [1, 2] - b = 3 -print a, b, -foo() -print a, b #: None None [1, 2] 3 - -#%% default_type_none -class Test: - value: int - def __init__(self, value: int): - self.value = value - def __repr__(self): - return str(self.value) -def key_func(k: Test): - return k.value -print sorted([Test(1), Test(3), Test(2)], key=key_func) #: [1, 2, 3] -print sorted([Test(1), Test(3), Test(2)], key=lambda x: x.value) #: [1, 2, 3] -print sorted([1, 3, 2]) #: [1, 2, 3] - -#%% nested_map -print list(map(lambda i: i-2, map(lambda i: i+1, range(5)))) -#: [-1, 0, 1, 2, 3] - -def h(x: list[int]): - return x -print h(list(map(lambda i: i-1, map(lambda i: i+2, range(5))))) -#: [1, 2, 3, 4, 5] - -#%% func_unify_error,barebones -def foo(x:int): - print x -z = 1 & foo #! unsupported operand type(s) for &: 'int' and 'foo[int]' - -#%% tuple_type_late,barebones -coords = [] -for i in range(2): - coords.append( ('c', i, []) ) -coords[0][2].append((1, 's')) -print(coords) #: [('c', 0, [(1, 's')]), ('c', 1, [])] - -#%% void,barebones -def foo(): - print 'foo' -def bar(x): - print 'bar', x.__class__.__name__ -a = foo() #: foo -bar(a) #: bar NoneType - -def x(): - pass -b = lambda: x() -b() -x() if True else x() - -#%% void_2,barebones -def foo(): - i = 0 - while i < 10: - print i #: 0 - yield - i += 10 -a = list(foo()) -print(a) #: [None] - -#%% instantiate_swap,barebones -class Foo[T, U]: - t: T - u: U - def __init__(self): - self.t = T() - self.u = U() - def __str__(self): - return f'{self.t} {self.u}' -print Foo[int, bool](), Foo[bool, int]() #: 0 False False 0 - -#%% static,barebones -class Num[N_: Static[int]]: - def __str__(self): - return f'[{N_}]' - def __init__(self): - pass -def foo[N: Static[int]](): - print Num[N*2]() -foo(3) #: [6] - -class XX[N_: Static[int]]: - a: Num[N_*2] - def __init__(self): - self.a = Num() -y = XX[5]() -print y.a, y.__class__.__name__, y.a.__class__.__name__ #: [10] XX[5] Num[10] - -@tuple -class FooBar[N: Static[int]]: - x: Int[N] -z = FooBar(i32(5)) -print z, z.__class__.__name__, z.x.__class__.__name__ #: (x: Int[32](5)) FooBar[32] Int[32] - -@tuple -class Foo[N: Static[int]]: - x: Int[2*N] - def __new__(x: Int[2*N]) -> Foo[N]: - return (x,) -foo = Foo[10](Int[20](0)) -print foo.__class__.__name__, foo.x.__class__.__name__ #: Foo[10] Int[20] - -#%% static_2,barebones -class Num[N: Static[int]]: - def __str__(self): - return f'~{N}' - def __init__(self): - pass -class Foo[T, A: Static[int], B: Static[int]]: - a: Num[A+B] - b: Num[A-B] - c: Num[A if A > 3 else B] - t: T - def __init__(self): - self.a = Num() - self.b = Num() - self.c = Num() - self.t = T() - def __str__(self): - return f'<{self.a} {self.b} {self.c} :: {self.t}>' -print Foo[int, 3, 4](), Foo[int, 5, 4]() -#: <~7 ~-1 ~4 :: 0> <~9 ~1 ~5 :: 0> - -#%% static_int,barebones -def foo(n: Static[int]): - print n - -a: Static[int] = 5 -foo(a < 1) #: 0 -foo(a <= 1) #: 0 -foo(a > 1) #: 1 -foo(a >= 1) #: 1 -foo(a == 1) #: 0 -foo(a != 1) #: 1 -foo(a and 1) #: 1 -foo(a or 1) #: 1 -foo(a + 1) #: 6 -foo(a - 1) #: 4 -foo(a * 1) #: 5 -foo(a // 2) #: 2 -foo(a % 2) #: 1 -foo(a & 2) #: 0 -foo(a | 2) #: 7 -foo(a ^ 1) #: 4 - -#%% static_str,barebones -class X: - s: Static[str] - i: Int[1 + (s == "abc")] - def __init__(self: X[s], s: Static[str]): - i = Int[1+(s=="abc")]() - print s, self.s, self.i.__class__.__name__ -def foo(x: Static[str], y: Static[str]): - print x+y -z: Static[str] = "woo" -foo("he", z) #: hewoo -X(s='lolo') #: lolo lolo Int[1] -X('abc') #: abc abc Int[2] - - -def foo2(x: Static[str]): - print(x, x.__is_static__) -s: Static[str] = "abcdefghijkl" -foo2(s) #: abcdefghijkl True -foo2(s[1]) #: b True -foo2(s[1:5]) #: bcde True -foo2(s[10:50]) #: kl True -foo2(s[1:30:3]) #: behk True -foo2(s[::-1]) #: lkjihgfedcba True - - -#%% static_getitem -print Int[staticlen("ee")].__class__.__name__ #: Int[2] - -y = [1, 2] -print getattr(y, "len") #: 2 -print y.len #: 2 -getattr(y, 'append')(1) -print y #: [1, 2, 1] - -@extend -class Dict: - def __getitem2__(self, attr: Static[str]): - if hasattr(self, attr): - return getattr(self, attr) - else: - return self[attr] - def __getitem1__(self, attr: Static[int]): - return self[attr] - -d = {'s': 3.19} -print d.__getitem2__('_upper_bound') #: 3 -print d.__getitem2__('s') #: 3.19 -e = {1: 3.33} -print e.__getitem1__(1) #: 3.33 - -#%% static_fail,barebones -def test(i: Int[32]): - print int(i) -test(Int[5](1)) #! 'Int[5]' does not match expected type 'Int[32]' - -#%% static_fail_2,barebones -zi = Int[32](6) -def test3[N](i: Int[N]): - print int(i) -test3(zi) #! 'N' does not match expected type 'N' -# TODO: nicer error message! - -#%% static_fail_3,barebones -zi = Int[32](6) -def test3[N: Static[int]](i: Int[N]): - print int(i) -test3(1, int) #! expected static expression -# TODO: nicer error message! - -#%% nested_fn_generic,barebones -def f(x): - def g(y): - return y - return g(x) -print f(5), f('s') #: 5 s - -def f2[U](x: U, y): - def g[T, U](x: T, y: U): - return (x, y) - return g(y, x) -x, y = 1, 'haha' -print f2(x, y).__class__.__name__ #: Tuple[str,int] -print f2('aa', 1.1, U=str).__class__.__name__ #: Tuple[float,str] - -#%% nested_fn_generic_error,barebones -def f[U](x: U, y): # ('u, 'a) -> tuple['a, 'u] - def g[T, U](x: T, y: U): # ('t, 'u) -> tuple['t, 'u] - return (x, y) - return g(y, x) -print f(1.1, 1, int).__class__.__name__ #! 'float' does not match expected type 'int' - -#%% fn_realization,barebones -def ff[T](x: T, y: tuple[T]): - print ff(T=str,...).__class__.__name__ #: ff[str,Tuple[str],str] - return x -x = ff(1, (1,)) -print x, x.__class__.__name__ #: 1 int -# print f.__class__.__name__ # TODO ERRORS - -def fg[T](x:T): - def g[T](y): - z = T() - return z - print fg(T=str,...).__class__.__name__ #: fg[str,str] - print g(1, T).__class__.__name__ #: int -fg(1) -print fg(1).__class__.__name__ #: NoneType - -def f[T](x: T): - print f(x, T).__class__.__name__ #: int - print f(x).__class__.__name__ #: int - print f(x, int).__class__.__name__ #: int - return x -print f(1), f(1).__class__.__name__ #: 1 int -print f(1, int).__class__.__name__ #: int - -#%% fn_realization_error,barebones -def f[T](x: T): - print f(x, int).__class__.__name__ - return x -f('s') -#! 'str' does not match expected type 'int' -#! during the realization of f(x: str, T: str) - -#%% nested_class_error,barebones -class X: - def foo(self, x): - return x - class Y: - def bar(self, x): - return x -y = X.Y() -y.foo(1) #! 'X.Y' object has no attribute 'foo' - -#%% nested_deep_class,barebones -class A[T]: - a: T - class B[U]: - b: U - class C[V]: - c: V - def foo[W](t: V, u: V, v: V, w: W): - return (t, u, v, w) - -print A.B.C[bool].foo(W=str, ...).__class__.__name__ #: foo[bool,bool,bool,str,str] -print A.B.C.foo(1,1,1,True) #: (1, 1, 1, True) -print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') -print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') -print A.B.C.foo('x', 'x', 'x', 'x') #: ('x', 'x', 'x', 'x') - -x = A.B.C[bool]() -print x.__class__.__name__ #: A.B.C[bool] - -#%% nested_deep_class_error,barebones -class A[T]: - a: T - class B[U]: - b: U - class C[V]: - c: V - def foo[W](t: V, u: V, v: V, w: W): - return (t, u, v, w) - -print A.B.C[str].foo(1,1,1,True) #! 'A.B.C[str]' object has no method 'foo' with arguments (int, int, int, bool) - -#%% nested_deep_class_error_2,barebones -class A[T]: - a: T - class B[U]: - b: U - class C[V]: - c: V - def foo[W](t: V, u: V, v: V, w: W): - return (t, u, v, w) -print A.B[int].C[float].foo(1,1,1,True) #! 'A.B[int]' object has no attribute 'C' - -#%% nested_class_function,barebones -def f(x): - def g(y): - return y - a = g(1) - b = g('s') - c = g(x) - return a, b, c -print f(1.1).__class__.__name__ #: Tuple[int,str,float] -print f(False).__class__.__name__ #: Tuple[int,str,bool] - -class A[T]: - a: T - class B[U]: - b: U - class C[V]: - c: V - def f(x): - def g(y): - return y - a = g(1) - b = g('s') - c = g(x) - return a, b, c -print A.B.C.f(1.1).__class__.__name__ #: Tuple[int,str,float] -print A.B.C[Optional[int]].f(False).__class__.__name__ #: Tuple[int,str,bool] - -#%% rec_class_1,barebones -class A: - y: A - def __init__(self): pass # necessary to prevent recursive instantiation! -x = A() -print x.__class__.__name__, x.y.__class__.__name__ #: A A - -#%% rec_class_2,barebones -class A[T]: - a: T - b: A[T] - c: A[str] - def __init__(self): pass -a = A[int]() -print a.__class__.__name__, a.b.__class__.__name__, a.c.__class__.__name__, a.b.b.__class__.__name__, a.b.c.__class__.__name__ -#: A[int] A[int] A[str] A[int] A[str] -print a.c.b.__class__.__name__, a.c.c.__class__.__name__, a.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.b.__class__.__name__ -#: A[str] A[str] A[int] - -#%% rec_class_3,barebones -class X: - x: int - rec: X - def __init__(self): pass - def foo(x: X, y: int): - return y - class Y: - y: int - def bar(self, y): - print y - return self.y -x, y = X(), X.Y() -print x.__class__.__name__, y.__class__.__name__ -#: X X.Y -print X.foo(x, 4), x.foo(5) -#: 4 5 -print y.bar(1), y.bar('s'), X.Y.bar(y, True) -#: 1 -#: s -#: True -#: 0 0 0 - -#%% rec_class_4,barebones -class A[T]: - a: T - b: A[T] - c: A[str] - def __init__(self): pass -class B[T]: - a: T - b: A[T] - c: B[T] - def __init__(self): pass - class Nest1[U]: - n: U - class Nest2[T, U]: - m: T - n: U -b = B[float]() -print b.__class__.__name__, b.a.__class__.__name__, b.b.__class__.__name__, b.c.__class__.__name__, b.c.b.c.a.__class__.__name__ -#: B[float] float A[float] B[float] str - -n1 = B.Nest1[int]() -print n1.n, n1.__class__.__name__, n1.n.__class__.__name__ #: 0 B.Nest1[int] int - -n1: B.Nest2 = B.Nest2[float, int]() -print (n1.m, n1.n), n1.__class__.__name__, n1.m.__class__.__name__, n1.n.__class__.__name__ #: (0, 0) B.Nest2[float,int] float int - -#%% func_arg_instantiate,barebones -class A[T]: - y: T - def foo(self, y: T): - self.y = y - return y - def bar(self, y): - return y -a = A() -print a.__class__.__name__ #: A[int] -a.y = 5 -print a.__class__.__name__ #: A[int] - -b = A() -print b.foo(5) #: 5 -print b.__class__.__name__, b.y #: A[int] 5 -print b.bar('s'), b.bar('s').__class__.__name__ #: s str -print b.bar(5), b.bar(5).__class__.__name__ #: 5 int - -aa = A() -print aa.foo('s') #: s -print aa.__class__.__name__, aa.y, aa.bar(5.1).__class__.__name__ #: A[str] s float - -#%% no_func_arg_instantiate_err,barebones -# TODO: allow unbound self? -class A[T]: - y: T - def foo(self, y): self.y = y -a = A() -a.foo(1) #! cannot typecheck the program - -#%% return_deduction,barebones -def fun[T, R](x, y: T) -> R: - def ffi[T, R, Z](x: T, y: R, z: Z): - return (x, y, z) - yy = ffi(False, byte(2), 's', T=bool, Z=str, R=R) - yz = ffi(1, byte(2), 's', T=int, Z=str, R=R) - return byte(1) -print fun(2, 1.1, float, byte).__class__.__name__ #: byte - -#%% return_auto_deduction_err,barebones -def fun[T, R](x, y: T) -> R: - return byte(1) -print fun(2, 1.1).__class__.__name__ #! cannot typecheck the program - -#%% random -# shuffle used to fail before for some reason (sth about unbound variables)... -def foo(): - from random import shuffle - v = list(range(10)) - shuffle(v) - print sorted(v) #: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -foo() - -#%% function_type,barebones -class F: - f: Function[[int], int] - g: function[[int], None] - x: int -def foo(x: int): - return x+1 -def goo(x: int): - print x+2 -f = F(foo, goo, 2) -print f.f(f.x) #: 3 -f.g(f.x) #: 4 - -def hoo(z): - print z+3 -f.g = hoo -f.g(f.x) #: 5 - -def hai(x, y, z): - print f'hai({x},{y},{z})' -fn = Function[[int, int, int], None](hai) -fn(1, 2, 3) #: hai(1,2,3) -print str(fn)[:12] #: X.foo[X[X[int]],...,...] -print y.foo(1, 2.2, float) #: (2, 4.4) - -#%% forward,barebones -def foo(f, x): - f(x, type(x)) - print f.__class__.__name__ -def bar[T](x): - print x, T.__class__.__name__ -foo(bar, 1) -#: 1 int -#: bar[...] -foo(bar(...), 's') -#: s str -#: bar[...] -z = bar -z('s', int) -#: s int -z(1, T=str) -#: 1 str - -zz = bar(T=int,...) -zz(1) -#: 1 int - -#%% forward_error,barebones -def foo(f, x): - f(x, type(x)) - print f.__class__.__name__ -def bar[T](x): - print x, T.__class__.__name__ -foo(bar(T=int,...), 1) -#! bar() takes 2 arguments (2 given) -#! during the realization of foo(f: bar[...], x: int) -# TODO fix this error message - -#%% sort_partial -def foo(x, y): - return y**x -print sorted([1,2,3,4,5], key=foo(y=2, ...)) -print sorted([1,2,3,4,5], key=foo(y=-2, ...)) -#: [1, 2, 3, 4, 5] -#: [5, 3, 1, 2, 4] - -#%% mutually_recursive_error,barebones -def bl(x): - return True -def frec(x, y): - def grec(x, y): - return frec(y, x) - return grec(x, y) if bl(y) else 2 -print frec(1, 2).__class__.__name__, frec('s', 1).__class__.__name__ -#! 'NoneType' does not match expected type 'int' -#! during the realization of frec(x: int, y: int) - -#%% return_fn,barebones -def retfn(a): - def inner(b, *args, **kwargs): - print a, b, args, kwargs - print inner.__class__.__name__ #: inner[...,...,int,...] - return inner(15, ...) -f = retfn(1) -print f.__class__.__name__ #: inner[int,...,int,...] -f(2,3,foo='bar') #: 1 15 (2, 3) (foo: 'bar') - -#%% decorator_manual,barebones -def foo(x, *args, **kwargs): - print x, args, kwargs - return 1 -def dec(fn, a): - print 'decorating', fn.__class__.__name__ #: decorating foo[...,...,...] - def inner(*args, **kwargs): - print 'decorator', args, kwargs #: decorator (5.5, 's') (z: True) - return fn(a, *args, **kwargs) - return inner(...) -ff = dec(foo(...), 10) -print ff(5.5, 's', z=True) -#: 10 (5.5, 's') (z: True) -#: 1 - - -#%% decorator,barebones -def foo(x, *args, **kwargs): - print x, args, kwargs - return 1 -def dec(a): - def f(fn): - print 'decorating', fn.__class__.__name__ - def inner(*args, **kwargs): - print 'decorator', args, kwargs - return fn(a, *args, **kwargs) - return inner - return f -ff = dec(10)(foo) -print ff(5.5, 's', z=True) -#: decorating foo[...,...,...] -#: decorator (5.5, 's') (z: True) -#: 10 (5.5, 's') (z: True) -#: 1 - -@dec(a=5) -def zoo(e, b, *args): - return f'zoo: {e}, {b}, {args}' -print zoo(2, 3) -print zoo('s', 3) -#: decorating zoo[...,...,...] -#: decorator (2, 3) () -#: zoo: 5, 2, (3,) -#: decorator ('s', 3) () -#: zoo: 5, s, (3,) - -def mydecorator(func): - def inner(): - print("before") - func() - print("after") - return inner -@mydecorator -def foo2(): - print("foo") -foo2() -#: before -#: foo -#: after - -def timeme(func): - def inner(*args, **kwargs): - begin = 1 - end = func(*args, **kwargs) - begin - print('time needed for', func.__class__.__name__, 'is', end) - return inner -@timeme -def factorial(num): - n = 1 - for i in range(1,num + 1): - n *= i - print(n) - return n -factorial(10) -#: 3628800 -#: time needed for factorial[...] is 3628799 - -def dx1(func): - def inner(): - x = func() - return x * x - return inner -def dx2(func): - def inner(): - x = func() - return 2 * x - return inner -@dx1 -@dx2 -def num(): - return 10 -print(num()) #: 400 - -def dy1(func): - def inner(*a, **kw): - x = func(*a, **kw) - return x * x - return inner -def dy2(func): - def inner(*a, **kw): - x = func(*a, **kw) - return 2 * x - return inner -@dy1 -@dy2 -def num2(a, b): - return a+b -print(num2(10, 20)) #: 3600 - -#%% hetero_iter,barebones -e = (1, 2, 3, 'foo', 5, 'bar', 6) -for i in e: - if isinstance(i, int): - if i == 1: continue - if isinstance(i, str): - if i == 'bar': break - print i - -#%% type_loc,barebones -a = 1 -T = type(a) -print T.__class__.__name__ #: int - -#%% empty_tuple,barebones -T = type(()) # only errors with empty tuple type -p = Ptr[T](cobj()) -print p.__class__.__name__ #: Ptr[Tuple] - -print [a for a in ()] #: [] - -def foo(*args): - return [a for a in args] -args, result = ((), [()]) -print list(foo(*args)) #: [] -print result #: [()] - - -#%% type_error_reporting -# TODO: improve this certainly -def tee(iterable, n=2): - from collections import deque - it = iter(iterable) - deques = [deque() for i in range(n)] - def gen(mydeque): - while True: - if not mydeque: # when the local deque is empty - if it.done(): - return - newval = it.next() - for d in deques: # load it to all the deques - d.append(newval) - yield mydeque.popleft() - return list(gen(d) for d in deques) -it = [1,2,3,4] -a, b = tee(it) -#! cannot typecheck the program -#! during the realization of tee(iterable: List[int], n: int) - -#%% new_syntax,barebones -def foo[T,U](x: type, y, z: Static[int] = 10): - print T.__class__.__name__, U.__class__.__name__, x.__class__.__name__, y.__class__.__name__, Int[z+1].__class__.__name__ - return List[x]() -print foo(T=int,U=str,...).__class__.__name__ #: foo[T1,x,z,int,str] -print foo(T=int,U=str,z=5,x=bool,...).__class__.__name__ #: foo[T1,bool,5,int,str] -print foo(float,3,T=int,U=str,z=5).__class__.__name__ #: List[float] -foo(float,1,10,str,int) #: str int float int Int[11] - - -class Foo[T,U: Static[int]]: - a: T - b: Static[int] - c: Int[U] - d: type - e: List[d] - f: UInt[b] -print Foo[5,int,float,6].__class__.__name__ #: Foo[5,int,float,6] -print Foo(1.1, 10i32, [False], 10u66).__class__.__name__ #: Foo[66,bool,float,32] - - -def foo2[N: Static[int]](): - print Int[N].__class__.__name__, N -x: Static[int] = 5 -y: Static[int] = 105 - x * 2 -foo2(y-x) #: Int[90] 90 - -if 1.1+2.2 > 0: - z: Static[int] = 88 - print z #: 88 -print x #: 5 -x : Static[int] = 3 -print x #: 3 - -def fox(N: Static[int] = 4): - print Int[N].__class__.__name__, N -fox(5) #: Int[5] 5 -fox() #: Int[4] 4 - -#%% new_syntax_err,barebones -class Foo[T,U: Static[int]]: - a: T - b: Static[int] - c: Int[U] - d: type - e: List[d] - f: UInt[b] -print Foo[float,6].__class__.__name__ #! Foo takes 4 generics (2 given) - -#%% partial_star_pipe_args,barebones -iter(['A', 'C']) |> print -#: A -#: C -iter(range(4)) |> print('x', ..., 1) -#: x 0 1 -#: x 1 1 -#: x 2 1 -#: x 3 1 - -#%% type_arg_transform,barebones -print list(map(str, range(5))) -#: ['0', '1', '2', '3', '4'] - - -#%% traits,barebones -def t[T](x: T, key: Optional[Callable[[T], S]] = None, S: type = NoneType): - if isinstance(S, NoneType): - return x - else: - return (key.__val__())(x) -print t(5) #: 5 -print t(6, lambda x: f'#{x}') #: #6 - -z: Callable[[int],int] = lambda x: x+1 -print z(5) #: 6 - -def foo[T](x: T, func: Optional[Callable[[], T]] = None) -> T: - return x -print foo(1) #: 1 - -#%% trait_callable -foo = [1,2,11] -print(sorted(foo, key=str)) -#: [1, 11, 2] - -foo = {1: "a", 2: "a", 11: "c"} -print(sorted(foo.items(), key=str)) -#: [(1, 'a'), (11, 'c'), (2, 'a')] - -def call(f: Callable[[int,int], Tuple[str,int]]): - print(f(1, 2)) - -def foo(*x): return f"{x}_{x.__class__.__name__}",1 -call(foo) -#: ('(1, 2)_Tuple[int,int]', 1) - -def foo(a:int, *b: float): return f"{a}_{b}", a+b[0] -call(foo) -#: ('1_(2,)', 3) - -def call(f: Callable[[int,int],str]): - print(f(1, 2)) -def foo(a: int, *b: int, **kw): return f"{a}_{b}_{kw}" -call(foo(zzz=1.1, ...)) -#: 1_(2,)_(zzz: 1.1) - -#%% traits_error,barebones -def t[T](x: T, key: Optional[Callable[[T], S]] = None, S: type = NoneType): - if isinstance(S, NoneType): - return x - else: - return (key.__val__())(x) -print t(6, Optional(1)) #! 'Optional[int]' does not match expected type 'Optional[Callable[[int],S]]' - -#%% traits_error_2,barebones -z: Callable[[int],int] = 4 #! 'Callable[[int],int]' does not match expected type 'int' - -#%% assign_wrappers,barebones -a = 1.5 -print a #: 1.5 -if 1: - a = 1 -print a, a.__class__.__name__ #: 1 float - -a: Optional[int] = None -if 1: - a = 5 -print a.__class__.__name__, a #: Optional[int] 5 - -b = 5 -c = Optional(6) -if 1: - b = c -print b.__class__.__name__, c.__class__.__name__, b, c #: int Optional[int] 6 6 - -z: Generator[int] = [1, 2] -print z.__class__.__name__ #: Generator[int] - -zx: float = 1 -print zx.__class__.__name__, zx #: float 1 - -def test(v: Optional[int]): - v: int = v if v is not None else 3 - print v.__class__.__name__ -test(5) #: int -test(None) #: int - -#%% methodcaller,barebones -def foo(): - def bar(a, b): - print 'bar', a, b - return bar -foo()(1, 2) #: bar 1 2 - -def methodcaller(foo: Static[str]): - def caller(foo: Static[str], obj, *args, **kwargs): - if isinstance(getattr(obj, foo)(*args, **kwargs), None): - getattr(obj, foo)(*args, **kwargs) - else: - return getattr(obj, foo)(*args, **kwargs) - return caller(foo=foo, ...) -v = [1] -methodcaller('append')(v, 42) -print v #: [1, 42] -print methodcaller('index')(v, 42) #: 1 - -#%% fn_overloads,barebones -def foo(x): - return 1, x - -print(foo('')) #: (1, '') - -@overload -def foo(x, y): - def foo(x, y): - return f'{x}_{y}' - return 2, foo(x, y) - -@overload -def foo(x): - if x == '': - return 3, 0 - return 3, 1 + foo(x[1:])[1] - -print foo('hi') #: (3, 2) -print foo('hi', 1) #: (2, 'hi_1') - - -def fox(a: int, b: int, c: int, dtype: type = int): - print('fox 1:', a, b, c) - -@overload -def fox(a: int, b: int, dtype: type = int): - print('fox 2:', a, b, dtype.__class__.__name__) - -fox(1, 2, float) -#: fox 2: 1 2 float -fox(1, 2) -#: fox 2: 1 2 int -fox(1, 2, 3) -#: fox 1: 1 2 3 - -#%% fn_shadow,barebones -def foo(x): - return 1, x -print foo('hi') #: (1, 'hi') - -def foo(x): - return 2, x -print foo('hi') #: (2, 'hi') - -#%% fn_overloads_error,barebones -def foo(x): - return 1, x -@overload -def foo(x, y): - return 2, x, y -foo('hooooooooy!', 1, 2) #! no function 'foo' with arguments (str, int, int) - -#%% c_void_return,barebones -from C import seq_print(str) -x = seq_print("not ") -print x #: not None - - -#%% static_for,barebones -def foo(i: Static[int]): - print('static', i, Int[i].__class__.__name__) - -for i in statictuple(1, 2, 3, 4, 5): - foo(i) - if i == 3: break -#: static 1 Int[1] -#: static 2 Int[2] -#: static 3 Int[3] -for i in staticrange(9, 4, -2): - foo(i) - if i == 3: - break -#: static 9 Int[9] -#: static 7 Int[7] -#: static 5 Int[5] -for i in statictuple("x", 1, 3.3, 2): - print(i) -#: x -#: 1 -#: 3.3 -#: 2 - -print tuple(Int[i+10](i) for i in statictuple(1, 2, 3)).__class__.__name__ -#: Tuple[Int[11],Int[12],Int[13]] - -for i in staticrange(0, 10): - if i % 2 == 0: continue - if i > 8: break - print('xyz', Int[i].__class__.__name__) -print('whoa') -#: xyz Int[1] -#: xyz Int[3] -#: xyz Int[5] -#: xyz Int[7] -#: whoa - -for i in staticrange(15): - if i % 2 == 0: continue - if i > 8: break - print('xyz', Int[i].__class__.__name__) -print('whoa') -#: xyz Int[1] -#: xyz Int[3] -#: xyz Int[5] -#: xyz Int[7] -#: whoa - -print tuple(Int[i-10](i) for i in staticrange(30,33)).__class__.__name__ -#: Tuple[Int[20],Int[21],Int[22]] - -for i in statictuple(0, 2, 4, 7, 11, 12, 13): - if i % 2 == 0: continue - if i > 8: break - print('xyz', Int[i].__class__.__name__) -print('whoa') -#: xyz Int[7] -#: whoa - -for i in staticrange(10): # TODO: large values are too slow! - pass -print('done') -#: done - -tt = (5, 'x', 3.14, False, [1, 2]) -for i, j in staticenumerate(tt): - print(foo(i * 2 + 1), j) -#: static 1 Int[1] -#: None 5 -#: static 3 Int[3] -#: None x -#: static 5 Int[5] -#: None 3.14 -#: static 7 Int[7] -#: None False -#: static 9 Int[9] -#: None [1, 2] - -print tuple((Int[i+1](i), j) for i, j in staticenumerate(tt)).__class__.__name__ -#: Tuple[Tuple[Int[1],int],Tuple[Int[2],str],Tuple[Int[3],float],Tuple[Int[4],bool],Tuple[Int[5],List[int]]] - -#%% static_range_error,barebones -for i in staticrange(1000, -2000, -2): - pass -#! staticrange too large (expected 0..1024, got instead 1500) - -#%% trait_defdict -class dd(Static[Dict[K,V]]): - fn: S - K: type - V: type - S: TypeVar[Callable[[], V]] - - def __init__(self: dd[K, VV, Function[[], V]], VV: TypeVar[V]): - self.fn = lambda: VV() - - def __init__(self, f: S): - self.fn = f - - def __getitem__(self, key: K) -> V: - if key not in self: - self.__setitem__(key, self.fn()) - return super().__getitem__(key) - - -x = dd(list) -x[1] = [1, 2] -print(x[2]) -#: [] -print(x) -#: {1: [1, 2], 2: []} - -z = 5 -y = dd(lambda: z+1) -y.update({'a': 5}) -print(y['b']) -#: 6 -z = 6 -print(y['c']) -#: 7 -print(y) -#: {'a': 5, 'b': 6, 'c': 7} - -xx = dd(lambda: 'empty') -xx.update({1: 's', 2: 'b'}) -print(xx[1], xx[44]) -#: s empty -print(xx) -#: {44: 'empty', 1: 's', 2: 'b'} - -s = 'mississippi' -d = dd(int) -for k in s: - d[k] = d["x" + k] -print(sorted(d.items())) -#: [('i', 0), ('m', 0), ('p', 0), ('s', 0), ('xi', 0), ('xm', 0), ('xp', 0), ('xs', 0)] - - -#%% kwargs_getattr,barebones -def foo(**kwargs): - print kwargs['foo'], kwargs['bar'] - -foo(foo=1, bar='s') -#: 1 s - - - -#%% union_types,barebones -def foo_int(x: int): - print(f'{x} {x.__class__.__name__}') -def foo_str(x: str): - print(f'{x} {x.__class__.__name__}') -def foo(x): - print(f'{x} {int(__internal__.union_get_tag(x))} {x.__class__.__name__}') - -a: Union[int, str] = 5 -foo_int(a) #: 5 int -foo(a) #: 5 0 Union[int,str] -print(staticlen(a)) #: 2 -print(staticlen(Union[int, int]), staticlen(Tuple[int, float, int])) #: 1 3 - -@extend -class str: - def __add__(self, i: int): - return int(self) + i - -a += 6 ## this is U.__new__(a.__getter__(__add__)(59)) -b = a + 59 -print(a, b, a.__class__.__name__, b.__class__.__name__) #: 11 70 Union[int,str] int - -if True: - a = 'hello' - foo_str(a) #: hello str - foo(a) #: hello 1 Union[int,str] - b = a[1:3] - print(b) #: el -print(a) #: hello - -a: Union[Union[Union[str], int], Union[int, int, str]] = 9 -foo(a) #: 9 0 Union[int,str] - -def ret(x): - z : Union = x - if x < 1: z = 1 - elif x < 10: z = False - else: z = 'oops' - return z -r = ret(2) -print(r, r.__class__.__name__) #: False Union[bool,int,str] -r = ret(33.3) -print(r, r.__class__.__name__) #: oops Union[bool,float,int,str] - -def ret2(x) -> Union: - if x < 1: return 1 - elif x < 10: return 2.2 - else: return ['oops'] -r = ret2(20) -print(r, r.__class__.__name__) #: ['oops'] Union[List[str],float,int] - -class A: - x: int - def foo(self): - return f"A: {self.x}" -class B: - y: str - def foo(self): - return f"B: {self.y}" -x : Union[A,B] = A(5) # TODO: just Union does not work in test mode :/ -print(x.foo()) #: A: 5 -print(x.x) #: 5 -if True: - x = B("bee") -print(x.foo()) #: B: bee -print(x.y) #: bee -try: - print(x.x) -except TypeError as e: - print(e.message) #: invalid union call 'x' - -def do(x: A): - print('do', x.x) -try: - do(x) -except TypeError: - print('error') #: error - -def do2(x: B): - print('do2', x.y) -do2(x) #: do2 bee - -z: Union[int, str] = 1 -print isinstance(z, int), isinstance(z, str), isinstance(z, float), isinstance(z, Union[int, float]), isinstance(z, Union[int, str]) -#: True False False False True - -print isinstance(z, Union[int]), isinstance(z, Union[int, float, str]) -#: False False - -if True: - z = 's' -print isinstance(z, int), isinstance(z, str), isinstance(z, float), isinstance(z, Union[int, float]), isinstance(z, Union[int, str]) -#: False True False False True - -class A: - def foo(self): return 1 -class B: - def foo(self): return 's' -class C: - def foo(self): return [True, False] -x : Union[A,B,C] = A() -print x.foo(), x.foo().__class__.__name__ -#: 1 Union[List[bool],int,str] - -xx = Union[int, str](0) -print(xx) #: 0 - -#%% union_error,barebones -a: Union[int, str] = 123 -print(123 == a) #: True -print(a == 123) #: True -try: - a = "foo" - print(a == 123) -except TypeError: - print("oops", a) #: oops foo - - -#%% generator_capture_nonglobal,barebones -# Issue #49 -def foo(iter): - print(iter.__class__.__name__, list(iter)) - -for x in range(2): - foo(1 for _ in range(x)) -#: Generator[int] [] -#: Generator[int] [1] -for x in range(2): - for y in range(x): - foo('z' for _ in range(y)) -#: Generator[str] [] - -#%% nonlocal_capture_loop,barebones -# Issue #51 -def kernel(fn): - def wrapper(*args, grid, block): - print(grid, block, fn(*args)) - return wrapper -def test_mandelbrot(): - MAX = 10 # maximum Mandelbrot iterations - N = 2 # width and height of image - pixels = [0 for _ in range(N)] - def scale(x, a, b): - return a + (x/N)*(b - a) - @kernel - def k(pixels): - i = 0 - while i < MAX: i += 1 # this is needed for test to make sense - return (MAX, N, pixels, scale(N, -2, 0.4)) - k(pixels, grid=(N*N)//1024, block=1024) -test_mandelbrot() #: 0 1024 (10, 2, [0, 0], 0.4) - -#%% delayed_lambda_realization,barebones -x = [] -for i in range(2): - print(all(x[j] < 0 for j in range(i))) - x.append(i) -#: True -#: False - -#%% constructor_passing -class A: - s: str - def __init__(self, x): - self.s = str(x)[::-1] - def __lt__(self, o): return self.s < o.s - def __eq__(self, o): return self.s == o.s - def __ge__(self, o): return self.s >= o.s -foo = [1,2,11,30] -print(sorted(foo, key=str)) -#: [1, 11, 2, 30] -print(sorted(foo, key=A)) -#: [30, 1, 11, 2] - - -#%% polymorphism,barebones -class A: - a: int - def foo(self, a: int): return (f'A({self.a})', a) - def bar(self): return 'A.bar' - def aaz(self): return 'A.aaz' -class B(A): - b: int - def foo(self, a): return (f'B({self.a},{self.b})', a + self.b) - def bar(self): return 'B.bar' - def baz(self): return 'B.baz' -class M[T]: - m: T - def moo(self): return (f'M_{T.__class__.__name__}', self.m) -class X(B,M[int]): - def foo(self, a): return (f'X({self.a},{self.b},{self.m})', a + self.b + self.m) - def bar(self): return 'X.bar' - -def foo(i): - x = i.foo(1) - y = i.bar() - z = i.aaz() - print(*x, y, z) -a = A(1) -l = [a, B(2,3), X(2,3,-1)] -for i in l: foo(i) -#: A(1) 1 A.bar A.aaz -#: B(2,3) 4 B.bar A.aaz -#: X(2,3,-1) 3 X.bar A.aaz - -def moo(m: M): - print(m.moo()) -moo(M[float](5.5)) -moo(X(1,2,3)) -#: ('M_float', 5.5) -#: ('M_int', 3) - - -class A[T]: - def __init__(self): - print("init A", T.__class__.__name__) -class Ho: - def __init__(self): - print("init Ho") -# TODO: this throws and error: B[U](U) -class B[U](A[U], Ho): - def __init__(self): - super().__init__() - print("init B", U.__class__.__name__) -B[Ho]() -#: init A Ho -#: init B Ho - - -class Vehicle: - def drive(self): - return "I'm driving a vehicle" - -class Car(Vehicle): - def drive(self): - return "I'm driving a car" - -class Truck(Vehicle): - def drive(self): - return "I'm driving a truck" - -class SUV(Car, Truck): - def drive(self): - return "I'm driving an SUV" - -suv = SUV() -def moo(s): - print(s.drive()) -moo(suv) -moo(Truck()) -moo(Car()) -moo(Vehicle()) -#: I'm driving an SUV -#: I'm driving a truck -#: I'm driving a car -#: I'm driving a vehicle - - -#%% polymorphism_error_1,barebones -class M[T]: - m: T -class X(M[int]): - pass -l = [M[float](1.1), X(2)] -#! 'List[M[float]]' object has no method 'append' with arguments (List[M[float]], X) - -#%% polymorphism_2 -class Expr: - def __init__(self): - pass - def eval(self): - raise ValueError('invalid expr') - return 0.0 - def __str__(self): - return "Expr" -class Const(Expr): - x: float - def __init__(self, x): - self.x=x - def __str__(self): - return f"{self.x}" - def eval(self): - return self.x -class Add(Expr): - lhs: Expr - rhs: Expr - def __init__(self, lhs, rhs): - self.lhs=lhs - self.rhs=rhs - # print(f'ctr: {self}') - def eval(self): - return self.lhs.eval()+self.rhs.eval() - def __str__(self): - return f"({self.lhs}) + ({self.rhs})" -class Mul(Expr): - lhs: Expr - rhs: Expr - def __init__(self, lhs, rhs): - self.lhs=lhs - self.rhs=rhs - def eval(self): - return self.lhs.eval()*self.rhs.eval() - def __str__(self): - return f"({self.lhs}) * ({self.rhs})" - -c1 = Const(5) -c2 = Const(4) -m = Add(c1, c2) -c3 = Const(2) -a : Expr = Mul(m, c3) -print(f'{a} = {a.eval()}') -#: ((5) + (4)) * (2) = 18 - -from random import random, seed -seed(137) -def random_expr(depth) -> Expr: - if depth<=0: - return Const(int(random()*42.0)) - else: - lhs=random_expr(depth-1) - rhs=random_expr(depth-1) - ctorid = int(random()*3) - if ctorid==0: - return Mul(lhs,rhs) - else: - return Add(lhs,rhs) -for i in range(11): - print(random_expr(i).eval()) -#: 17 -#: 71 -#: 1760 -#: 118440 -#: 94442 -#: 8.02435e+15 -#: 1.07463e+13 -#: 1.43017e+19 -#: 2.40292e+34 -#: 6.1307e+28 -#: 5.16611e+49 - - -#%% collection_common_type,barebones -l = [1, 2, 3] -print(l, l.__class__.__name__) -#: [1, 2, 3] List[int] - -l = [1.1, 2, 3] -print(l, l.__class__.__name__) -#: [1.1, 2, 3] List[float] - -l = [1, 2, 3.3] -print(l, l.__class__.__name__) -#: [1, 2, 3.3] List[float] - -l = [1, None] -print(l, l.__class__.__name__) -#: [1, None] List[Optional[int]] - -l = [None, 2.2] -print(l, l.__class__.__name__) -#: [None, 2.2] List[Optional[float]] - -class A: - def __repr__(self): return 'A' -class B(A): - def __repr__(self): return 'B' -class C(B): - def __repr__(self): return 'C' -class D(A): - def __repr__(self): return 'D' - -l = [A(), B(), C(), D()] -print(l, l.__class__.__name__) -#: [A, B, C, D] List[A] - -l = [D(), C(), B(), A()] -print(l, l.__class__.__name__) -#: [D, C, B, A] List[A] - -l = [C(), B()] -print(l, l.__class__.__name__) -#: [C, B] List[B] - -l = [C(), A(), B()] -print(l, l.__class__.__name__) -#: [C, A, B] List[A] - -l = [None, *[1, 2], None] -print(l, l.__class__.__name__) -#: [None, 1, 2, None] List[Optional[int]] - -# l = [C(), D(), B()] # does not work (correct behaviour) -# print(l, l.__class__.__name__) - -d = {1: None, 2.2: 's'} -print(d, d.__class__.__name__) -#: {1: None, 2.2: 's'} Dict[float,Optional[str]] - -#%% polymorphism_3 -import operator - -class Expr: - def eval(self): - return 0 - -class Const(Expr): - value: int - - def __init__(self, value): - self.value = value - - def eval(self): - return self.value - -class BinOp(Expr): - lhs: Expr - rhs: Expr - - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs - - def eval_from_fn(self, fn): - return fn(self.lhs.eval(), self.rhs.eval()) - -class Add(BinOp): - def eval(self): - return self.eval_from_fn(operator.add) - -class Sub(BinOp): - def eval(self): - return self.eval_from_fn(operator.sub) - -class Mul(BinOp): - def eval(self): - return self.eval_from_fn(operator.mul) - -class Div(BinOp): - def eval(self): - return self.eval_from_fn(operator.floordiv) - -# TODO: remove Expr requirement -expr : Expr = Mul(Const(3), Add(Const(10), Const(5))) -print(expr.eval()) #: 45 - - -#%% polymorphism_4 -class A(object): - a: int - def __init__(self, a: int): - self.a = a - - def test_a(self, n: int): - print("test_a:A", n) - - def test(self, n: int): - print("test:A", n) - - def test2(self, n: int): - print("test2:A", n) - -class B(A): - b: int - def __init__(self, a: int, b: int): - super().__init__(a) - self.b = b - - def test(self, n: int): - print("test:B", n) - - def test2(self, n: int): - print("test2:B", n) - -class C(B): - pass - -b = B(1, 2) -b.test_a(1) -b.test(1) -#: test_a:A 1 -#: test:B 1 - -a: A = b -a.test(1) -a.test2(2) -#: test:B 1 -#: test2:B 2 - - - -class AX(object): - value: u64 - - def __init__(self): - print('init/AX') - self.value = 15u64 - - def get_value(self) -> u64: - return self.value - -class BX(object): - a: AX - def __init__(self): - print('init/BX') - self.a = AX() - def hai(self): - return f"hai/BX: {self.a.value}" - -class CX(BX): - def __init__(self): - print('init/CX') - super().__init__() - - def getsuper(self): - return super() - - def test(self): - print('test/CX:', self.a.value) - return self.a.get_value() - - def hai(self): - return f"hai/CX: {self.a.value}" - -table = CX() -#: init/CX -#: init/BX -#: init/AX -print table.test() -#: test/CX: 15 -#: 15 - -s = table.getsuper() -print(s.hai()) -#: hai/BX: 15 -s.a.value += 1u64 -print(s.hai()) -#: hai/BX: 16 -table.a.value += 1u64 -print(s.hai()) -#: hai/BX: 17 -table.test() -#: test/CX: 17 - -c: List[BX] = [s, table] -print(c[0].hai()) #: hai/BX: 17 -print(c[1].hai()) #: hai/CX: 17 - - - -#%% no_generic,barebones -def foo(a, b: Static[int]): - pass -foo(5) #! generic 'b' not provided - - -#%% no_generic_2,barebones -def f(a, b, T: type): - print(a, b) -f(1, 2) #! generic 'T' not provided - -#%% variardic_tuples,barebones - -class Foo[N: Static[int]]: - x: Tuple[N, str] - - def __init__(self): - self.x = ("hi", ) * N - -f = Foo[5]() -print(f.__class__.__name__) -#: Foo[5] -print(f.x.__class__.__name__) -#: Tuple[str,str,str,str,str] -print(f.x) -#: ('hi', 'hi', 'hi', 'hi', 'hi') - -print(Tuple[int, int].__class__.__name__) -#: Tuple[int,int] -print(Tuple[3, int].__class__.__name__) -#: Tuple[int,int,int] -print(Tuple[0].__class__.__name__) -#: Tuple -print(Tuple[-5, int].__class__.__name__) -#: Tuple -print(Tuple[5, int, str].__class__.__name__) -#: Tuple[int,str,int,str,int,str,int,str,int,str] - - -#%% domination_nested,barebones -def correlate(a, b, mode = 'valid'): - if mode == 'valid': - if isinstance(a, List): - xret = '1' - else: - xret = '2' - for i in a: - for j in b: - xret += 'z' - elif mode == 'same': - if isinstance(a, List): - xret = '3' - else: - xret = '4' - for i in a: - for j in b: - xret += 'z' - elif mode == 'full': - if isinstance(a, List): - xret = '5' - else: - xret = '6' - for i in a: - for j in b: - xret += 'z' - else: - raise ValueError(f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})") - return xret -print(correlate([1], [2], 'full')) #: 5z - -def foo(x, y): - a = 5 - if isinstance(a, int): - if staticlen(y) == 0: - a = 0 - elif staticlen(y) == 1: - a = 1 - else: - for i in range(10): - a = 40 - return a - return a -print foo(5, (1, 2, 3)) #: 40 - -#%% union_hasattr,barebones -class A: - def foo(self): - print('foo') - def bar(self): - print('bar') -class B: - def foo(self): - print('foo') - def baz(self): - print('baz') - -a = A() -print(hasattr(a, 'foo'), hasattr(a, 'bar'), hasattr(a, 'baz')) -#: True True False -b = B() -print(hasattr(b, 'foo'), hasattr(b, 'bar'), hasattr(b, 'baz')) -#: True False True - -c: Union[A, B] = A() -print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz')) -#: True True False - -c = B() -print(hasattr(c, 'foo'), hasattr(c, 'bar'), hasattr(c, 'baz')) -#: True False True - - -#%% delayed_dispatch -import math -def fox(a, b, key=None): # key=None delays it! - return a if a <= b else b - -a = 1.0 -b = 2.0 -c = fox(a, b) -print(math.log(c) / 2) #: 0 diff --git a/test/python/myextension.codon b/test/python/myextension.codon index 48f7a45c..51a7dc1e 100644 --- a/test/python/myextension.codon +++ b/test/python/myextension.codon @@ -323,7 +323,7 @@ class Foo: x: Dict[str, int] def __new__(a: List[str]) -> Foo: - return (a, {s: i for i, s in enumerate(a)}) + return Foo(a, {s: i for i, s in enumerate(a)}) def __iter__(self): return iter(self.a) diff --git a/test/python/myextension2.codon b/test/python/myextension2.codon index 6e8dcb3f..b5f45e53 100644 --- a/test/python/myextension2.codon +++ b/test/python/myextension2.codon @@ -8,9 +8,6 @@ class Vec: n: ClassVar[int] = 0 d: ClassVar[int] = 0 - def __new__(a: float, b: float, tag: str) -> Vec: - return (a, b, tag) - def __new__(a: float = 0.0, b: float = 0.0): v = Vec(a, b, 'v' + str(Vec.n)) Vec.n += 1 diff --git a/test/stdlib/datetime_test.codon b/test/stdlib/datetime_test.codon index 73e4d784..dd7101b8 100644 --- a/test/stdlib/datetime_test.codon +++ b/test/stdlib/datetime_test.codon @@ -688,7 +688,7 @@ class TestDate(Static[TestCase]): test_cases.append((new_date, new_iso)) for d, exp_iso in test_cases: - self.assertEqual(d.isocalendar(), exp_iso) + self.assertEqual(tuple(d.isocalendar()), exp_iso) # Check that the tuple contents are accessible by field name t = d.isocalendar() diff --git a/test/stdlib/itertools_test.codon b/test/stdlib/itertools_test.codon index 001dc75e..089358af 100644 --- a/test/stdlib/itertools_test.codon +++ b/test/stdlib/itertools_test.codon @@ -540,9 +540,6 @@ def test_accumulate_from_cpython(): assert list(accumulate(List[int](), initial=100)) == [100] -test_accumulate_from_cpython() - - @test def test_chain_from_cpython(): assert list(chain("abc", "def")) == list("abcdef") @@ -551,9 +548,6 @@ def test_chain_from_cpython(): assert list(take(4, chain("abc", "def"))) == list("abcd") -test_chain_from_cpython() - - @test def test_chain_from_iterable_from_cpython(): assert list(chain.from_iterable(["abc", "def"])) == list("abcdef") @@ -562,9 +556,6 @@ def test_chain_from_iterable_from_cpython(): assert take(4, chain.from_iterable(["abc", "def"])) == list("abcd") -test_chain_from_iterable_from_cpython() - - @test def test_combinations_from_cpython(): f = lambda x: x # hack to get non-static argument @@ -645,7 +636,6 @@ def test_combinations_from_cpython(): e for e in values if e in c ] # comb is a subsequence of the input iterable -test_combinations_from_cpython() @test @@ -741,9 +731,6 @@ def test_combinations_with_replacement_from_cpython(): ] # comb is a subsequence of the input iterable -test_combinations_with_replacement_from_cpython() - - @test def test_permutations_from_cpython(): f = lambda x: x # hack to get non-static argument @@ -813,9 +800,6 @@ def test_permutations_from_cpython(): assert result == list(permutations(values, r)) -test_permutations_from_cpython() - - @extend class List: def __lt__(self, other: List[T]): @@ -952,9 +936,6 @@ def test_combinatorics_from_cpython(): assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm -test_combinatorics_from_cpython() - - @test def test_compress_from_cpython(): assert list(compress(data="ABCDEF", selectors=[1, 0, 1, 0, 1, 1])) == list("ACEF") @@ -969,9 +950,6 @@ def test_compress_from_cpython(): assert list(compress(data, selectors)) == [1, 3, 5] * n -test_compress_from_cpython() - - @test def test_count_from_cpython(): assert lzip("abc", count()) == [("a", 0), ("b", 1), ("c", 2)] @@ -982,9 +960,6 @@ def test_count_from_cpython(): assert take(3, count(3.25)) == [3.25, 4.25, 5.25] -test_count_from_cpython() - - @test def test_count_with_stride_from_cpython(): assert lzip("abc", count(2, 3)) == [("a", 2), ("b", 5), ("c", 8)] @@ -996,9 +971,6 @@ def test_count_with_stride_from_cpython(): assert take(3, count(2.0, 1.25)) == [2.0, 3.25, 4.5] -test_count_with_stride_from_cpython() - - @test def test_cycle_from_cpython(): assert take(10, cycle("abc")) == list("abcabcabca") @@ -1006,9 +978,6 @@ def test_cycle_from_cpython(): assert list(islice(cycle(gen3()), 10)) == [0, 1, 2, 0, 1, 2, 0, 1, 2, 0] -test_cycle_from_cpython() - - @test def test_groupby_from_cpython(): # Check whether it accepts arguments correctly @@ -1075,9 +1044,6 @@ def test_groupby_from_cpython(): assert r == [(5, "a"), (2, "r"), (2, "b")] -test_groupby_from_cpython() - - @test def test_filter_from_cpython(): assert list(filter(isEven, range(6))) == [0, 2, 4] @@ -1086,9 +1052,6 @@ def test_filter_from_cpython(): assert take(4, filter(isEven, count())) == [0, 2, 4, 6] -test_filter_from_cpython() - - @test def test_filterfalse_from_cpython(): assert list(filterfalse(isEven, range(6))) == [1, 3, 5] @@ -1097,9 +1060,6 @@ def test_filterfalse_from_cpython(): assert take(4, filterfalse(isEven, count())) == [1, 3, 5, 7] -test_filterfalse_from_cpython() - - @test def test_zip_from_cpython(): ans = [(x, y) for x, y in zip("abc", count())] @@ -1112,9 +1072,6 @@ def test_zip_from_cpython(): assert [pair for pair in zip("abc", "def")] == lzip("abc", "def") -test_zip_from_cpython() - - @test def test_ziplongest_from_cpython(): for args in ( @@ -1142,9 +1099,6 @@ def test_ziplongest_from_cpython(): ) -test_ziplongest_from_cpython() - - @test def test_product_from_cpython(): for args, result in ( @@ -1175,9 +1129,6 @@ def test_product_from_cpython(): ) -test_product_from_cpython() - - @test def test_repeat_from_cpython(): assert list(repeat(object="a", times=3)) == ["a", "a", "a"] @@ -1188,9 +1139,6 @@ def test_repeat_from_cpython(): assert list(repeat("a", -3)) == [] -test_repeat_from_cpython() - - @test def test_map_from_cpython(): power = lambda a, b: a ** b @@ -1201,9 +1149,6 @@ def test_map_from_cpython(): assert list(map(tupleize, List[int]())) == [] -test_map_from_cpython() - - @test def test_starmap_from_cpython(): power = lambda a, b: a ** b @@ -1213,9 +1158,6 @@ def test_starmap_from_cpython(): assert list(starmap(power, [(4, 5)])) == [4 ** 5] -test_starmap_from_cpython() - - @test def test_islice_from_cpython(): for args in ( # islice(args) should agree with range(args) @@ -1243,9 +1185,6 @@ def test_islice_from_cpython(): assert list(islice(range(10), 1, None, 2)) == list(range(1, 10, 2)) -test_islice_from_cpython() - - @test def test_takewhile_from_cpython(): data = [1, 3, 5, 20, 2, 4, 6, 8] @@ -1255,9 +1194,6 @@ def test_takewhile_from_cpython(): assert list(t) == [1, 1, 1] -test_takewhile_from_cpython() - - @test def test_dropwhile_from_cpython(): data = [1, 3, 5, 20, 2, 4, 6, 8] @@ -1265,9 +1201,6 @@ def test_dropwhile_from_cpython(): assert list(dropwhile(underten, List[int]())) == [] -test_dropwhile_from_cpython() - - @test def test_tee_from_cpython(): import random @@ -1315,5 +1248,27 @@ def test_tee_from_cpython(): assert list(a) == list(range(100, 2000)) assert list(c) == list(range(2, 2000)) - +test_accumulate_from_cpython() +test_chain_from_cpython() +test_chain_from_iterable_from_cpython() +test_combinations_from_cpython() # takes long time to typecheck +test_combinations_with_replacement_from_cpython() +test_permutations_from_cpython() +test_combinatorics_from_cpython() # TODO: takes FOREVER to typecheck +test_compress_from_cpython() +test_count_from_cpython() +test_count_with_stride_from_cpython() +test_cycle_from_cpython() +test_groupby_from_cpython() +test_filter_from_cpython() +test_filterfalse_from_cpython() +test_zip_from_cpython() +test_ziplongest_from_cpython() +test_product_from_cpython() +test_repeat_from_cpython() +test_map_from_cpython() +test_starmap_from_cpython() +test_islice_from_cpython() +test_takewhile_from_cpython() +test_dropwhile_from_cpython() test_tee_from_cpython() diff --git a/test/stdlib/llvm_test.codon b/test/stdlib/llvm_test.codon new file mode 100644 index 00000000..f8c2e633 --- /dev/null +++ b/test/stdlib/llvm_test.codon @@ -0,0 +1,205 @@ +#%% ptr,barebones +import internal.gc as gc +print gc.sizeof(Ptr[int]) #: 8 +print gc.atomic(Ptr[int]) #: False +y = Ptr[int](1) +y[0] = 11 +print y[0] #: 11 +_y = y.as_byte() +print int(_y[0]) #: 11 +y = Ptr[int](5) +y[0] = 1; y[1] = 2; y[2] = 3; y[3] = 4; y[4] = 5 +z = Ptr[int](y) +print y[1], z[2] #: 2 3 +z = Ptr[int](y.as_byte()) +print y[1], z[2] #: 2 3 +print z.__bool__() #: True +z.__int__() # big number... +zz = z.__copy__() # this is not a deep copy! +zz[2] = 10 +print z[2], zz[2] #: 10 10 +print y.__getitem__(1) #: 2 +y.__setitem__(1, 3) +print y[1] #: 3 +print y.__add__(1)[0] #: 3 +print (y + 3).__sub__(y + 1) #: 2 +print y.__eq__(z) #: True +print y.__eq__(zz) #: True +print y.as_byte().__eq__('abc'.ptr) #: False +print y.__ne__(z) #: False +print y.__lt__(y+1) #: True +print y.__gt__(y+1) #: False +print (y+1).__le__(y) #: False +print y.__ge__(y) #: True +y.__prefetch_r0__() +y.__prefetch_r1__() +y.__prefetch_r2__() +y.__prefetch_r3__() +y.__prefetch_w0__() +y.__prefetch_w1__() +y.__prefetch_w2__() +y.__prefetch_w3__() + +#%% int,barebones +a = int() +b = int(5) +c = int(True) +d = int(byte(1)) +e = int(1.1) +print a, b, c, d, e #: 0 5 1 1 1 +print a.__repr__() #: 0 +print b.__copy__() #: 5 +print b.__hash__() #: 5 +print a.__bool__(), b.__bool__() #: False True +print a.__pos__() #: 0 +print b.__neg__() #: -5 +print (-b).__abs__() #: 5 +print c.__truediv__(5) #: 0.2 +print b.__lshift__(1) #: 10 +print b.__rshift__(1) #: 2 +print b.__truediv__(5.15) #: 0.970874 +print a.__add__(1) #: 1 +print a.__add__(1.1) #: 1.1 +print a.__sub__(1) #: -1 +print a.__sub__(1.1) #: -1.1 +print b.__mul__(1) #: 5 +print b.__mul__(1.1) #: 5.5 +print b.__floordiv__(2) #: 2 +print b.__floordiv__(1.1) #: 4 +print b.__mod__(2) #: 1 +print b.__mod__(1.1) #: 0.6 +print a.__eq__(1) #: False +print a.__eq__(1.1) #: False +print a.__ne__(1) #: True +print a.__ne__(1.1) #: True +print a.__lt__(1) #: True +print a.__lt__(1.1) #: True +print a.__le__(1) #: True +print a.__le__(1.1) #: True +print a.__gt__(1) #: False +print a.__gt__(1.1) #: False +print a.__ge__(1) #: False +print a.__ge__(1.1) #: False + +#%% uint,barebones +au = Int[123](15) +a = UInt[123]() +b = UInt[123](a) +a = UInt[123](15) +a = UInt[123](au) +print a.__copy__() #: 15 +print a.__hash__() #: 15 +print a.__bool__() #: True +print a.__pos__() #: 15 +print a.__neg__() #: 10633823966279326983230456482242756593 +print a.__invert__() #: 10633823966279326983230456482242756592 +m = UInt[123](4) +print a.__add__(m), a.__sub__(m), a.__mul__(m), a.__floordiv__(m), a.__truediv__(m) #: 19 11 60 3 3.75 +print a.__mod__(m), a.__lshift__(m), a.__rshift__(m) #: 3 240 0 +print a.__eq__(m), a.__ne__(m), a.__lt__(m), a.__gt__(m), a.__le__(m), a.__ge__(m) #: False True False True False True +print a.__and__(m), a.__or__(m), a.__xor__(m) #: 4 15 11 +print a, a.popcnt() #: 15 4 +ac = Int[128](5) +bc = Int[32](5) +print ac, bc, int(ac), int(bc) #: 5 5 5 5 + +print int(Int[12](12)) #: 12 +print int(Int[122](12)) #: 12 +print int(Int[64](12)) #: 12 +print int(UInt[12](12)) #: 12 +print int(UInt[122](12)) #: 12 +print int(UInt[64](12)) #: 12 + +print Int[32](212) #: 212 +print Int[64](212) #: 212 +print Int[66](212) #: 212 +print UInt[32](112) #: 112 +print UInt[64](112) #: 112 +print UInt[66](112) #: 112 + + +#%% float,barebones +z = float.__new__() +z = 5.5 +print z.__repr__() #: 5.5 +print z.__copy__() #: 5.5 +print z.__bool__(), z.__pos__(), z.__neg__() #: True 5.5 -5.5 +f = 1.3 +print z.__floordiv__(f), z.__floordiv__(2) #: 4 2 +print z.__truediv__(f), z.__truediv__(2) #: 4.23077 2.75 +print z.__pow__(2.2), z.__pow__(2) #: 42.54 30.25 +print z.__add__(2) #: 7.5 +print z.__sub__(2) #: 3.5 +print z.__mul__(2) #: 11 +print z.__truediv__(2) #: 2.75 +print z.__mod__(2) #: 1.5 +print z.__eq__(2) #: False +print z.__ne__(2) #: True +print z.__lt__(2) #: False +print z.__gt__(2) #: True +print z.__le__(2) #: False +print z.__ge__(2) #: True + +#%% bool,barebones +z = bool.__new__() +print z.__repr__() #: False +print z.__copy__() #: False +print z.__bool__(), z.__invert__() #: False True +print z.__eq__(True) #: False +print z.__ne__(True) #: True +print z.__lt__(True) #: True +print z.__gt__(True) #: False +print z.__le__(True) #: True +print z.__ge__(True) #: False +print z.__and__(True), z.__or__(True), z.__xor__(True) #: False True True + +#%% byte,barebones +z = byte.__new__() +z = byte(65) +print z.__repr__() #: byte('A') +print z.__bool__() #: True +print z.__eq__(byte(5)) #: False +print z.__ne__(byte(5)) #: True +print z.__lt__(byte(5)) #: False +print z.__gt__(byte(5)) #: True +print z.__le__(byte(5)) #: False +print z.__ge__(byte(5)) #: True + +#%% array,barebones +a = Array[float](5) +pa = Ptr[float](3) +z = Array[float](pa, 3) +z.__copy__() +print z.__len__() #: 3 +print z.__bool__() #: True +z.__setitem__(0, 1.0) +z.__setitem__(1, 2.0) +z.__setitem__(2, 3.0) +print z.__getitem__(1) #: 2 +print z.slice(0, 2).len #: 2 + +#%% optional,barebones +a = Optional[float]() +print bool(a) #: False +a = Optional[float](0.0) +print bool(a) #: False +a = Optional[float](5.5) +print a.__bool__(), a.__val__() #: True 5.5 + +#%% generator,barebones +def foo(): + yield 1 + yield 2 + yield 3 +z = foo() +y = z.__iter__() +print str(y.__raw__())[:2] #: 0x +print y.__done__() #: False +print y.__promise__()[0] #: 0 +y.__resume__() +print y.__repr__()[:16] #: