Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating one egraph in generate-candidates #1143

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@

; node -> natural
; inserts an expression into the e-graph, returning its e-class id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal but better to avoid making unnecessary whitespace changes like this.

(define (insert-node! node root?)
(match node
[(list op ids ...) (egraph_add_node ptr (symbol->string op) (list->u32vec ids) root?)]
[(? symbol? x) (egraph_add_node ptr (symbol->string x) 0-vec root?)]
[(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)]))

(define insert-batch (batch-remove-zombie batch roots))

(define mappings (build-vector (batch-length insert-batch) values))
(define (remap x)
(vector-ref mappings x))
Expand All @@ -131,7 +131,6 @@
[(hole prec spec) (remap spec)] ; "hole" terms currently disappear
[(approx spec impl) (insert-node! (list '$approx (remap spec) (remap impl)) root?)]
[(list op (app remap args) ...) (insert-node! (cons op args) root?)]))

(vector-set! mappings n idx))

(for ([node (in-vector (batch-nodes insert-batch))]
Expand Down
56 changes: 5 additions & 51 deletions src/core/patch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,6 @@

(provide generate-candidates)

;;;;;;;;;;;;;;;;;;;;;;;;;;;; Simplify ;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (lower-approximations approxs global-batch)
(timeline-event! 'simplify)

(define reprs
(for/list ([approx (in-list approxs)])
(define prev (car (alt-prevs approx)))
(repr-of (debatchref (alt-expr prev)) (*context*))))

; generate real rules
(define rules (*simplify-rules*))
(define lowering-rules (platform-lowering-rules))

; egg runner
(define schedule
(if (flag-set? 'generate 'simplify)
; if simplify enabled, 2-phases for real rewrites and implementation selection
`((,rules . ((node . ,(*node-limit*))))
(,lowering-rules . ((iteration . 1) (scheduler . simple))))
; if disabled, only implementation selection
`((,lowering-rules . ((iteration . 1) (scheduler . simple))))))

(define roots
(for/vector ([approx (in-list approxs)])
(batchref-idx (alt-expr approx))))

; run egg
(define runner (make-egraph global-batch roots reprs schedule))
(define simplification-options (simplify-batch runner global-batch))

; convert to altns
(define simplified
(reap [sow]
(define global-batch-mutable (batch->mutable-batch global-batch)) ; Create mutable batch
(for ([altn (in-list approxs)]
[outputs (in-list simplification-options)])
(match-define (cons _ simplified) outputs)
(define prev (car (alt-prevs altn)))
(for ([bref (in-list simplified)])
(sow (alt bref `(simplify ,runner #f) (list altn) '()))))
(batch-copy-mutable-nodes! global-batch global-batch-mutable))) ; Update global-batch

(timeline-push! 'count (length approxs) (length simplified))
simplified)

;;;;;;;;;;;;;;;;;;;;;;;;;;;; Taylor ;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define transforms-to-try
Expand Down Expand Up @@ -111,7 +65,7 @@
(timeline-push! 'outputs (map ~a (map (compose debatchref alt-expr) approxs)))
(timeline-push! 'count (length altns) (length approxs))

(lower-approximations approxs global-batch))
approxs)

;;;;;;;;;;;;;;;;;;;;;;;;;;;; Recursive Rewrite ;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand All @@ -134,7 +88,6 @@
(define roots (list->vector (map (compose batchref-idx alt-expr) altns)))
(define reprs (map (curryr repr-of (*context*)) exprs))
(timeline-push! 'inputs (map ~a exprs))

(define runner (make-egraph global-batch roots reprs schedule))
; batchrefss is a (listof (listof batchref))
(define batchrefss (egraph-variations runner global-batch))
Expand All @@ -161,7 +114,7 @@
; Starting alternatives
(define start-altns
(for/list ([expr (in-list exprs)]
[root (batch-roots global-batch)])
[root (in-vector (batch-roots global-batch))])
(define repr (repr-of expr (*context*)))
(alt (batchref global-batch root) (list 'patch expr repr) '() '())))

Expand All @@ -170,10 +123,11 @@
(if (flag-set? 'generate 'taylor)
(run-taylor exprs start-altns global-batch)
'()))

; Recursive rewrite
(define rewritten
(if (flag-set? 'generate 'rr)
(run-rr start-altns global-batch)
(run-rr (append start-altns approximations) global-batch)
'()))

(remove-duplicates (append approximations rewritten) #:key (λ (x) (batchref-idx (alt-expr x)))))
(remove-duplicates rewritten #:key (λ (x) (batchref-idx (alt-expr x)))))
1 change: 1 addition & 0 deletions src/core/programs.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
[(literal val precision) (get-representation precision)]
[(? variable?) (context-lookup ctx expr)]
[(approx _ impl) (repr-of impl ctx)]
[(hole precision spec) (get-representation precision)]
[(list 'if cond ift iff) (repr-of ift ctx)]
[(list op args ...) (impl-info op 'otype)]))

Expand Down
Loading