Skip to content

Commit

Permalink
Moved to .External2 - WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
fikovnik committed Aug 21, 2024
1 parent c4049d1 commit aa20814
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 66 deletions.
89 changes: 71 additions & 18 deletions server/backend/Rsh.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,11 @@ static INLINE Value sexp_as_val(SEXP s) {
}
}

// FIXME: should be in the test runtime or extern
static Value Rsh_NilValue;
static SEXP NOT_OP;
static SEXP DOTEXTERNAL2_SYM;
static SEXP RSH_CALL_TRAMPOLINE_SXP;

// BINDING CELLS (bcell) implementation
// ------------------------------------
Expand Down Expand Up @@ -350,6 +353,8 @@ static INLINE Rboolean bcell_set_value(BCell cell, SEXP value) {
#define Rsh_const_int(env, idx) INT_TO_VAL(INTEGER(Rsh_const(env, idx))[0])
#define Rsh_const_lgl(env, idx) LGL_TO_VAL(INTEGER(Rsh_const(env, idx))[0])

typedef SEXP (*Rsh_closure)(SEXP, SEXP);

// RUNTIME INITIALIZATION
// ----------------------

Expand All @@ -362,6 +367,7 @@ static INLINE Rboolean bcell_set_value(BCell cell, SEXP value) {
} while (0)

#ifdef RSH_TESTS

SEXP Rsh_initialize_runtime(void) {
#define X(a, b) LOAD_R_BUILTIN(R_ARITH_OPS[b], #a);
X_ARITH_OPS
Expand Down Expand Up @@ -391,11 +397,40 @@ SEXP Rsh_initialize_runtime(void) {

Rsh_NilValue = SXP_TO_VAL(R_NilValue);
LOAD_R_BUILTIN(NOT_OP, "!");
DOTEXTERNAL2_SYM = install(".External2");

RSH_CALL_TRAMPOLINE_SXP = Rf_mkString("Rsh_call_trampoline");
R_PreserveObject(RSH_CALL_TRAMPOLINE_SXP);

return R_NilValue;
}
#endif

SEXP Rsh_call_trampoline(SEXP call, SEXP op, SEXP args, SEXP rho) {
SEXP closure = CADR(args);
if (TYPEOF(closure) != CLOSXP) {
Rf_error("Expected a closure");
}

SEXP body = BODY(closure);
if (TYPEOF(body) != BCODESXP) {
Rf_error("Expected a compiled closure");
}

SEXP cp = BCODE_CONSTS(body);
if (XLENGTH(cp) != 6) {
Rf_error("Expected a constant pool with 6 elements");
}

SEXP c_cp = VECTOR_ELT(cp, LENGTH(cp) - 2);
// cf. https://stackoverflow.com/a/19487645
Rsh_closure fun;
*(void **)(&fun) = R_ExternalPtrAddr(VECTOR_ELT(c_cp, 0));
SEXP res = fun(rho, VECTOR_ELT(c_cp, 1));

return res;
}

// INSTRUCTIONS
// ------------

Expand Down Expand Up @@ -915,29 +950,45 @@ static INLINE Value Rsh_logic(SEXP call, RshLogic2Op op, Value lhs, Value rhs,
return res;
}

#define LDCONST_OP 16
#define DOTCALL_OP 119
#define PUSHCONSTARG_OP 34
#define BASEGUARD_OP 123
#define GETBUILTIN_OP 26
#define CALLBUILTIN_OP 39
#define RETURN_OP 1

#define BCODE_CODE(x) CAR(x)
#define BCODE_CONSTS(x) CDR(x)
#define IS_BYTECODE(x) (TYPEOF(x) == BCODESXP)

static INLINE SEXP create_wrapper_body(SEXP original_body,
const char *native_fun_name, SEXP rho,
SEXP c_cp) {
static INLINE SEXP create_constant_pool(Rsh_closure fun_ptr, SEXP c_cp) {
SEXP pool = PROTECT(Rf_allocVector(VECSXP, 2));
SEXP p = R_MakeExternalPtr((void *)fun_ptr, R_NilValue, R_NilValue);

// slot 0: the pointer to the compiled function
SET_VECTOR_ELT(pool, 0, p);

// slot 1: the contants used by the compiled function
SET_VECTOR_ELT(pool, 1, c_cp);

UNPROTECT(1); // consts

return pool;
}

static INLINE SEXP create_wrapper_body(SEXP closure, SEXP rho, SEXP c_cp) {

// clang-format off
static i32 CALL_FUN_BC[] = {
12,
LDCONST_OP, 2,
LDCONST_OP, 3,
LDCONST_OP, 4,
DOTCALL_OP, 1, 2,
GETBUILTIN_OP, 1,
PUSHCONSTARG_OP, 2,
PUSHCONSTARG_OP, 3,
CALLBUILTIN_OP, 0,
RETURN_OP
};
// clang-format on

SEXP original_body = BODY(closure);
assert(IS_BYTECODE(original_body));

SEXP original_cp = BCODE_CONSTS(original_body);
Expand All @@ -952,15 +1003,14 @@ static INLINE SEXP create_wrapper_body(SEXP original_body,
INTEGER(expr_index)[0] = NA_INTEGER;
memset(INTEGER(expr_index) + 1, 0, (bc_size - 1) * sizeof(i32));

SEXP natfun_sxp = Rf_mkString(native_fun_name);
SEXP cp = PROTECT(Rf_allocVector(VECSXP, 6));
int i = 0;

// store the original AST (consequently it will not correspond to the AST)
SET_VECTOR_ELT(cp, i++, VECTOR_ELT(original_cp, 0));
SET_VECTOR_ELT(cp, i++, Rf_lang4(install(".Call"), natfun_sxp, rho, c_cp));
SET_VECTOR_ELT(cp, i++, natfun_sxp);
SET_VECTOR_ELT(cp, i++, rho);
SET_VECTOR_ELT(cp, i++, DOTEXTERNAL2_SYM);
SET_VECTOR_ELT(cp, i++, RSH_CALL_TRAMPOLINE_SXP);
SET_VECTOR_ELT(cp, i++, closure);
SET_VECTOR_ELT(cp, i++, c_cp);
SET_VECTOR_ELT(cp, i++, expr_index);

Expand All @@ -979,14 +1029,16 @@ static INLINE SEXP create_wrapper_body(SEXP original_body,
return body;
}

static INLINE Value Rsh_native_closure(SEXP mkclos_arg,
const char *native_fun_name, SEXP c_cp,
SEXP rho) {
static INLINE Value Rsh_native_closure(SEXP mkclos_arg, Rsh_closure fun_ptr,
SEXP consts, SEXP rho) {

SEXP forms = VECTOR_ELT(mkclos_arg, 0);
SEXP original_body = VECTOR_ELT(mkclos_arg, 1);
SEXP body = create_wrapper_body(original_body, native_fun_name, rho, c_cp);
SEXP closure = Rf_mkCLOSXP(forms, body, rho);
SEXP closure = PROTECT(Rf_mkCLOSXP(forms, original_body, rho));

SEXP c_cp = PROTECT(create_constant_pool(fun_ptr, consts));
SEXP body = PROTECT(create_wrapper_body(closure, rho, c_cp));
SET_BODY(closure, body);

if (LENGTH(mkclos_arg) > 2) {
SEXP srcref = VECTOR_ELT(mkclos_arg, 2);
Expand All @@ -996,6 +1048,7 @@ static INLINE Value Rsh_native_closure(SEXP mkclos_arg,
}
R_Visible = TRUE;

UNPROTECT(3);
return SXP_TO_VAL(closure);
}

Expand Down
7 changes: 4 additions & 3 deletions server/src/main/java/org/prlprg/bc2c/BC2CCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ public CompiledModule finish() {
var compiledClosure = module.compileClosure(bc);

var file = new CFile();
file.setPreamble("#include <Rsh.h>");
module.funs().forEach(file::add);
file.addInclude("Rsh.h");
module.funs().forEach(fun -> file.addFun(fun, true));

return new CompiledModule(file, compiledClosure.name(), compiledClosure.constantPool());
}
Expand Down Expand Up @@ -258,8 +258,9 @@ private void compileMakeClosure(ConstPool.Idx<VecSXP> idx) {
var compiledClosure = module.compileClosure(body.bc());
var cpId = constants.size();
// new body for the closure itself
// FIXME: the cpId has to be greater than any existing constpool.idx -- need to keep an extra one around
constants.put(cpId, new Constant(cpId, compiledClosure.constantPool()));
push("Rsh_native_closure(%s, \"%s\", %s, %s)".formatted(constantSXP(idx), compiledClosure.name(), constantSXP(cpId), NAME_ENV), false);
push("Rsh_native_closure(%s, &%s, %s, %s)".formatted(constantSXP(idx), compiledClosure.name(), constantSXP(cpId), NAME_ENV), false);
} else {
throw new UnsupportedOperationException("Unsupported body: " + body);
}
Expand Down
23 changes: 13 additions & 10 deletions server/src/main/java/org/prlprg/bc2c/CFile.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,16 @@
import java.util.List;

public class CFile {
private String preamble;
private final List<CFunction> funs = new ArrayList<>();

public void setPreamble(String preamble) {
this.preamble = preamble;
}
private final List<CFunction> forwards = new ArrayList<>();
private final List<String> includes = new ArrayList<>();

public void writeTo(Writer w) {
var pw = new PrintWriter(w);

if (preamble != null) {
pw.println(preamble);
pw.println();
}
includes.forEach(x -> pw.println("#include <" + x + ">"));
pw.println();
forwards.stream().map(CFunction::getDeclaration).forEach(x -> pw.println(x + ";"));

funs.forEach(x -> x.writeTo(pw));
}
Expand All @@ -31,7 +27,14 @@ public String toString() {
return w.toString();
}

public void add(CFunction fun) {
public void addFun(CFunction fun, boolean forwardDeclare) {
funs.add(fun);
if (forwardDeclare) {
forwards.add(fun);
}
}

public void addInclude(String include) {
includes.add(include);
}
}
70 changes: 37 additions & 33 deletions server/src/main/java/org/prlprg/bc2c/CFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,48 @@
import java.util.List;

public class CFunction {
private final String returnType;
private final String name;
private final String parameters;
private final List<CCode> sections = new ArrayList<>();
private final String returnType;
private final String name;
private final String parameters;
private final List<CCode> sections = new ArrayList<>();

CFunction(String returnType, String name, String parameters) {
this.returnType = returnType;
this.name = name;
this.parameters = parameters;
}
CFunction(String returnType, String name, String parameters) {
this.returnType = returnType;
this.name = name;
this.parameters = parameters;
}

public CCode add() {
var s = new CCode();
sections.add(s);
return s;
}
public CCode add() {
var s = new CCode();
sections.add(s);
return s;
}

public void writeTo(Writer w) {
var pw = new PrintWriter(w);
pw.format("%s %s(%s) {", returnType, name, parameters);
pw.println();
for (int i = 0; i < sections.size(); i++) {
sections.get(i).writeTo(w);
if (i < sections.size() - 1) {
public void writeTo(Writer w) {
var pw = new PrintWriter(w);
pw.format("%s {", getDeclaration());
pw.println();
}
for (int i = 0; i < sections.size(); i++) {
sections.get(i).writeTo(w);
if (i < sections.size() - 1) {
pw.println();
}
}
pw.println("}");
pw.flush();
}

public CCode insertAbove(CCode sec) {
var i = sections.indexOf(sec);
if (i == -1) {
throw new IllegalArgumentException("Section " + sec + " does not exist in fun " + this);
}
var s = new CCode();
sections.add(i, s);
return s;
}
pw.println("}");
pw.flush();
}

public CCode insertAbove(CCode sec) {
var i = sections.indexOf(sec);
if (i == -1) {
throw new IllegalArgumentException("Section " + sec + " does not exist in fun " + this);
public String getDeclaration() {
return String.format("%s %s(%s)", returnType, name, parameters);
}
var s = new CCode();
sections.add(i, s);
return s;
}
}
5 changes: 3 additions & 2 deletions server/src/test/java/org/prlprg/bc2c/BC2CCompilerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ public void testBooleanOperators() throws Exception {
public void testClosure() throws Exception {
verify(
"""
f <- function (x) { x + 1 }
y <- 21
f <- function (x) { x + y }
f(42)
""",
(RealSXP v) -> assertEquals(43.0, v.asReal(0)));
(RealSXP v) -> assertEquals(63.0, v.asReal(0)));
}

@Test
Expand Down

0 comments on commit aa20814

Please sign in to comment.