Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: ComposablePass trait allowing sequencing and validation #1895

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions hugr-passes/src/composable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//! Compiler passes and utilities for composing them

use std::{error::Error, marker::PhantomData};

use hugr_core::hugr::{hugrmut::HugrMut, ValidationError};
use hugr_core::HugrView;
use itertools::Either;

pub trait ComposablePass: Sized {
type Err: Error;
fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err>;
fn map_err<E2: Error>(self, f: impl Fn(Self::Err) -> E2) -> impl ComposablePass<Err = E2> {
ErrMapper::new(self, f)
}
fn sequence(
self,
other: impl ComposablePass<Err = Self::Err>,
) -> impl ComposablePass<Err = Self::Err> {
(self, other) // SequencePass::new(self, other) ?
}
fn sequence_either<P: ComposablePass>(
self,
other: P,
) -> impl ComposablePass<Err = Either<Self::Err, P::Err>> {
self.map_err(Either::Left)
.sequence(other.map_err(Either::Right))
}
}

struct ErrMapper<P, E, F>(P, F, PhantomData<E>);

impl<P: ComposablePass, E: Error, F: Fn(P::Err) -> E> ErrMapper<P, E, F> {
fn new(pass: P, err_fn: F) -> Self {
Self(pass, err_fn, PhantomData)
}
}

impl<P: ComposablePass, E: Error, F: Fn(P::Err) -> E> ComposablePass for ErrMapper<P, E, F> {
type Err = E;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
self.0.run(hugr).map_err(&self.1)
}
}

impl<E: Error, P1: ComposablePass<Err = E>, P2: ComposablePass<Err = E>> ComposablePass
for (P1, P2)
{
type Err = E;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
self.0.run(hugr)?;
self.1.run(hugr)
}
}

#[derive(thiserror::Error, Debug)]
#[allow(missing_docs)]
pub enum ValidatePassError<E> {
#[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")]
Input {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")]
Output {
#[source]
err: ValidationError,
pretty_hugr: String,
},
#[error(transparent)]
Underlying(E),
}

/// Runs another, underlying, pass, with validation of the Hugr
/// both before and afterwards.
pub struct ValidatingPass<P>(P, bool);

impl<P: ComposablePass> ValidatingPass<P> {
pub fn new_default(underlying: P) -> Self {
// Self(underlying, cfg!(feature = "extension_inference"))
// Sadly, many tests fail with extension inference, hence:
Self(underlying, false)
}

pub fn new_validating_extensions(underlying: P) -> Self {
Self(underlying, true)
}

pub fn new(underlying: P, validate_extensions: bool) -> Self {
Self(underlying, validate_extensions)
}

fn validation_impl<E>(
&self,
hugr: &impl HugrView,
mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError<E>,
) -> Result<(), ValidatePassError<E>> {
match self.1 {
false => hugr.validate_no_extensions(),
true => hugr.validate(),
}
.map_err(|err| mk_err(err, hugr.mermaid_string()))
}
}

impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
type Err = ValidatePassError<P::Err>;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
err,
pretty_hugr,
})?;
self.0.run(hugr).map_err(ValidatePassError::Underlying)?;
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output {
err,
pretty_hugr,
})
}
}

pub fn validate_if_test<P: ComposablePass>(
pass: P,
hugr: &mut impl HugrMut,
) -> Result<(), ValidatePassError<P::Err>> {
if cfg!(test) {
ValidatingPass::new_default(pass).run(hugr)
} else {
pass.run(hugr).map_err(ValidatePassError::Underlying)
}
}
140 changes: 61 additions & 79 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

pub mod value_handle;
use std::collections::{HashMap, HashSet, VecDeque};
use thiserror::Error;
use std::convert::Infallible;

