diff --git a/Cargo.toml b/Cargo.toml index db72e3a..d6ad6e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,12 @@ tracing = "0.1.37" elsa = "1.9.0" bitvec = "1.0.1" serde = { version = "1.0", features = ["derive"], optional = true } +futures = { version = "0.3.30", default-features = false, features = ["alloc"] } +tokio = { version = "1.35.1", features = ["rt", "sync"] } [dev-dependencies] insta = "1.31.0" indexmap = "2.0.0" proptest = "1.2.0" tracing-test = { version = "0.2.4", features = ["no-env-filter"] } +tokio = { version = "1.35.1", features = ["time"] } diff --git a/src/lib.rs b/src/lib.rs index 5985145..ad766f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ use std::{ any::Any, fmt::{Debug, Display}, hash::Hash, + rc::Rc, }; /// The solver is based around the fact that for for every package name we are trying to find a @@ -61,21 +62,37 @@ pub trait VersionSet: Debug + Display + Clone + Eq + Hash { /// packages that are available in the system. pub trait DependencyProvider: Sized { /// Returns the [`Pool`] that is used to allocate the Ids returned from this instance - fn pool(&self) -> &Pool; + fn pool(&self) -> Rc>; /// Sort the specified solvables based on which solvable to try first. The solver will /// iteratively try to select the highest version. If a conflict is found with the highest /// version the next version is tried. This continues until a solution is found. fn sort_candidates(&self, solver: &SolverCache, solvables: &mut [SolvableId]); - /// Returns a list of solvables that should be considered when a package with the given name is + /// Obtains a list of solvables that should be considered when a package with the given name is /// requested. /// - /// Returns `None` if no such package exist. - fn get_candidates(&self, name: NameId) -> Option; + /// # Async + /// + /// The returned future will be awaited by a tokio runtime blocking the main thread. You are + /// free to use other runtimes in your implementation, as long as the runtime-specific code runs + /// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can + /// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to + /// retrieve necessary information from the network, and `await` the returned task handle. + #[allow(async_fn_in_trait)] + async fn get_candidates(&self, name: NameId) -> Option; /// Returns the dependencies for the specified solvable. - fn get_dependencies(&self, solvable: SolvableId) -> Dependencies; + /// + /// # Async + /// + /// The returned future will be awaited by a tokio runtime blocking the main thread. You are + /// free to use other runtimes in your implementation, as long as the runtime-specific code runs + /// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can + /// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to + /// retrieve necessary information from the network, and `await` the returned task handle. + #[allow(async_fn_in_trait)] + async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies; /// Whether the solver should stop the dependency resolution algorithm. /// diff --git a/src/problem.rs b/src/problem.rs index 311c7bd..7dc1870 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -4,7 +4,6 @@ use std::collections::{HashMap, HashSet}; use std::fmt; use std::fmt::{Display, Formatter}; use std::hash::Hash; - use std::rc::Rc; use itertools::Itertools; @@ -52,7 +51,7 @@ impl Problem { let unresolved_node = graph.add_node(ProblemNode::UnresolvedDependency); for clause_id in &self.clauses { - let clause = &solver.clauses[*clause_id].kind; + let clause = &solver.clauses.borrow()[*clause_id].kind; match clause { Clause::InstallRoot => (), Clause::Excluded(solvable, reason) => { @@ -65,7 +64,7 @@ impl Problem { &Clause::Requires(package_id, version_set_id) => { let package_node = Self::add_node(&mut graph, &mut nodes, package_id); - let candidates = solver.cache.get_or_cache_sorted_candidates(version_set_id).unwrap_or_else(|_| { + let candidates = solver.async_runtime.block_on(solver.cache.get_or_cache_sorted_candidates(version_set_id)).unwrap_or_else(|_| { unreachable!("The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`") }); if candidates.is_empty() { @@ -162,10 +161,11 @@ impl Problem { >( &self, solver: &'a Solver, + pool: Rc>, merged_solvable_display: &'a M, ) -> DisplayUnsat<'a, VS, N, M> { let graph = self.graph(solver); - DisplayUnsat::new(graph, solver.pool(), merged_solvable_display) + DisplayUnsat::new(graph, pool, merged_solvable_display) } } @@ -512,7 +512,7 @@ pub struct DisplayUnsat<'pool, VS: VersionSet, N: PackageName + Display, M: Solv merged_candidates: HashMap>, installable_set: HashSet, missing_set: HashSet, - pool: &'pool Pool, + pool: Rc>, merged_solvable_display: &'pool M, } @@ -521,10 +521,10 @@ impl<'pool, VS: VersionSet, N: PackageName + Display, M: SolvableDisplay> { pub(crate) fn new( graph: ProblemGraph, - pool: &'pool Pool, + pool: Rc>, merged_solvable_display: &'pool M, ) -> Self { - let merged_candidates = graph.simplify(pool); + let merged_candidates = graph.simplify(&pool); let installable_set = graph.get_installable_set(); let missing_set = graph.get_missing_set(); @@ -666,10 +666,10 @@ impl<'pool, VS: VersionSet, N: PackageName + Display, M: SolvableDisplay> let version = if let Some(merged) = self.merged_candidates.get(&solvable_id) { reported.extend(merged.ids.iter().cloned()); self.merged_solvable_display - .display_candidates(self.pool, &merged.ids) + .display_candidates(&self.pool, &merged.ids) } else { self.merged_solvable_display - .display_candidates(self.pool, &[solvable_id]) + .display_candidates(&self.pool, &[solvable_id]) }; let excluded = graph @@ -790,9 +790,9 @@ impl> fmt::D writeln!( f, "{indent}{} {} is locked, but another version is required as reported above", - locked.name.display(self.pool), + locked.name.display(&self.pool), self.merged_solvable_display - .display_candidates(self.pool, &[solvable_id]) + .display_candidates(&self.pool, &[solvable_id]) )?; } ConflictCause::Excluded(_, _) => continue, diff --git a/src/solver/cache.rs b/src/solver/cache.rs index 8d25d3d..072543d 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -13,6 +13,7 @@ use elsa::FrozenMap; use std::any::Any; use std::cell::RefCell; use std::marker::PhantomData; +use std::rc::Rc; /// Keeps a cache of previously computed and/or requested information about solvables and version /// sets. @@ -65,7 +66,7 @@ impl> SolverCache &Pool { + pub fn pool(&self) -> Rc> { self.provider.pool() } @@ -74,7 +75,7 @@ impl> SolverCache Result<&Candidates, Box> { @@ -93,6 +94,7 @@ impl> SolverCache> SolverCache Result<&[SolvableId], Box> { match self.version_set_candidates.get(&version_set_id) { Some(candidates) => Ok(candidates), None => { - let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let version_set = self.pool().resolve_version_set(version_set_id); - let candidates = self.get_or_cache_candidates(package_name)?; + let pool = self.pool(); + let package_name = pool.resolve_version_set_package_name(version_set_id); + let version_set = pool.resolve_version_set(version_set_id); + let candidates = self.get_or_cache_candidates(package_name).await?; let matching_candidates = candidates .candidates .iter() .copied() .filter(|&p| { - let version = self.pool().resolve_internal_solvable(p).solvable().inner(); + let version = pool.resolve_internal_solvable(p).solvable().inner(); version_set.contains(version) }) .collect(); @@ -158,23 +161,24 @@ impl> SolverCache Result<&[SolvableId], Box> { match self.version_set_inverse_candidates.get(&version_set_id) { Some(candidates) => Ok(candidates), None => { - let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let version_set = self.pool().resolve_version_set(version_set_id); - let candidates = self.get_or_cache_candidates(package_name)?; + let pool = self.pool(); + let package_name = pool.resolve_version_set_package_name(version_set_id); + let version_set = pool.resolve_version_set(version_set_id); + let candidates = self.get_or_cache_candidates(package_name).await?; let matching_candidates = candidates .candidates .iter() .copied() .filter(|&p| { - let version = self.pool().resolve_internal_solvable(p).solvable().inner(); + let version = pool.resolve_internal_solvable(p).solvable().inner(); !version_set.contains(version) }) .collect(); @@ -191,7 +195,7 @@ impl> SolverCache Result<&[SolvableId], Box> { @@ -199,8 +203,10 @@ impl> SolverCache Ok(candidates), None => { let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let matching_candidates = self.get_or_cache_matching_candidates(version_set_id)?; - let candidates = self.get_or_cache_candidates(package_name)?; + let matching_candidates = self + .get_or_cache_matching_candidates(version_set_id) + .await?; + let candidates = self.get_or_cache_candidates(package_name).await?; // Sort all the candidates in order in which they should be tried by the solver. let mut sorted_candidates = Vec::new(); @@ -228,7 +234,7 @@ impl> SolverCache Result<&Dependencies, Box> { @@ -242,7 +248,7 @@ impl> SolverCache, + conflicting_clauses: Vec, + negative_assertions: Vec<(SolvableId, ClauseId)>, + clauses_to_watch: Vec, +} + /// Drives the SAT solving process pub struct Solver> { + /// The [Pool] used by the solver + pub pool: Rc>, + pub(crate) async_runtime: tokio::runtime::Runtime, pub(crate) cache: SolverCache, - pub(crate) clauses: Arena, + pub(crate) clauses: RefCell>, requires_clauses: Vec<(SolvableId, VersionSetId, ClauseId)>, watches: WatchMap, @@ -43,8 +56,8 @@ pub struct Solver> learnt_why: Mapping>, learnt_clause_ids: Vec, - clauses_added_for_package: HashSet, - clauses_added_for_solvable: HashSet, + clauses_added_for_package: RefCell>>>, + clauses_added_for_solvable: RefCell>>>, decision_tracker: DecisionTracker, @@ -53,14 +66,25 @@ pub struct Solver> } impl> Solver { - /// Create a solver, using the provided pool - pub fn new(provider: D) -> Self { + /// Create a solver, using the provided pool and async runtime. + /// + /// # Async runtime + /// + /// The solver uses tokio to await the results of async methods in [DependencyProvider]. It will + /// run them concurrently, but blocking the main thread. That means that a single-threaded tokio + /// runtime is usually enough. It is also possible to use a different runtime, as long as you + /// avoid mixing incompatible futures. For details, check out the documentation for the async + /// methods of [DependencyProvider]. + pub fn new(provider: D, async_runtime: tokio::runtime::Runtime) -> Self { + let pool = provider.pool(); Self { cache: SolverCache::new(provider), - clauses: Arena::new(), + pool, + async_runtime, + clauses: RefCell::new(Arena::new()), requires_clauses: Default::default(), watches: WatchMap::new(), - negative_assertions: Vec::new(), + negative_assertions: Default::default(), learnt_clauses: Arena::new(), learnt_why: Mapping::new(), learnt_clause_ids: Vec::new(), @@ -71,9 +95,17 @@ impl> Solver &Pool { - self.cache.pool() + /// Create a solver, using the provided pool and the default async runtime. + /// + /// The default is a single-threaded tokio runtime without any features enabled. If you need + /// something more advanced, consider providing your own runtime through [Self::new]. + pub fn new_with_default_runtime(provider: D) -> Self { + Self::new( + provider, + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(), + ) } } @@ -123,7 +155,7 @@ impl> Sol // The first clause will always be the install root clause. Here we verify that this is // indeed the case. - let root_clause = self.clauses.alloc(ClauseState::root()); + let root_clause = self.clauses.borrow_mut().alloc(ClauseState::root()); assert_eq!(root_clause, ClauseId::install_root()); // Run SAT @@ -145,26 +177,6 @@ impl> Sol Ok(steps) } - /// Adds a clause to the solver and immediately starts watching its literals. - fn add_and_watch_clause(&mut self, clause: ClauseState) -> ClauseId { - let clause_id = self.clauses.alloc(clause); - let clause = &self.clauses[clause_id]; - - // Add in requires clause lookup - if let &Clause::Requires(solvable_id, version_set_id) = &clause.kind { - self.requires_clauses - .push((solvable_id, version_set_id, clause_id)); - } - - // Start watching the literals of the clause - let clause = &mut self.clauses[clause_id]; - if clause.has_watches() { - self.watches.start_watching(clause, clause_id); - } - - clause_id - } - /// Adds clauses for a solvable. These clauses include requirements and constrains on other /// solvables. /// @@ -172,25 +184,36 @@ impl> Sol /// /// If the provider has requested the solving process to be cancelled, the cancellation value /// will be returned as an `Err(...)`. - fn add_clauses_for_solvable( - &mut self, + async fn add_clauses_for_solvable( + &self, solvable_id: SolvableId, - ) -> Result<(Vec, Vec), Box> { - if self.clauses_added_for_solvable.contains(&solvable_id) { - return Ok((Vec::new(), Vec::new())); - } - - let mut new_clauses = Vec::new(); - let mut conflicting_clauses = Vec::new(); + ) -> Result> { + let mut output = AddClauseOutput::default(); let mut queue = vec![solvable_id]; let mut seen = HashSet::new(); seen.insert(solvable_id); while let Some(solvable_id) = queue.pop() { - let solvable = self.pool().resolve_internal_solvable(solvable_id); + let mutex = { + let mut clauses = self.clauses_added_for_solvable.borrow_mut(); + let mutex = clauses + .entry(solvable_id) + .or_insert_with(|| Rc::new(tokio::sync::Mutex::new(false))); + mutex.clone() + }; + + // This prevents concurrent requests to add clauses for a solvable from racing. Only the + // first request for that solvable will go through, and others will wait till it + // completes. + let mut clauses_added = mutex.lock().await; + if *clauses_added { + continue; + } + + let solvable = self.pool.resolve_internal_solvable(solvable_id); tracing::trace!( "┝━ adding clauses for dependencies of {}", - solvable.display(self.pool()) + solvable.display(&self.pool) ); // Determine the dependencies of the current solvable. There are two cases here: @@ -200,7 +223,7 @@ impl> Sol let (requirements, constrains) = match solvable.inner { SolvableInner::Root => (self.root_requirements.clone(), Vec::new()), SolvableInner::Package(_) => { - let deps = self.cache.get_or_cache_dependencies(solvable_id)?; + let deps = self.cache.get_or_cache_dependencies(solvable_id).await?; match deps { Dependencies::Known(deps) => { (deps.requirements.clone(), deps.constrains.clone()) @@ -210,15 +233,15 @@ impl> Sol // an exclusion clause for it let clause_id = self .clauses + .borrow_mut() .alloc(ClauseState::exclude(solvable_id, *reason)); // Exclusions are negative assertions, tracked outside of the watcher system - self.negative_assertions.push((solvable_id, clause_id)); - - new_clauses.push(clause_id); + output.negative_assertions.push((solvable_id, clause_id)); + // There might be a conflict now if self.decision_tracker.assigned_value(solvable_id) == Some(true) { - conflicting_clauses.push(clause_id); + output.conflicting_clauses.push(clause_id); } continue; @@ -229,18 +252,25 @@ impl> Sol // Add clauses for the requirements for version_set_id in requirements { - let dependency_name = self.pool().resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package(dependency_name)?; + let dependency_name = self.pool.resolve_version_set_package_name(version_set_id); + self.add_clauses_for_package( + &mut output.negative_assertions, + &mut output.clauses_to_watch, + dependency_name, + ) + .await?; // Find all the solvables that match for the given version set - let candidates = self.cache.get_or_cache_sorted_candidates(version_set_id)?; + let candidates = self + .cache + .get_or_cache_sorted_candidates(version_set_id) + .await?; // Queue requesting the dependencies of the candidates as well if they are cheaply // available from the dependency provider. for &candidate in candidates { if seen.insert(candidate) && self.cache.are_dependencies_available_for(candidate) - && !self.clauses_added_for_solvable.contains(&candidate) { queue.push(candidate); } @@ -255,27 +285,44 @@ impl> Sol &self.decision_tracker, ); - let clause_id = self.add_and_watch_clause(clause); + let clause_id = self.clauses.borrow_mut().alloc(clause); + let clause = &self.clauses.borrow()[clause_id]; + + let &Clause::Requires(solvable_id, version_set_id) = &clause.kind else { + unreachable!(); + }; + + if clause.has_watches() { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((solvable_id, version_set_id, clause_id)); if conflict { - conflicting_clauses.push(clause_id); + output.conflicting_clauses.push(clause_id); } else if no_candidates { // Add assertions for unit clauses (i.e. those with no matching candidates) - self.negative_assertions.push((solvable_id, clause_id)); + output.negative_assertions.push((solvable_id, clause_id)); } - - new_clauses.push(clause_id); } // Add clauses for the constraints for version_set_id in constrains { - let dependency_name = self.pool().resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package(dependency_name)?; + let dependency_name = self.pool.resolve_version_set_package_name(version_set_id); + self.add_clauses_for_package( + &mut output.negative_assertions, + &mut output.clauses_to_watch, + dependency_name, + ) + .await?; // Find all the solvables that match for the given version set let constrained_candidates = self .cache - .get_or_cache_non_matching_candidates(version_set_id)?; + .get_or_cache_non_matching_candidates(version_set_id) + .await?; // Add forbidden clauses for the candidates for forbidden_candidate in constrained_candidates.iter().copied().collect_vec() { @@ -286,21 +333,19 @@ impl> Sol &self.decision_tracker, ); - let clause_id = self.add_and_watch_clause(clause); + let clause_id = self.clauses.borrow_mut().alloc(clause); + output.clauses_to_watch.push(clause_id); if conflict { - conflicting_clauses.push(clause_id); + output.conflicting_clauses.push(clause_id); } - - new_clauses.push(clause_id) } } - // Start by stating the clauses have been added. - self.clauses_added_for_solvable.insert(solvable_id); + *clauses_added = true; } - Ok((new_clauses, conflicting_clauses)) + Ok(output) } /// Adds all clauses for a specific package name. @@ -319,17 +364,33 @@ impl> Sol /// /// If the provider has requested the solving process to be cancelled, the cancellation value /// will be returned as an `Err(...)`. - fn add_clauses_for_package(&mut self, package_name: NameId) -> Result<(), Box> { - if self.clauses_added_for_package.contains(&package_name) { + async fn add_clauses_for_package( + &self, + negative_assertions: &mut Vec<(SolvableId, ClauseId)>, + clauses_to_watch: &mut Vec, + package_name: NameId, + ) -> Result<(), Box> { + let mutex = { + let mut clauses = self.clauses_added_for_package.borrow_mut(); + let mutex = clauses + .entry(package_name) + .or_insert_with(|| Rc::new(tokio::sync::Mutex::new(false))); + mutex.clone() + }; + + // This prevents concurrent calls to `add_clauses_for_package` from racing. Only the first + // call for a given package will go through, and others will wait till it completes. + let mut clauses_added = mutex.lock().await; + if *clauses_added { return Ok(()); } tracing::trace!( "┝━ adding clauses for package '{}'", - self.pool().resolve_package_name(package_name) + self.pool.resolve_package_name(package_name) ); - let package_candidates = self.cache.get_or_cache_candidates(package_name)?; + let package_candidates = self.cache.get_or_cache_candidates(package_name).await?; let locked_solvable_id = package_candidates.locked; let candidates = &package_candidates.candidates; @@ -346,11 +407,11 @@ impl> Sol for &other_candidate in &candidates[i + 1..] { let clause_id = self .clauses + .borrow_mut() .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); - let clause = &mut self.clauses[clause_id]; - debug_assert!(clause.has_watches()); - self.watches.start_watching(clause, clause_id); + debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); + clauses_to_watch.push(clause_id); } } @@ -360,28 +421,30 @@ impl> Sol if other_candidate != locked_solvable_id { let clause_id = self .clauses + .borrow_mut() .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); - let clause = &mut self.clauses[clause_id]; - - debug_assert!(clause.has_watches()); - self.watches.start_watching(clause, clause_id); + debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); + clauses_to_watch.push(clause_id); } } } // Add a clause for solvables that are externally excluded. for (solvable, reason) in package_candidates.excluded.iter().copied() { - let clause_id = self.clauses.alloc(ClauseState::exclude(solvable, reason)); + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::exclude(solvable, reason)); // Exclusions are negative assertions, tracked outside of the watcher system - self.negative_assertions.push((solvable, clause_id)); + negative_assertions.push((solvable, clause_id)); // Conflicts should be impossible here debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true)); } - self.clauses_added_for_package.insert(package_name); + *clauses_added = true; Ok(()) } @@ -410,7 +473,6 @@ impl> Sol assert!(self.decision_tracker.is_empty()); let mut level = 0; - let mut new_clauses = Vec::new(); loop { // A level of 0 means the decision loop has been completely reset because a partial // solution was invalidated by newly added clauses. @@ -424,7 +486,7 @@ impl> Sol // solution that satisfies the user requirements. tracing::info!( "╤══ install {} at level {level}", - SolvableId::root().display(self.pool()) + SolvableId::root().display(&self.pool) ); self.decision_tracker .try_add_decision( @@ -434,14 +496,14 @@ impl> Sol .expect("already decided"); // Add the clauses for the root solvable. - let (mut clauses, conflicting_clauses) = - self.add_clauses_for_solvable(SolvableId::root())?; - if let Some(clause_id) = conflicting_clauses.into_iter().next() { + let output = self + .async_runtime + .block_on(self.add_clauses_for_solvable(SolvableId::root()))?; + if let Err(clause_id) = self.process_add_clause_output(output) { return Err(UnsolvableOrCancelled::Unsolvable( self.analyze_unsolvable(clause_id), )); } - new_clauses.append(&mut clauses); } // Propagate decisions from assignments above @@ -459,7 +521,7 @@ impl> Sol // The conflict was caused because new clauses have been added dynamically. // We need to start over. tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", - clause=self.clauses[clause_id].debug(self.pool())); + clause=self.clauses.borrow()[clause_id].debug(&self.pool)); level = 0; self.decision_tracker.clear(); continue; @@ -486,7 +548,12 @@ impl> Sol // Filter only decisions that led to a positive assignment .filter(|d| d.value) // Select solvables for which we do not yet have dependencies - .filter(|d| !self.clauses_added_for_solvable.contains(&d.solvable_id)) + .filter(|d| { + !self + .clauses_added_for_solvable + .borrow() + .contains_key(&d.solvable_id) + }) .map(|d| (d.solvable_id, d.derived_from)) .collect(); @@ -502,29 +569,67 @@ impl> Sol .copied() .format_with("\n- ", |(id, derived_from), f| f(&format_args!( "{} (derived from {:?})", - id.display(self.pool()), - self.clauses[derived_from].debug(self.pool()), + id.display(&self.pool), + self.clauses.borrow()[derived_from].debug(&self.pool), ))) ); - for (solvable, _) in new_solvables { - // Add the clauses for this particular solvable. - let (mut clauses_for_solvable, conflicting_causes) = - self.add_clauses_for_solvable(solvable)?; - new_clauses.append(&mut clauses_for_solvable); - - for &clause_id in &conflicting_causes { - // Backtrack in the case of conflicts + // Concurrently get the solvable's clauses + let async_outputs = new_solvables.iter().map(|(solvable, _)| async { + let output = self.add_clauses_for_solvable(*solvable).await?; + Ok::<_, Box>(output) + }); + let outputs = self + .async_runtime + .block_on(futures::future::join_all(async_outputs)); + + // Serially process the outputs, to reduce the need for synchronization + let mut reset_solver = false; + for output in outputs { + let output = output?; + for &clause_id in &output.conflicting_clauses { tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", - clause=self.clauses[clause_id].debug(self.pool())); + clause=self.clauses.borrow()[clause_id].debug(&self.pool)); } - if !conflicting_causes.is_empty() { - self.decision_tracker.clear(); - level = 0; + if let Err(_first_conflicting_clause_id) = self.process_add_clause_output(output) { + // There is a conflict, so make sure we backtrack + reset_solver = true; + + // We still need to process the output from other tasks, because they might add + // more clauses + continue; } } + + if reset_solver { + self.decision_tracker.clear(); + level = 0; + } + } + } + + fn process_add_clause_output(&mut self, mut output: AddClauseOutput) -> Result<(), ClauseId> { + let mut clauses = self.clauses.borrow_mut(); + for clause_id in output.clauses_to_watch { + debug_assert!( + clauses[clause_id].has_watches(), + "attempting to watch a clause without watches!" + ); + self.watches + .start_watching(&mut clauses[clause_id], clause_id); + } + + self.requires_clauses + .append(&mut output.new_requires_clauses); + self.negative_assertions + .append(&mut output.negative_assertions); + + if let Some(&clause_id) = output.conflicting_clauses.first() { + return Err(clause_id); } + + Ok(()) } /// Resolves all dependencies @@ -557,7 +662,7 @@ impl> Sol /// ensures that if there are conflicts they are delt with as early as possible. fn decide(&mut self) -> Option<(SolvableId, SolvableId, ClauseId)> { let mut best_decision = None; - for &(solvable_id, deps, clause_id) in self.requires_clauses.iter() { + for &(solvable_id, deps, clause_id) in &self.requires_clauses { // Consider only clauses in which we have decided to install the solvable if self.decision_tracker.assigned_value(solvable_id) != Some(true) { continue; @@ -605,8 +710,8 @@ impl> Sol if let Some((count, (candidate, _solvable_id, clause_id))) = best_decision { tracing::info!( "deciding to assign {}, ({:?}, {} possible candidates)", - candidate.display(self.pool()), - self.clauses[clause_id].debug(self.pool()), + candidate.display(&self.pool), + self.clauses.borrow()[clause_id].debug(&self.pool), count, ); } @@ -637,8 +742,8 @@ impl> Sol tracing::info!( "╤══ Install {} at level {level} (required by {})", - solvable.display(self.pool()), - required_by.display(self.pool()), + solvable.display(&self.pool), + required_by.display(&self.pool), ); // Add the decision to the tracker @@ -688,28 +793,28 @@ impl> Sol { tracing::info!( "├─ Propagation conflicted: could not set {solvable} to {attempted_value}", - solvable = conflicting_solvable.display(self.pool()) + solvable = conflicting_solvable.display(&self.pool) ); tracing::info!( "│ During unit propagation for clause: {:?}", - self.clauses[conflicting_clause].debug(self.pool()) + self.clauses.borrow()[conflicting_clause].debug(&self.pool) ); tracing::info!( "│ Previously decided value: {}. Derived from: {:?}", !attempted_value, - self.clauses[self + self.clauses.borrow()[self .decision_tracker .find_clause_for_assignment(conflicting_solvable) .unwrap()] - .debug(self.pool()), + .debug(&self.pool), ); } if level == 1 { tracing::info!("╘══ UNSOLVABLE"); for decision in self.decision_tracker.stack() { - let clause = &self.clauses[decision.derived_from]; + let clause = &self.clauses.borrow()[decision.derived_from]; let level = self.decision_tracker.level(decision.solvable_id); let action = if decision.value { "install" } else { "forbid" }; @@ -720,8 +825,8 @@ impl> Sol tracing::info!( "* ({level}) {action} {}. Reason: {:?}", - decision.solvable_id.display(self.pool()), - clause.debug(self.pool()), + decision.solvable_id.display(&self.pool), + clause.debug(&self.pool), ); } @@ -744,7 +849,7 @@ impl> Sol .expect("bug: solvable was already decided!"); tracing::debug!( "├─ Propagate after learn: {} = {decision}", - literal.solvable_id.display(self.pool()) + literal.solvable_id.display(&self.pool) ); Ok(level) @@ -774,7 +879,7 @@ impl> Sol if decided { tracing::trace!( "├─ Propagate assertion {} = {}", - solvable_id.display(self.pool()), + solvable_id.display(&self.pool), value ); } @@ -783,7 +888,7 @@ impl> Sol // Assertions derived from learnt rules for learn_clause_idx in 0..self.learnt_clause_ids.len() { let clause_id = self.learnt_clause_ids[learn_clause_idx]; - let clause = &self.clauses[clause_id]; + let clause = &self.clauses.borrow()[clause_id]; let Clause::Learnt(learnt_index) = clause.kind else { unreachable!(); }; @@ -811,7 +916,7 @@ impl> Sol if decided { tracing::trace!( "├─ Propagate assertion {} = {}", - literal.solvable_id.display(self.pool()), + literal.solvable_id.display(&self.pool), decision ); } @@ -831,13 +936,14 @@ impl> Sol } // Get mutable access to both clauses. + let mut clauses = self.clauses.borrow_mut(); let (predecessor_clause, clause) = if let Some(prev_clause_id) = predecessor_clause_id { let (predecessor_clause, clause) = - self.clauses.get_two_mut(prev_clause_id, clause_id); + clauses.get_two_mut(prev_clause_id, clause_id); (Some(predecessor_clause), clause) } else { - (None, &mut self.clauses[clause_id]) + (None, &mut clauses[clause_id]) }; // Update the prev_clause_id for the next run @@ -909,9 +1015,9 @@ impl> Sol _ => { tracing::debug!( "├─ Propagate {} = {}. {:?}", - remaining_watch.solvable_id.display(self.cache.pool()), + remaining_watch.solvable_id.display(&self.cache.pool()), remaining_watch.satisfying_value(), - clause.debug(self.cache.pool()), + clause.debug(&self.cache.pool()), ); } } @@ -964,7 +1070,7 @@ impl> Sol tracing::info!("=== ANALYZE UNSOLVABLE"); let mut involved = HashSet::new(); - self.clauses[clause_id].kind.visit_literals( + self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -974,7 +1080,7 @@ impl> Sol let mut seen = HashSet::new(); Self::analyze_unsolvable_clause( - &self.clauses, + &self.clauses.borrow(), &self.learnt_why, clause_id, &mut problem, @@ -995,14 +1101,14 @@ impl> Sol assert_ne!(why, ClauseId::install_root()); Self::analyze_unsolvable_clause( - &self.clauses, + &self.clauses.borrow(), &self.learnt_why, why, &mut problem, &mut seen, ); - self.clauses[why].kind.visit_literals( + self.clauses.borrow()[why].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -1046,7 +1152,7 @@ impl> Sol loop { learnt_why.push(clause_id); - self.clauses[clause_id].kind.visit_literals( + self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -1115,10 +1221,13 @@ impl> Sol let learnt_id = self.learnt_clauses.alloc(learnt.clone()); self.learnt_why.insert(learnt_id, learnt_why); - let clause_id = self.clauses.alloc(ClauseState::learnt(learnt_id, &learnt)); + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::learnt(learnt_id, &learnt)); self.learnt_clause_ids.push(clause_id); - let clause = &mut self.clauses[clause_id]; + let clause = &mut self.clauses.borrow_mut()[clause_id]; if clause.has_watches() { self.watches.start_watching(clause, clause_id); } @@ -1128,7 +1237,7 @@ impl> Sol tracing::debug!( "│ - {}{}", if lit.negate { "NOT " } else { "" }, - lit.solvable_id.display(self.pool()) + lit.solvable_id.display(&self.pool) ); } diff --git a/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap b/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap new file mode 100644 index 0000000..5365a68 --- /dev/null +++ b/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap @@ -0,0 +1,8 @@ +--- +source: tests/solver.rs +expression: result +--- +child1=3 +child2=2 +parent=4 + diff --git a/tests/solver.rs b/tests/solver.rs index e9fc230..d9b95ee 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -5,6 +5,10 @@ use resolvo::{ KnownDependencies, NameId, Pool, SolvableId, Solver, SolverCache, UnsolvableOrCancelled, VersionSet, VersionSetId, }; +use std::rc::Rc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; use std::{ any::Any, cell::Cell, @@ -139,12 +143,16 @@ impl FromStr for Spec { /// This provides sorting functionality for our `BundleBox` packaging system #[derive(Default)] struct BundleBoxProvider { - pool: Pool>, + pool: Rc>>, packages: IndexMap>, favored: HashMap, locked: HashMap, excluded: HashMap>, cancel_solving: Cell, + // TODO: simplify? + concurrent_requests: Arc, + concurrent_requests_max: Rc>, + sleep_before_return: bool, } struct BundleBoxPackageDependencies { @@ -224,11 +232,23 @@ impl BundleBoxProvider { }, ); } + + // Sends a value from the dependency provider to the solver, introducing a minimal delay to force + // concurrency to be used (unless there is no async runtime available) + async fn maybe_delay(&self, value: T) -> T { + if self.sleep_before_return { + tokio::time::sleep(Duration::from_millis(10)).await; + self.concurrent_requests.fetch_sub(1, Ordering::SeqCst); + return value; + } else { + value + } + } } impl DependencyProvider> for BundleBoxProvider { - fn pool(&self) -> &Pool> { - &self.pool + fn pool(&self) -> Rc>> { + self.pool.clone() } fn sort_candidates( @@ -244,9 +264,18 @@ impl DependencyProvider> for BundleBoxProvider { }); } - fn get_candidates(&self, name: NameId) -> Option { + async fn get_candidates(&self, name: NameId) -> Option { + let concurrent_requests = self.concurrent_requests.fetch_add(1, Ordering::SeqCst); + self.concurrent_requests_max.set( + self.concurrent_requests_max + .get() + .max(concurrent_requests + 1), + ); + let package_name = self.pool.resolve_package_name(name); - let package = self.packages.get(package_name)?; + let Some(package) = self.packages.get(package_name) else { + return self.maybe_delay(None).await; + }; let mut candidates = Candidates { candidates: Vec::with_capacity(package.len()), @@ -271,10 +300,17 @@ impl DependencyProvider> for BundleBoxProvider { } } - Some(candidates) + self.maybe_delay(Some(candidates)).await } - fn get_dependencies(&self, solvable: SolvableId) -> Dependencies { + async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies { + let concurrent_requests = self.concurrent_requests.fetch_add(1, Ordering::SeqCst); + self.concurrent_requests_max.set( + self.concurrent_requests_max + .get() + .max(concurrent_requests + 1), + ); + let candidate = self.pool.resolve_solvable(solvable); let package_name = self.pool.resolve_package_name(candidate.name_id()); let pack = candidate.inner(); @@ -282,16 +318,18 @@ impl DependencyProvider> for BundleBoxProvider { if pack.cancel_during_get_dependencies { self.cancel_solving.set(true); let reason = self.pool.intern_string("cancelled"); - return Dependencies::Unknown(reason); + return self.maybe_delay(Dependencies::Unknown(reason)).await; } if pack.unknown_deps { let reason = self.pool.intern_string("could not retrieve deps"); - return Dependencies::Unknown(reason); + return self.maybe_delay(Dependencies::Unknown(reason)).await; } let Some(deps) = self.packages.get(package_name).and_then(|v| v.get(pack)) else { - return Dependencies::Known(Default::default()); + return self + .maybe_delay(Dependencies::Known(Default::default())) + .await; }; let mut result = KnownDependencies { @@ -310,7 +348,7 @@ impl DependencyProvider> for BundleBoxProvider { result.constrains.push(dep_spec); } - Dependencies::Known(result) + self.maybe_delay(Dependencies::Known(result)).await } fn should_cancel_with_value(&self) -> Option> { @@ -341,7 +379,8 @@ fn transaction_to_string(pool: &Pool, solvables: &Vec String { let requirements = provider.requirements(specs); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); match solver.solve(requirements) { Ok(_) => panic!("expected unsat, but a solution was found"), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { @@ -349,12 +388,12 @@ fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { let graph = problem.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); - graph.graphviz(&mut output, solver.pool(), true).unwrap(); + graph.graphviz(&mut output, &pool, true).unwrap(); writeln!(output, "\n").unwrap(); // Format a user friendly error message problem - .display_user_friendly(&solver, &DefaultSolvableDisplay) + .display_user_friendly(&solver, pool, &DefaultSolvableDisplay) .to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), @@ -362,22 +401,31 @@ fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { } /// Solve the problem and returns either a solution represented as a string or an error string. -fn solve_snapshot(provider: BundleBoxProvider, specs: &[&str]) -> String { +fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { + // The test dependency provider requires time support for sleeping + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + + provider.sleep_before_return = true; + let requirements = provider.requirements(specs); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new(provider, runtime); match solver.solve(requirements) { - Ok(solvables) => transaction_to_string(solver.pool(), &solvables), + Ok(solvables) => transaction_to_string(&pool, &solvables), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { // Write the problem graphviz to stderr let graph = problem.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); - graph.graphviz(&mut output, solver.pool(), true).unwrap(); + graph.graphviz(&mut output, &pool, true).unwrap(); writeln!(output, "\n").unwrap(); // Format a user friendly error message problem - .display_user_friendly(&solver, &DefaultSolvableDisplay) + .display_user_friendly(&solver, pool, &DefaultSolvableDisplay) .to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), @@ -389,16 +437,14 @@ fn solve_snapshot(provider: BundleBoxProvider, specs: &[&str]) -> String { fn test_unit_propagation_1() { let provider = BundleBoxProvider::from_packages(&[("asdf", 1, vec![])]); let root_requirements = provider.requirements(&["asdf"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(root_requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 1); } @@ -411,25 +457,20 @@ fn test_unit_propagation_nested() { ("dummy", 6u32, vec![]), ]); let requirements = provider.requirements(&["asdf"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 2); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 1); - let solvable = solver.pool().resolve_solvable(solved[1]); + let solvable = pool.resolve_solvable(solved[1]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "efgh" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "efgh"); assert_eq!(solvable.inner().version, 4); } @@ -443,28 +484,39 @@ fn test_resolve_multiple() { ("efgh", 5, vec![]), ]); let requirements = provider.requirements(&["asdf", "efgh"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 2); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 2); - let solvable = solver.pool().resolve_solvable(solved[1]); + let solvable = pool.resolve_solvable(solved[1]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "efgh" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "efgh"); assert_eq!(solvable.inner().version, 5); } +#[test] +fn test_resolve_with_concurrent_metadata_fetching() { + let provider = BundleBoxProvider::from_packages(&[ + ("parent", 4, vec!["child1", "child2"]), + ("child1", 3, vec![]), + ("child2", 2, vec![]), + ]); + + let max_concurrent_requests = provider.concurrent_requests_max.clone(); + + let result = solve_snapshot(provider, &["parent"]); + insta::assert_snapshot!(result); + + assert_eq!(2, max_concurrent_requests.get()); +} + /// In case of a conflict the version should not be selected with the conflict #[test] fn test_resolve_with_conflict() { @@ -490,17 +542,15 @@ fn test_resolve_with_nonexisting() { ("b", 1, vec!["idontexist"]), ]); let requirements = provider.requirements(&["asdf"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 3); } @@ -526,15 +576,16 @@ fn test_resolve_with_nested_deps() { ("opentelemetry-grpc", 1, vec!["opentelemetry-api 1"]), ]); let requirements = provider.requirements(&["apache-airflow"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), + pool.resolve_package_name(solvable.name_id()), "apache-airflow" ); assert_eq!(solvable.inner().version, 1); @@ -552,15 +603,16 @@ fn test_resolve_with_unknown_deps() { ); provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), + pool.resolve_package_name(solvable.name_id()), "opentelemetry-api" ); assert_eq!(solvable.inner().version, 2); @@ -596,15 +648,13 @@ fn test_resolve_locked_top_level() { let requirements = provider.requirements(&["asdf"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); let solvable_id = solved[0]; - assert_eq!( - solver.pool().resolve_solvable(solvable_id).inner().version, - 3 - ); + assert_eq!(pool.resolve_solvable(solvable_id).inner().version, 3); } /// Should ignore lock when it is not a top level package and a newer version exists without it @@ -619,16 +669,14 @@ fn test_resolve_ignored_locked_top_level() { provider.set_locked("fgh", 1); let requirements = provider.requirements(&["asdf"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 4); } @@ -679,10 +727,11 @@ fn test_resolve_cyclic() { let provider = BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); let requirements = provider.requirements(&["a 0..100"]); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new_with_default_runtime(provider); let solved = solver.solve(requirements).unwrap(); - let result = transaction_to_string(&solver.pool(), &solved); + let result = transaction_to_string(&pool, &solved); insta::assert_snapshot!(result, @r###" a=2 b=5