Skip to content

Commit

Permalink
extraction works, but for some reasons spec gets mixed with impl
Browse files Browse the repository at this point in the history
  • Loading branch information
AYadrov committed Feb 4, 2025
1 parent 33f5ffd commit e392e31
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
60 changes: 40 additions & 20 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,10 @@

;; Translates a Herbie rule into an egg rule
(define (rule->egg-rule ru)
(define ru-vars (map car (rule-itypes ru)))
(define ru-vars
(if (list? (rule-itypes ru))
(map car (rule-itypes ru))
(hash-keys (rule-itypes ru))))
(struct-copy rule
ru
[input (expr->egg-pattern (rule-input ru) ru-vars)]
Expand Down Expand Up @@ -519,6 +522,7 @@
(sow
(cons #f (make-ffi-rule "drop-hole-of-hole" "($hole ?repr ($hole ?repr ?a))" "($hole ?repr ?a)")))
(sow (cons #f (make-ffi-rule "drop-var" "($var ?repr ?a)" "?a")))
(sow (cons #f (make-ffi-rule "drop-literal" "($literal ?a ?repr)" "?a")))
(for ([rule (in-list rules)])
(define egg&ffi-rules
(hash-ref! (*egg-rule-cache*)
Expand Down Expand Up @@ -566,7 +570,7 @@
(define (enode-type enode ctx lookup)
(match enode
[(? number?) (cons 'real (platform-reprs (*active-platform*)))] ; number
[(? repr-exists?) (get-representation enode)] ; to be changed!
[(? repr-exists?) (get-representation enode)] ; useless but okay
[(? symbol?) ; variable
(define var (egg-var->var enode ctx))
(define repr (context-lookup ctx var))
Expand All @@ -576,7 +580,8 @@
[(eq? f '$approx) (platform-reprs (*active-platform*))]
[(eq? f 'if) (all-reprs/types)]
[(eq? f '$hole)
(define repr (lookup (car ids)))
(define repr-idx (car ids))
(define repr (car (lookup repr-idx))) ; lookup is garaunteed to return eclass '(type)
(get-representation repr)]
[(impl-exists? f) (list (impl-info f 'otype))]
[else (list (operator-info f 'otype))])]))
Expand All @@ -603,7 +608,7 @@
(define val-eclass (vector->list (lookup-eclass (u32vector-ref ids 1))))
(if (or (andmap (λ (x) (number? x)) val-eclass) (andmap (λ (x) (symbol? x)) val-eclass))
(list '$hole repr val) ; it is a hole that pointing to a variable/number
'())] ; it is a hole that pointing to expression - lowering rules did not work!
'())] ; it is a hole that pointing to expression - no need to extract
; it is just a repr node
[(repr-exists? f) f]
[else
Expand Down Expand Up @@ -669,11 +674,11 @@
(vector-set! id->parents child-id (cons idx (vector-ref id->parents child-id)))))]
[(? symbol?) (vector-set! id->leaf? idx #t)]
[(? number?) (vector-set! id->leaf? idx #t)]
; it was a hole expressions that got pruned
#;(printf "PRUNED BULLSHIT: ~a which refers to ~a\n"
enode
(lookup-eclass (u32vector-ref (cdr enode) 1)))
[(list) enode*])
[(list) ; it was a hole expressions that got pruned
(printf "PRUNED HOLE: ~a which refers to ~a\n"
enode
(lookup-eclass (u32vector-ref (cdr enode) 1)))
enode*])
(unless (empty? enode*)
(vector-set! id->eclass idx (cons enode* (vector-ref id->eclass idx)))))
(when (empty? (vector-ref id->eclass idx))
Expand All @@ -699,9 +704,7 @@

;; is the e-node well-typed?
(define (enode-typed? enode)
(or (number? enode)
(symbol? enode)
(and (list? enode) (not (empty? enode)) (andmap eclass-well-typed? (cdr enode)))))
(or (number? enode) (symbol? enode) (and (list? enode) (andmap eclass-well-typed? (cdr enode)))))

(define (check-typed! dirty?-vec)
(define dirty? #f)
Expand Down Expand Up @@ -978,16 +981,17 @@
(define ((typed-egg-batch-extractor batch-extract-to) regraph)
(define cost-proc (if (*egraph-platform-cost*) platform-egg-cost-proc default-egg-cost-proc))
(define eclasses (regraph-eclasses regraph))
(define types (regraph-types regraph))

; debugging
; ------------------------------------
(printf "\nProcessed Egraph: \n")
(for ([eclass (in-vector eclasses)]
[type (in-vector types)]
[n (in-naturals)])
(printf "Eclass ~a: ~a\n" n eclass))
(printf "Eclass ~a: ~a, type:~a\n" n eclass type))
; ------------------------------------

(define types (regraph-types regraph))
(define n (vector-length eclasses))

; e-class costs
Expand Down Expand Up @@ -1048,6 +1052,8 @@

(define id->spec (regraph-specs regraph))

(printf "Costs: ~a\n\n" costs)

(define ctx (regraph-ctx regraph))
(define-values (add-id add-enode finalize-batch)
(egg-nodes->batch costs id->spec batch-extract-to ctx))
Expand Down Expand Up @@ -1099,6 +1105,13 @@
(if (string-prefix? (symbol->string enode) "$var")
(egg-var->var enode ctx)
enode)]
[(list '$hole
(app eggref repr)
(app eggref val)) ; hole contains a var or number at this point
(match val
[(? number?) (literal val repr)]
[(? symbol?) (egg-var->var enode ctx)]
[_ (error (format "$hole contains unknown value ~a!" val))])]
[(list '$approx spec (app eggref impl))
(define spec* (vector-ref id->spec spec))
(unless spec*
Expand Down Expand Up @@ -1161,6 +1174,7 @@
(match node
[(? number?) 1]
[(? symbol?) 1]
[(list '$hole repr val-or-symbol) 1] ; hole is basically a number or symbol at this point, cost=1
; approx node
[(list '$approx _ impl) (rec impl)]
[(list 'if cond ift iff) (+ 1 (rec cond) (rec ift) (rec iff))]
Expand Down Expand Up @@ -1191,6 +1205,8 @@
[(? symbol?) ; variables
(define repr (context-lookup ctx (egg-var->var node ctx)))
((node-cost-proc node repr))]
; hole is basically a number or symbol at this point
[(list '$hole repr val-idx) 1] ; What to do here???
; approx node
[(list '$approx _ impl) (rec impl)]
[(list 'if cond ift iff) ; if expression
Expand All @@ -1208,11 +1224,10 @@
; Extract functions to extract exprs from egraph
(match-define (list extract-id _ _) extract)
; extract expr
(define key (cons id type))
(cond
; at least one extractable expression
[(hash-has-key? canon key)
(define id* (hash-ref canon key))
[(hash-has-key? canon id)
(define id* (hash-ref canon id))
(list (extract-id id* type))]
; no extractable expressions
[else (list)]))
Expand All @@ -1226,11 +1241,10 @@
; Functions for egg-extraction
(match-define (list _ extract-enode _) extract)
; extract expressions
(define key (cons id type))
(cond
; at least one extractable expression
[(hash-has-key? canon key)
(define id* (hash-ref canon key))
[(hash-has-key? canon id)
(define id* (hash-ref canon id))

(remove-duplicates (for/list ([enode (vector-ref eclasses id*)])
(extract-enode enode type))
Expand Down Expand Up @@ -1429,8 +1443,14 @@
(for/list ([id (in-list root-ids)]
[repr (in-list reprs)])
(regraph-extract-variants regraph extract-id id repr)))

; commit changes to the batch
(finalize-batch)

(printf "\nExpressions extracted: ...\n")
(for* ([rewrites out]
[rewrite rewrites])
(printf "~a\n" (debatchref rewrite)))
out)

(module+ test
Expand Down
2 changes: 1 addition & 1 deletion src/core/patch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

; egg schedule (3-phases for mathematical rewrites and implementation selection)
(define schedule
`((,lifting-rules . ((scheduler . simple))) (,rules . ((node . ,50)))
`((,lifting-rules . ((scheduler . simple))) (,rules . ((node . ,(*node-limit*))))
(,lowering-rules . ((scheduler . simple)))))

; run egg
Expand Down

0 comments on commit e392e31

Please sign in to comment.