diff --git a/server/backend/Rsh.h b/server/backend/Rsh.h index b8f465ad..d242df25 100644 --- a/server/backend/Rsh.h +++ b/server/backend/Rsh.h @@ -172,8 +172,11 @@ static INLINE Value sexp_as_val(SEXP s) { } } +// FIXME: should be in the test runtime or extern static Value Rsh_NilValue; static SEXP NOT_OP; +static SEXP DOTEXTERNAL2_SYM; +static SEXP RSH_CALL_TRAMPOLINE_SXP; // BINDING CELLS (bcell) implementation // ------------------------------------ @@ -350,6 +353,8 @@ static INLINE Rboolean bcell_set_value(BCell cell, SEXP value) { #define Rsh_const_int(env, idx) INT_TO_VAL(INTEGER(Rsh_const(env, idx))[0]) #define Rsh_const_lgl(env, idx) LGL_TO_VAL(INTEGER(Rsh_const(env, idx))[0]) +typedef SEXP (*Rsh_closure)(SEXP, SEXP); + // RUNTIME INITIALIZATION // ---------------------- @@ -362,6 +367,7 @@ static INLINE Rboolean bcell_set_value(BCell cell, SEXP value) { } while (0) #ifdef RSH_TESTS + SEXP Rsh_initialize_runtime(void) { #define X(a, b) LOAD_R_BUILTIN(R_ARITH_OPS[b], #a); X_ARITH_OPS @@ -391,11 +397,40 @@ SEXP Rsh_initialize_runtime(void) { Rsh_NilValue = SXP_TO_VAL(R_NilValue); LOAD_R_BUILTIN(NOT_OP, "!"); + DOTEXTERNAL2_SYM = install(".External2"); + + RSH_CALL_TRAMPOLINE_SXP = Rf_mkString("Rsh_call_trampoline"); + R_PreserveObject(RSH_CALL_TRAMPOLINE_SXP); return R_NilValue; } #endif +SEXP Rsh_call_trampoline(SEXP call, SEXP op, SEXP args, SEXP rho) { + SEXP closure = CADR(args); + if (TYPEOF(closure) != CLOSXP) { + Rf_error("Expected a closure"); + } + + SEXP body = BODY(closure); + if (TYPEOF(body) != BCODESXP) { + Rf_error("Expected a compiled closure"); + } + + SEXP cp = BCODE_CONSTS(body); + if (XLENGTH(cp) != 6) { + Rf_error("Expected a constant pool with 6 elements"); + } + + SEXP c_cp = VECTOR_ELT(cp, LENGTH(cp) - 2); + // cf. https://stackoverflow.com/a/19487645 + Rsh_closure fun; + *(void **)(&fun) = R_ExternalPtrAddr(VECTOR_ELT(c_cp, 0)); + SEXP res = fun(rho, VECTOR_ELT(c_cp, 1)); + + return res; +} + // INSTRUCTIONS // ------------ @@ -915,29 +950,45 @@ static INLINE Value Rsh_logic(SEXP call, RshLogic2Op op, Value lhs, Value rhs, return res; } -#define LDCONST_OP 16 -#define DOTCALL_OP 119 +#define PUSHCONSTARG_OP 34 +#define BASEGUARD_OP 123 +#define GETBUILTIN_OP 26 +#define CALLBUILTIN_OP 39 #define RETURN_OP 1 #define BCODE_CODE(x) CAR(x) #define BCODE_CONSTS(x) CDR(x) #define IS_BYTECODE(x) (TYPEOF(x) == BCODESXP) -static INLINE SEXP create_wrapper_body(SEXP original_body, - const char *native_fun_name, SEXP rho, - SEXP c_cp) { +static INLINE SEXP create_constant_pool(Rsh_closure fun_ptr, SEXP c_cp) { + SEXP pool = PROTECT(Rf_allocVector(VECSXP, 2)); + SEXP p = R_MakeExternalPtr((void *)fun_ptr, R_NilValue, R_NilValue); + + // slot 0: the pointer to the compiled function + SET_VECTOR_ELT(pool, 0, p); + + // slot 1: the contants used by the compiled function + SET_VECTOR_ELT(pool, 1, c_cp); + + UNPROTECT(1); // consts + + return pool; +} + +static INLINE SEXP create_wrapper_body(SEXP closure, SEXP rho, SEXP c_cp) { // clang-format off static i32 CALL_FUN_BC[] = { 12, - LDCONST_OP, 2, - LDCONST_OP, 3, - LDCONST_OP, 4, - DOTCALL_OP, 1, 2, + GETBUILTIN_OP, 1, + PUSHCONSTARG_OP, 2, + PUSHCONSTARG_OP, 3, + CALLBUILTIN_OP, 0, RETURN_OP }; // clang-format on + SEXP original_body = BODY(closure); assert(IS_BYTECODE(original_body)); SEXP original_cp = BCODE_CONSTS(original_body); @@ -952,15 +1003,14 @@ static INLINE SEXP create_wrapper_body(SEXP original_body, INTEGER(expr_index)[0] = NA_INTEGER; memset(INTEGER(expr_index) + 1, 0, (bc_size - 1) * sizeof(i32)); - SEXP natfun_sxp = Rf_mkString(native_fun_name); SEXP cp = PROTECT(Rf_allocVector(VECSXP, 6)); int i = 0; // store the original AST (consequently it will not correspond to the AST) SET_VECTOR_ELT(cp, i++, VECTOR_ELT(original_cp, 0)); - SET_VECTOR_ELT(cp, i++, Rf_lang4(install(".Call"), natfun_sxp, rho, c_cp)); - SET_VECTOR_ELT(cp, i++, natfun_sxp); - SET_VECTOR_ELT(cp, i++, rho); + SET_VECTOR_ELT(cp, i++, DOTEXTERNAL2_SYM); + SET_VECTOR_ELT(cp, i++, RSH_CALL_TRAMPOLINE_SXP); + SET_VECTOR_ELT(cp, i++, closure); SET_VECTOR_ELT(cp, i++, c_cp); SET_VECTOR_ELT(cp, i++, expr_index); @@ -979,14 +1029,16 @@ static INLINE SEXP create_wrapper_body(SEXP original_body, return body; } -static INLINE Value Rsh_native_closure(SEXP mkclos_arg, - const char *native_fun_name, SEXP c_cp, - SEXP rho) { +static INLINE Value Rsh_native_closure(SEXP mkclos_arg, Rsh_closure fun_ptr, + SEXP consts, SEXP rho) { SEXP forms = VECTOR_ELT(mkclos_arg, 0); SEXP original_body = VECTOR_ELT(mkclos_arg, 1); - SEXP body = create_wrapper_body(original_body, native_fun_name, rho, c_cp); - SEXP closure = Rf_mkCLOSXP(forms, body, rho); + SEXP closure = PROTECT(Rf_mkCLOSXP(forms, original_body, rho)); + + SEXP c_cp = PROTECT(create_constant_pool(fun_ptr, consts)); + SEXP body = PROTECT(create_wrapper_body(closure, rho, c_cp)); + SET_BODY(closure, body); if (LENGTH(mkclos_arg) > 2) { SEXP srcref = VECTOR_ELT(mkclos_arg, 2); @@ -996,6 +1048,7 @@ static INLINE Value Rsh_native_closure(SEXP mkclos_arg, } R_Visible = TRUE; + UNPROTECT(3); return SXP_TO_VAL(closure); } diff --git a/server/src/main/java/org/prlprg/bc2c/BC2CCompiler.java b/server/src/main/java/org/prlprg/bc2c/BC2CCompiler.java index d6862af4..49d91dca 100644 --- a/server/src/main/java/org/prlprg/bc2c/BC2CCompiler.java +++ b/server/src/main/java/org/prlprg/bc2c/BC2CCompiler.java @@ -137,8 +137,8 @@ public CompiledModule finish() { var compiledClosure = module.compileClosure(bc); var file = new CFile(); - file.setPreamble("#include "); - module.funs().forEach(file::add); + file.addInclude("Rsh.h"); + module.funs().forEach(fun -> file.addFun(fun, true)); return new CompiledModule(file, compiledClosure.name(), compiledClosure.constantPool()); } @@ -258,8 +258,9 @@ private void compileMakeClosure(ConstPool.Idx idx) { var compiledClosure = module.compileClosure(body.bc()); var cpId = constants.size(); // new body for the closure itself + // FIXME: the cpId has to be greater than any existing constpool.idx -- need to keep an extra one around constants.put(cpId, new Constant(cpId, compiledClosure.constantPool())); - push("Rsh_native_closure(%s, \"%s\", %s, %s)".formatted(constantSXP(idx), compiledClosure.name(), constantSXP(cpId), NAME_ENV), false); + push("Rsh_native_closure(%s, &%s, %s, %s)".formatted(constantSXP(idx), compiledClosure.name(), constantSXP(cpId), NAME_ENV), false); } else { throw new UnsupportedOperationException("Unsupported body: " + body); } diff --git a/server/src/main/java/org/prlprg/bc2c/CFile.java b/server/src/main/java/org/prlprg/bc2c/CFile.java index 0ee643c0..4a924a57 100644 --- a/server/src/main/java/org/prlprg/bc2c/CFile.java +++ b/server/src/main/java/org/prlprg/bc2c/CFile.java @@ -7,20 +7,16 @@ import java.util.List; public class CFile { - private String preamble; private final List funs = new ArrayList<>(); - - public void setPreamble(String preamble) { - this.preamble = preamble; - } + private final List forwards = new ArrayList<>(); + private final List includes = new ArrayList<>(); public void writeTo(Writer w) { var pw = new PrintWriter(w); - if (preamble != null) { - pw.println(preamble); - pw.println(); - } + includes.forEach(x -> pw.println("#include <" + x + ">")); + pw.println(); + forwards.stream().map(CFunction::getDeclaration).forEach(x -> pw.println(x + ";")); funs.forEach(x -> x.writeTo(pw)); } @@ -31,7 +27,14 @@ public String toString() { return w.toString(); } - public void add(CFunction fun) { + public void addFun(CFunction fun, boolean forwardDeclare) { funs.add(fun); + if (forwardDeclare) { + forwards.add(fun); + } + } + + public void addInclude(String include) { + includes.add(include); } } diff --git a/server/src/main/java/org/prlprg/bc2c/CFunction.java b/server/src/main/java/org/prlprg/bc2c/CFunction.java index cf7bc164..479e676b 100644 --- a/server/src/main/java/org/prlprg/bc2c/CFunction.java +++ b/server/src/main/java/org/prlprg/bc2c/CFunction.java @@ -6,44 +6,48 @@ import java.util.List; public class CFunction { - private final String returnType; - private final String name; - private final String parameters; - private final List sections = new ArrayList<>(); + private final String returnType; + private final String name; + private final String parameters; + private final List sections = new ArrayList<>(); - CFunction(String returnType, String name, String parameters) { - this.returnType = returnType; - this.name = name; - this.parameters = parameters; - } + CFunction(String returnType, String name, String parameters) { + this.returnType = returnType; + this.name = name; + this.parameters = parameters; + } - public CCode add() { - var s = new CCode(); - sections.add(s); - return s; - } + public CCode add() { + var s = new CCode(); + sections.add(s); + return s; + } - public void writeTo(Writer w) { - var pw = new PrintWriter(w); - pw.format("%s %s(%s) {", returnType, name, parameters); - pw.println(); - for (int i = 0; i < sections.size(); i++) { - sections.get(i).writeTo(w); - if (i < sections.size() - 1) { + public void writeTo(Writer w) { + var pw = new PrintWriter(w); + pw.format("%s {", getDeclaration()); pw.println(); - } + for (int i = 0; i < sections.size(); i++) { + sections.get(i).writeTo(w); + if (i < sections.size() - 1) { + pw.println(); + } + } + pw.println("}"); + pw.flush(); + } + + public CCode insertAbove(CCode sec) { + var i = sections.indexOf(sec); + if (i == -1) { + throw new IllegalArgumentException("Section " + sec + " does not exist in fun " + this); + } + var s = new CCode(); + sections.add(i, s); + return s; } - pw.println("}"); - pw.flush(); - } - public CCode insertAbove(CCode sec) { - var i = sections.indexOf(sec); - if (i == -1) { - throw new IllegalArgumentException("Section " + sec + " does not exist in fun " + this); + public String getDeclaration() { + return String.format("%s %s(%s)", returnType, name, parameters); } - var s = new CCode(); - sections.add(i, s); - return s; - } } diff --git a/server/src/test/java/org/prlprg/bc2c/BC2CCompilerTest.java b/server/src/test/java/org/prlprg/bc2c/BC2CCompilerTest.java index efe64f2f..3963b4e0 100644 --- a/server/src/test/java/org/prlprg/bc2c/BC2CCompilerTest.java +++ b/server/src/test/java/org/prlprg/bc2c/BC2CCompilerTest.java @@ -148,10 +148,11 @@ public void testBooleanOperators() throws Exception { public void testClosure() throws Exception { verify( """ - f <- function (x) { x + 1 } + y <- 21 + f <- function (x) { x + y } f(42) """, - (RealSXP v) -> assertEquals(43.0, v.asReal(0))); + (RealSXP v) -> assertEquals(63.0, v.asReal(0))); } @Test