diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index ecc522ec39d12..70222f4acabe1 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -30,14 +30,6 @@ pub enum DiffMode { Forward, /// The target function, to be created using reverse mode AD. Reverse, - /// The target function, to be created using forward mode AD. - /// This target function will also be used as a source for higher order derivatives, - /// so compute it before all Forward/Reverse targets and optimize it through llvm. - ForwardFirst, - /// The target function, to be created using reverse mode AD. - /// This target function will also be used as a source for higher order derivatives, - /// so compute it before all Forward/Reverse targets and optimize it through llvm. - ReverseFirst, } /// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. @@ -92,10 +84,10 @@ pub struct AutoDiffAttrs { impl DiffMode { pub fn is_rev(&self) -> bool { - matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst) + matches!(self, DiffMode::Reverse) } pub fn is_fwd(&self) -> bool { - matches!(self, DiffMode::Forward | DiffMode::ForwardFirst) + matches!(self, DiffMode::Forward) } } @@ -106,8 +98,6 @@ impl Display for DiffMode { DiffMode::Source => write!(f, "Source"), DiffMode::Forward => write!(f, "Forward"), DiffMode::Reverse => write!(f, "Reverse"), - DiffMode::ForwardFirst => write!(f, "ForwardFirst"), - DiffMode::ReverseFirst => write!(f, "ReverseFirst"), } } } @@ -125,12 +115,12 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { match mode { DiffMode::Error => false, DiffMode::Source => false, - DiffMode::Forward | DiffMode::ForwardFirst => { + DiffMode::Forward => { activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const } - DiffMode::Reverse | DiffMode::ReverseFirst => { + DiffMode::Reverse => { activity == DiffActivity::Const || activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly @@ -166,10 +156,10 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { return match mode { DiffMode::Error => false, DiffMode::Source => false, - DiffMode::Forward | DiffMode::ForwardFirst => { + DiffMode::Forward => { matches!(activity, Dual | DualOnly | Const) } - DiffMode::Reverse | DiffMode::ReverseFirst => { + DiffMode::Reverse => { matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const) } }; @@ -200,8 +190,6 @@ impl FromStr for DiffMode { "Source" => Ok(DiffMode::Source), "Forward" => Ok(DiffMode::Forward), "Reverse" => Ok(DiffMode::Reverse), - "ForwardFirst" => Ok(DiffMode::ForwardFirst), - "ReverseFirst" => Ok(DiffMode::ReverseFirst), _ => Err(()), } } diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 78c759bbe8c03..8bad437eeb716 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -606,10 +606,31 @@ pub(crate) fn run_pass_manager( // If this rustc version was build with enzyme/autodiff enabled, and if users applied the // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time. - let first_run = true; debug!("running llvm pm opt pipeline"); unsafe { - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; + write::llvm_optimize( + cgcx, + dcx, + module, + config, + opt_level, + opt_stage, + write::AutodiffStage::DuringAD, + )?; + } + // FIXME(ZuseZ4): Make this more granular + if cfg!(llvm_enzyme) && !thin { + unsafe { + write::llvm_optimize( + cgcx, + dcx, + module, + config, + opt_level, + llvm::OptStage::FatLTO, + write::AutodiffStage::PostAD, + )?; + } } debug!("lto done"); Ok(()) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 4706744f35307..ae4c4d5876e2b 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -530,6 +530,16 @@ fn get_instr_profile_output_path(config: &ModuleConfig) -> Option { config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned()) } +// PreAD will run llvm opts but disable size increasing opts (vectorization, loop unrolling) +// DuringAD is the same as above, but also runs the enzyme opt and autodiff passes. +// PostAD will run all opts, including size increasing opts. +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum AutodiffStage { + PreAD, + DuringAD, + PostAD, +} + pub(crate) unsafe fn llvm_optimize( cgcx: &CodegenContext, dcx: DiagCtxtHandle<'_>, @@ -537,7 +547,7 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, - skip_size_increasing_opts: bool, + autodiff_stage: AutodiffStage, ) -> Result<(), FatalError> { // Enzyme: // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized @@ -550,12 +560,16 @@ pub(crate) unsafe fn llvm_optimize( let unroll_loops; let vectorize_slp; let vectorize_loop; + let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD; // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing - // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly, + // optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt). + // We therefore have two calls to llvm_optimize, if autodiff is used. + // + // FIXME(ZuseZ4): Before shipping on nightly, // we should make this more granular, or at least check that the user has at least one autodiff // call in their code, to justify altering the compilation pipeline. - if skip_size_increasing_opts && cfg!(llvm_enzyme) { + if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD { unroll_loops = false; vectorize_slp = false; vectorize_loop = false; @@ -565,7 +579,7 @@ pub(crate) unsafe fn llvm_optimize( vectorize_slp = config.vectorize_slp; vectorize_loop = config.vectorize_loop; } - trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop); + trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop, ?run_enzyme); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -633,6 +647,7 @@ pub(crate) unsafe fn llvm_optimize( vectorize_loop, config.no_builtins, config.emit_lifetime_markers, + run_enzyme, sanitizer_options.as_ref(), pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), @@ -684,18 +699,14 @@ pub(crate) unsafe fn optimize( _ => llvm::OptStage::PreLinkNoLTO, }; - // If we know that we will later run AD, then we disable vectorization and loop unrolling - let skip_size_increasing_opts = cfg!(llvm_enzyme); + // If we know that we will later run AD, then we disable vectorization and loop unrolling. + // Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD). + // FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff + // usages, not just if we build rustc with autodiff support. + let autodiff_stage = + if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD }; return unsafe { - llvm_optimize( - cgcx, - dcx, - module, - config, - opt_level, - opt_stage, - skip_size_increasing_opts, - ) + llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage) }; } Ok(()) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index dd5e726160d48..b2c1088e3fc02 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_errors::FatalError; -use rustc_session::config::Lto; use tracing::{debug, trace}; -use crate::back::write::{llvm_err, llvm_optimize}; +use crate::back::write::llvm_err; use crate::builder::SBuilder; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; @@ -53,8 +52,6 @@ fn generate_enzyme_call<'ll>( let mut ad_name: String = match attrs.mode { DiffMode::Forward => "__enzyme_fwddiff", DiffMode::Reverse => "__enzyme_autodiff", - DiffMode::ForwardFirst => "__enzyme_fwddiff", - DiffMode::ReverseFirst => "__enzyme_autodiff", _ => panic!("logic bug in autodiff, unrecognized mode"), } .to_string(); @@ -153,7 +150,7 @@ fn generate_enzyme_call<'ll>( _ => {} } - trace!("matching autodiff arguments"); + debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on // llvm-ir level. The number of activities matches the number of Rust level arguments, so we @@ -164,10 +161,10 @@ fn generate_enzyme_call<'ll>( let mut activity_pos = 0; let outer_args: Vec<&llvm::Value> = get_params(outer_fn); while activity_pos < inputs.len() { - let activity = inputs[activity_pos as usize]; + let diff_activity = inputs[activity_pos as usize]; // Duplicated arguments received a shadow argument, into which enzyme will write the // gradient. - let (activity, duplicated): (&Metadata, bool) = match activity { + let (activity, duplicated): (&Metadata, bool) = match diff_activity { DiffActivity::None => panic!("not a valid input activity"), DiffActivity::Const => (enzyme_const, false), DiffActivity::Active => (enzyme_out, false), @@ -222,7 +219,15 @@ fn generate_enzyme_call<'ll>( // A duplicated pointer will have the following two outer_fn arguments: // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: // (..., metadata! enzyme_dup, ptr, ptr, ...). - assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); + if matches!( + diff_activity, + DiffActivity::Duplicated | DiffActivity::DuplicatedOnly + ) { + assert!( + llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer + ); + } + // In the case of Dual we don't have assumptions, e.g. f32 would be valid. args.push(next_outer_arg); outer_pos += 2; activity_pos += 1; @@ -277,7 +282,7 @@ pub(crate) fn differentiate<'ll>( module: &'ll ModuleCodegen, cgcx: &CodegenContext, diff_items: Vec, - config: &ModuleConfig, + _config: &ModuleConfig, ) -> Result<(), FatalError> { for item in &diff_items { trace!("{}", item); @@ -318,29 +323,6 @@ pub(crate) fn differentiate<'ll>( // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts - if let Some(opt_level) = config.opt_level { - let opt_stage = match cgcx.lto { - Lto::Fat => llvm::OptStage::PreLinkFatLTO, - Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, - _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, - _ => llvm::OptStage::PreLinkNoLTO, - }; - // This is our second opt call, so now we run all opts, - // to make sure we get the best performance. - let skip_size_increasing_opts = false; - trace!("running Module Optimization after differentiation"); - unsafe { - llvm_optimize( - cgcx, - diag_handler.handle(), - module, - config, - opt_level, - opt_stage, - skip_size_increasing_opts, - )? - }; - } trace!("done with differentiate()"); Ok(()) diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 50a40c9c30927..4d6a76b23ea1a 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2382,6 +2382,7 @@ unsafe extern "C" { LoopVectorize: bool, DisableSimplifyLibCalls: bool, EmitLifetimeMarkers: bool, + RunEnzyme: bool, SanitizerOptions: Option<&SanitizerOptions>, PGOGenPath: *const c_char, PGOUsePath: *const c_char, diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 7acdbd19993de..2c38fb5658f22 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -914,8 +914,6 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { let mode = match mode.as_str() { "Forward" => DiffMode::Forward, "Reverse" => DiffMode::Reverse, - "ForwardFirst" => DiffMode::ForwardFirst, - "ReverseFirst" => DiffMode::ReverseFirst, _ => { span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode"); } diff --git a/compiler/rustc_llvm/build.rs b/compiler/rustc_llvm/build.rs index d9d28299413b1..48806888b43df 100644 --- a/compiler/rustc_llvm/build.rs +++ b/compiler/rustc_llvm/build.rs @@ -193,6 +193,10 @@ fn main() { cfg.define(&flag, None); } + if tracked_env_var_os("LLVM_ENZYME").is_some() { + cfg.define("ENZYME", None); + } + if tracked_env_var_os("LLVM_RUSTLLVM").is_some() { cfg.define("LLVM_RUSTLLVM", None); } diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp index 6447a9362b3ab..a6b2384f2d7b2 100644 --- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp @@ -688,14 +688,20 @@ struct LLVMRustSanitizerOptions { bool SanitizeKernelAddressRecover; }; +// This symbol won't be available or used when Enzyme is not enabled +#ifdef ENZYME +extern "C" void registerEnzyme(llvm::PassBuilder &PB); +#endif + extern "C" LLVMRustResult LLVMRustOptimize( LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef, LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage, bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR, bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, - bool EmitLifetimeMarkers, LLVMRustSanitizerOptions *SanitizerOptions, - const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, + bool EmitLifetimeMarkers, bool RunEnzyme, + LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath, + const char *PGOUsePath, bool InstrumentCoverage, const char *InstrProfileOutput, const char *PGOSampleUsePath, bool DebugInfoForProfiling, void *LlvmSelfProfiler, LLVMRustSelfProfileBeforePassCallback BeforePassCallback, @@ -1010,6 +1016,18 @@ extern "C" LLVMRustResult LLVMRustOptimize( MPM.addPass(NameAnonGlobalPass()); } + // now load "-enzyme" pass: +#ifdef ENZYME + if (RunEnzyme) { + registerEnzyme(PB); + if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { + std::string ErrMsg = toString(std::move(Err)); + LLVMRustSetLastError(ErrMsg.c_str()); + return LLVMRustResult::Failure; + } + } +#endif + // Upgrade all calls to old intrinsics first. for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;) UpgradeCallsToIntrinsic(&*I++); // must be post-increment, as we remove diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index cd3558ac6a49b..d22fad1840635 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1049,12 +1049,12 @@ pub fn rustc_cargo( // . cargo.rustflag("-Zon-broken-pipe=kill"); - // We temporarily disable linking here as part of some refactoring. - // This way, people can manually use -Z llvm-plugins and -C passes=enzyme for now. - // In a follow-up PR, we will re-enable linking here and load the pass for them. - //if builder.config.llvm_enzyme { - // cargo.rustflag("-l").rustflag("Enzyme-19"); - //} + // We want to link against registerEnzyme and in the future we want to use additional + // functionality from Enzyme core. For that we need to link against Enzyme. + // FIXME(ZuseZ4): Get the LLVM version number automatically instead of hardcoding it. + if builder.config.llvm_enzyme { + cargo.rustflag("-l").rustflag("Enzyme-19"); + } // Building with protected visibility reduces the number of dynamic relocations needed, giving // us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object @@ -1234,6 +1234,9 @@ fn rustc_llvm_env(builder: &Builder<'_>, cargo: &mut Cargo, target: TargetSelect if builder.is_rust_llvm(target) { cargo.env("LLVM_RUSTLLVM", "1"); } + if builder.config.llvm_enzyme { + cargo.env("LLVM_ENZYME", "1"); + } let llvm::LlvmResult { llvm_config, .. } = builder.ensure(llvm::Llvm { target }); cargo.env("LLVM_CONFIG", &llvm_config); diff --git a/src/bootstrap/src/core/build_steps/llvm.rs b/src/bootstrap/src/core/build_steps/llvm.rs index 49bf04356d5bc..46b98a8e7babe 100644 --- a/src/bootstrap/src/core/build_steps/llvm.rs +++ b/src/bootstrap/src/core/build_steps/llvm.rs @@ -968,6 +968,7 @@ impl Step for Enzyme { .env("LLVM_CONFIG_REAL", &llvm_config) .define("LLVM_ENABLE_ASSERTIONS", "ON") .define("ENZYME_EXTERNAL_SHARED_LIB", "ON") + .define("ENZYME_RUNPASS", "ON") .define("LLVM_DIR", builder.llvm_out(target)); cfg.build(); diff --git a/src/tools/enzyme b/src/tools/enzyme index 0e5fa4a3d475f..7f3b207c4413c 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 0e5fa4a3d475f4dece489c9e06b11164f83789f5 +Subproject commit 7f3b207c4413c9d715fd54b36b8a8fd3179e0b67 diff --git a/tests/codegen/autodiff.rs b/tests/codegen/autodiff.rs new file mode 100644 index 0000000000000..abf7fcf3e4bcd --- /dev/null +++ b/tests/codegen/autodiff.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square, Reverse, Duplicated, Active)] +#[no_mangle] +fn square(x: &f64) -> f64 { + x * x +} + +// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'" +// CHECK-NEXT:invertstart: +// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val +// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val +// CHECK-NEXT: %1 = load double, ptr %"x'", align 8 +// CHECK-NEXT: %2 = fadd fast double %1, %0 +// CHECK-NEXT: store double %2, ptr %"x'", align 8 +// CHECK-NEXT: ret double %_0 +// CHECK-NEXT:} + +fn main() { + let x = 3.0; + let output = square(&x); + assert_eq!(9.0, output); + + let mut df_dx = 0.0; + let output_ = d_square(&x, &mut df_dx, 1.0); + assert_eq!(output, output_); + assert_eq!(6.0, df_dx); +}