use hugr_core::{
hugr::{
Expand All @@ -20,36 +20,21 @@ use hugr_core::{
};
use value_handle::ValueHandle;

use crate::composable::validate_if_test;
use crate::dataflow::{
partial_from_const, AbstractValue, AnalysisResults, ConstLoader, ConstLocation, DFContext,
Machine, PartialValue, TailLoopTermination,
};
use crate::validation::{ValidatePassError, ValidationLevel};
use crate::ComposablePass;

#[derive(Debug, Clone, Default)]
/// A configuration for the Constant Folding pass.
pub struct ConstantFoldPass {
validation: ValidationLevel,
allow_increase_termination: bool,
inputs: HashMap<IncomingPort, Value>,
}

#[derive(Debug, Error)]
#[non_exhaustive]
/// Errors produced by [ConstantFoldPass].
pub enum ConstFoldError {
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
}

impl ConstantFoldPass {
/// Sets the validation level used before and after the pass is run
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their
/// result (if/when they do terminate) is either known or not needed.
///
Expand All @@ -71,8 +56,63 @@ impl ConstantFoldPass {
self
}

/// Run the Constant Folding pass.
fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> {
fn find_needed_nodes<H: HugrView>(
&self,
results: &AnalysisResults<ValueHandle, H>,
needed: &mut HashSet<Node>,
) {
let mut q = VecDeque::new();
let h = results.hugr();
q.push_back(h.root());
while let Some(n) = q.pop_front() {
if !needed.insert(n) {
continue;
};

if h.get_optype(n).is_cfg() {
for bb in h.children(n) {
//if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
q.push_back(bb);
}
} else if let Some(inout) = h.get_io(n) {
// Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges.
q.extend(inout); // Input also necessary for legality even if unreachable

if !self.allow_increase_termination {
// Also add on anything that might not terminate (even if results not required -
// if its results are required we'll add it by following dataflow, below.)
for ch in h.children(n) {
if might_diverge(results, ch) {
q.push_back(ch);
}
}
}
}
// Also follow dataflow demand
for (src, op) in h.all_linked_outputs(n) {
let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() {
EdgeKind::Value(_) => {
h.get_optype(src).is_load_constant()
|| results
.try_read_wire_concrete::<Value, _, _>(Wire::new(src, op))
.is_err()
}
EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true,
EdgeKind::ControlFlow => false, // we always include all children of a CFG above
_ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst
};
if needs_predecessor {
q.push_back(src);
}
}
}
}
}

impl ComposablePass for ConstantFoldPass {
type Err = Infallible;

fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Self::Err> {
let fresh_node = Node::from(portgraph::NodeIndex::new(
hugr.nodes().max().map_or(0, |n| n.index() + 1),
));
Expand Down Expand Up @@ -135,64 +175,6 @@ impl ConstantFoldPass {
}
Ok(())
}

/// Run the pass using this configuration
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
self.validation
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
}

fn find_needed_nodes<H: HugrView>(
&self,
results: &AnalysisResults<ValueHandle, H>,
needed: &mut HashSet<Node>,
) {
let mut q = VecDeque::new();
let h = results.hugr();
q.push_back(h.root());
while let Some(n) = q.pop_front() {
if !needed.insert(n) {
continue;
};

if h.get_optype(n).is_cfg() {
for bb in h.children(n) {
//if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
q.push_back(bb);
}
} else if let Some(inout) = h.get_io(n) {
// Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges.
q.extend(inout); // Input also necessary for legality even if unreachable

if !self.allow_increase_termination {
// Also add on anything that might not terminate (even if results not required -
// if its results are required we'll add it by following dataflow, below.)
for ch in h.children(n) {
if might_diverge(results, ch) {
q.push_back(ch);
}
}
}
}
// Also follow dataflow demand
for (src, op) in h.all_linked_outputs(n) {
let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() {
EdgeKind::Value(_) => {
h.get_optype(src).is_load_constant()
|| results
.try_read_wire_concrete::<Value, _, _>(Wire::new(src, op))
.is_err()
}
EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true,
EdgeKind::ControlFlow => false, // we always include all children of a CFG above
_ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst
};
if needs_predecessor {
q.push_back(src);
}
}
}
}
}

// "Diverge" aka "never-terminate"
Expand All @@ -219,7 +201,7 @@ fn might_diverge<V: AbstractValue>(results: &AnalysisResults<V, impl HugrView>,

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass<H: HugrMut>(h: &mut H) {
ConstantFoldPass::default().run(h).unwrap()
validate_if_test(ConstantFoldPass::default(), h).unwrap()
}

struct ConstFoldContext<'a, H>(&'a H);
Expand Down
1 change: 1 addition & 0 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV};
use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node};

use crate::dataflow::{partial_from_const, DFContext, PartialValue};
use crate::ComposablePass as _;

use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle};

Expand Down
Loading
Loading