diff --git a/Cargo.toml b/Cargo.toml index d5a72b5..1341156 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,12 @@ tracing = "0.1.37" elsa = "1.9.0" bitvec = "1.0.1" serde = { version = "1.0", features = ["derive"], optional = true } +smol = "2.0.0" [dev-dependencies] insta = "1.31.0" indexmap = "2.0.0" proptest = "1.2.0" tracing-test = { version = "0.2.4", features = ["no-env-filter"] } +static_assertions = "1.1.0" +tokio = { version = "1.35.1", features = ["rt-multi-thread", "time"] } diff --git a/src/lib.rs b/src/lib.rs index 5985145..a0f2281 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,6 +57,27 @@ pub trait VersionSet: Debug + Display + Clone + Eq + Hash { fn contains(&self, v: &Self::V) -> bool; } +/// Used to send a value, possibly asynchronously, from the dependency provider to the solver. +pub struct OneShotSender { + tx: Option>, +} + +impl OneShotSender { + pub(crate) fn new() -> (Self, smol::channel::Receiver) { + let (tx, receiver) = smol::channel::unbounded(); + (Self { tx: Some(tx) }, receiver) + } + + /// Send a value, possibly asynchronously, from the dependency provider to the solver. + pub fn send(mut self, value: T) -> Result<(), T> { + self.tx + .take() + .unwrap() + .send_blocking(value) + .map_err(|error| error.0) + } +} + /// Defines implementation specific behavior for the solver and a way for the solver to access the /// packages that are available in the system. pub trait DependencyProvider: Sized { @@ -68,14 +89,28 @@ pub trait DependencyProvider: Sized { /// version the next version is tried. This continues until a solution is found. fn sort_candidates(&self, solver: &SolverCache, solvables: &mut [SolvableId]); - /// Returns a list of solvables that should be considered when a package with the given name is + /// Obtains a list of solvables that should be considered when a package with the given name is /// requested. /// - /// Returns `None` if no such package exist. - fn get_candidates(&self, name: NameId) -> Option; + /// The result should be submitted through the provided sender, which gives the trait + /// implementor the freedom to obtain the metadata in an asynchronous way if desired. In that + /// case, the trait implementor is responsible for rate-limiting requests (e.g. by using an + /// internal semaphore) and for triggering timeouts when they take too long. + /// + /// Important: if no package exists, you still need to send a value (`None`), to notify the + /// solver. Otherwise the solver will crash when [OneShotSender] is dropped. + fn get_candidates(&self, name: NameId, sender: OneShotSender>); /// Returns the dependencies for the specified solvable. - fn get_dependencies(&self, solvable: SolvableId) -> Dependencies; + /// + /// The result should be submitted through the provided sender, which gives the trait + /// implementor the freedom to obtain the metadata in an asynchronous way if desired. In that + /// case, the trait implementor is responsible for rate-limiting requests (e.g. using an + /// internal semaphore) and for triggering timeouts when they take too long. + /// + /// Important: make sure to always send a value. Otherwise the solver will crash when + /// [OneShotSender] is dropped. + fn get_dependencies(&self, solvable: SolvableId, sender: OneShotSender); /// Whether the solver should stop the dependency resolution algorithm. /// @@ -180,3 +215,12 @@ where .join(" | ") } } + +#[cfg(test)] +mod test { + use super::*; + use static_assertions::assert_impl_any; + + assert_impl_any!(OneShotSender>: Send, Sync); + assert_impl_any!(OneShotSender: Send, Sync); +} diff --git a/src/problem.rs b/src/problem.rs index 311c7bd..d1116a8 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -52,7 +52,7 @@ impl Problem { let unresolved_node = graph.add_node(ProblemNode::UnresolvedDependency); for clause_id in &self.clauses { - let clause = &solver.clauses[*clause_id].kind; + let clause = &solver.clauses.borrow()[*clause_id].kind; match clause { Clause::InstallRoot => (), Clause::Excluded(solvable, reason) => { @@ -65,7 +65,7 @@ impl Problem { &Clause::Requires(package_id, version_set_id) => { let package_node = Self::add_node(&mut graph, &mut nodes, package_id); - let candidates = solver.cache.get_or_cache_sorted_candidates(version_set_id).unwrap_or_else(|_| { + let candidates = smol::block_on(solver.cache.get_or_cache_sorted_candidates(version_set_id)).unwrap_or_else(|_| { unreachable!("The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`") }); if candidates.is_empty() { diff --git a/src/solver/cache.rs b/src/solver/cache.rs index 8d25d3d..a844f9f 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -5,8 +5,8 @@ use crate::{ frozen_copy_map::FrozenCopyMap, id::{CandidatesId, DependenciesId}, }, - Candidates, Dependencies, DependencyProvider, NameId, PackageName, Pool, SolvableId, - VersionSet, VersionSetId, + Candidates, Dependencies, DependencyProvider, NameId, OneShotSender, PackageName, Pool, + SolvableId, VersionSet, VersionSetId, }; use bitvec::vec::BitVec; use elsa::FrozenMap; @@ -74,7 +74,7 @@ impl> SolverCache Result<&Candidates, Box> { @@ -90,9 +90,12 @@ impl> SolverCache> SolverCache Result<&[SolvableId], Box> { @@ -135,7 +138,7 @@ impl> SolverCache { let package_name = self.pool().resolve_version_set_package_name(version_set_id); let version_set = self.pool().resolve_version_set(version_set_id); - let candidates = self.get_or_cache_candidates(package_name)?; + let candidates = self.get_or_cache_candidates(package_name).await?; let matching_candidates = candidates .candidates @@ -158,7 +161,7 @@ impl> SolverCache Result<&[SolvableId], Box> { @@ -167,7 +170,7 @@ impl> SolverCache { let package_name = self.pool().resolve_version_set_package_name(version_set_id); let version_set = self.pool().resolve_version_set(version_set_id); - let candidates = self.get_or_cache_candidates(package_name)?; + let candidates = self.get_or_cache_candidates(package_name).await?; let matching_candidates = candidates .candidates @@ -191,7 +194,7 @@ impl> SolverCache Result<&[SolvableId], Box> { @@ -199,8 +202,10 @@ impl> SolverCache Ok(candidates), None => { let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let matching_candidates = self.get_or_cache_matching_candidates(version_set_id)?; - let candidates = self.get_or_cache_candidates(package_name)?; + let matching_candidates = self + .get_or_cache_matching_candidates(version_set_id) + .await?; + let candidates = self.get_or_cache_candidates(package_name).await?; // Sort all the candidates in order in which they should be tried by the solver. let mut sorted_candidates = Vec::new(); @@ -228,7 +233,7 @@ impl> SolverCache Result<&Dependencies, Box> { @@ -242,7 +247,12 @@ impl> SolverCache, + conflicting_clauses: Vec, + negative_assertions: Vec<(SolvableId, ClauseId)>, + clauses_to_watch: Vec, +} + /// Drives the SAT solving process pub struct Solver> { pub(crate) cache: SolverCache, - pub(crate) clauses: Arena, + pub(crate) clauses: RefCell>, requires_clauses: Vec<(SolvableId, VersionSetId, ClauseId)>, watches: WatchMap, @@ -43,8 +53,8 @@ pub struct Solver> learnt_why: Mapping>, learnt_clause_ids: Vec, - clauses_added_for_package: HashSet, - clauses_added_for_solvable: HashSet, + clauses_added_for_package: RefCell>>>, + clauses_added_for_solvable: RefCell>>>, decision_tracker: DecisionTracker, @@ -57,10 +67,10 @@ impl> Solver Self { Self { cache: SolverCache::new(provider), - clauses: Arena::new(), + clauses: RefCell::new(Arena::new()), requires_clauses: Default::default(), watches: WatchMap::new(), - negative_assertions: Vec::new(), + negative_assertions: Default::default(), learnt_clauses: Arena::new(), learnt_why: Mapping::new(), learnt_clause_ids: Vec::new(), @@ -123,7 +133,7 @@ impl> Sol // The first clause will always be the install root clause. Here we verify that this is // indeed the case. - let root_clause = self.clauses.alloc(ClauseState::root()); + let root_clause = self.clauses.borrow_mut().alloc(ClauseState::root()); assert_eq!(root_clause, ClauseId::install_root()); // Run SAT @@ -145,26 +155,6 @@ impl> Sol Ok(steps) } - /// Adds a clause to the solver and immediately starts watching its literals. - fn add_and_watch_clause(&mut self, clause: ClauseState) -> ClauseId { - let clause_id = self.clauses.alloc(clause); - let clause = &self.clauses[clause_id]; - - // Add in requires clause lookup - if let &Clause::Requires(solvable_id, version_set_id) = &clause.kind { - self.requires_clauses - .push((solvable_id, version_set_id, clause_id)); - } - - // Start watching the literals of the clause - let clause = &mut self.clauses[clause_id]; - if clause.has_watches() { - self.watches.start_watching(clause, clause_id); - } - - clause_id - } - /// Adds clauses for a solvable. These clauses include requirements and constrains on other /// solvables. /// @@ -172,21 +162,32 @@ impl> Sol /// /// If the provider has requested the solving process to be cancelled, the cancellation value /// will be returned as an `Err(...)`. - fn add_clauses_for_solvable( - &mut self, + async fn add_clauses_for_solvable( + &self, solvable_id: SolvableId, - ) -> Result<(Vec, Vec), Box> { - if self.clauses_added_for_solvable.contains(&solvable_id) { - return Ok((Vec::new(), Vec::new())); - } - - let mut new_clauses = Vec::new(); - let mut conflicting_clauses = Vec::new(); + ) -> Result> { + let mut output = AddClauseOutput::default(); let mut queue = vec![solvable_id]; let mut seen = HashSet::new(); seen.insert(solvable_id); while let Some(solvable_id) = queue.pop() { + let mutex = { + let mut clauses = self.clauses_added_for_solvable.borrow_mut(); + let mutex = clauses + .entry(solvable_id) + .or_insert_with(|| Rc::new(smol::lock::Mutex::new(false))); + mutex.clone() + }; + + // This prevents concurrent requests to add clauses for a solvable from racing. Only the + // first request for that solvable will go through, and others will wait till it + // completes. + let mut clauses_added = mutex.lock().await; + if *clauses_added { + continue; + } + let solvable = self.pool().resolve_internal_solvable(solvable_id); tracing::trace!( "┝━ adding clauses for dependencies of {}", @@ -200,7 +201,7 @@ impl> Sol let (requirements, constrains) = match solvable.inner { SolvableInner::Root => (self.root_requirements.clone(), Vec::new()), SolvableInner::Package(_) => { - let deps = self.cache.get_or_cache_dependencies(solvable_id)?; + let deps = self.cache.get_or_cache_dependencies(solvable_id).await?; match deps { Dependencies::Known(deps) => { (deps.requirements.clone(), deps.constrains.clone()) @@ -210,24 +211,18 @@ impl> Sol // an exclusion clause for it let clause_id = self .clauses + .borrow_mut() .alloc(ClauseState::exclude(solvable_id, *reason)); // Exclusions are negative assertions, tracked outside of the watcher system - self.negative_assertions.push((solvable_id, clause_id)); + output.negative_assertions.push((solvable_id, clause_id)); // There might be a conflict now - let conflicts = if self.decision_tracker.assigned_value(solvable_id) - == Some(true) - { - vec![clause_id] - } else { - Vec::new() - }; - - // The new assertion should be kept in all cases (it is returned in the - // lhs of the tuple), and a conflicts should be reported if present (rhs - // of the tuple) - return Ok((vec![clause_id], conflicts)); + if self.decision_tracker.assigned_value(solvable_id) == Some(true) { + output.conflicting_clauses.push(clause_id); + } + + continue; } } } @@ -236,17 +231,24 @@ impl> Sol // Add clauses for the requirements for version_set_id in requirements { let dependency_name = self.pool().resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package(dependency_name)?; + self.add_clauses_for_package( + &mut output.negative_assertions, + &mut output.clauses_to_watch, + dependency_name, + ) + .await?; // Find all the solvables that match for the given version set - let candidates = self.cache.get_or_cache_sorted_candidates(version_set_id)?; + let candidates = self + .cache + .get_or_cache_sorted_candidates(version_set_id) + .await?; // Queue requesting the dependencies of the candidates as well if they are cheaply // available from the dependency provider. for &candidate in candidates { if seen.insert(candidate) && self.cache.are_dependencies_available_for(candidate) - && !self.clauses_added_for_solvable.contains(&candidate) { queue.push(candidate); } @@ -261,27 +263,44 @@ impl> Sol &self.decision_tracker, ); - let clause_id = self.add_and_watch_clause(clause); + let clause_id = self.clauses.borrow_mut().alloc(clause); + let clause = &self.clauses.borrow()[clause_id]; + + let &Clause::Requires(solvable_id, version_set_id) = &clause.kind else { + unreachable!(); + }; + + if clause.has_watches() { + output.clauses_to_watch.push(clause_id); + } + + output + .new_requires_clauses + .push((solvable_id, version_set_id, clause_id)); if conflict { - conflicting_clauses.push(clause_id); + output.conflicting_clauses.push(clause_id); } else if no_candidates { // Add assertions for unit clauses (i.e. those with no matching candidates) - self.negative_assertions.push((solvable_id, clause_id)); + output.negative_assertions.push((solvable_id, clause_id)); } - - new_clauses.push(clause_id); } // Add clauses for the constraints for version_set_id in constrains { let dependency_name = self.pool().resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package(dependency_name)?; + self.add_clauses_for_package( + &mut output.negative_assertions, + &mut output.clauses_to_watch, + dependency_name, + ) + .await?; // Find all the solvables that match for the given version set let constrained_candidates = self .cache - .get_or_cache_non_matching_candidates(version_set_id)?; + .get_or_cache_non_matching_candidates(version_set_id) + .await?; // Add forbidden clauses for the candidates for forbidden_candidate in constrained_candidates.iter().copied().collect_vec() { @@ -292,21 +311,19 @@ impl> Sol &self.decision_tracker, ); - let clause_id = self.add_and_watch_clause(clause); + let clause_id = self.clauses.borrow_mut().alloc(clause); + output.clauses_to_watch.push(clause_id); if conflict { - conflicting_clauses.push(clause_id); + output.conflicting_clauses.push(clause_id); } - - new_clauses.push(clause_id) } } - // Start by stating the clauses have been added. - self.clauses_added_for_solvable.insert(solvable_id); + *clauses_added = true; } - Ok((new_clauses, conflicting_clauses)) + Ok(output) } /// Adds all clauses for a specific package name. @@ -325,8 +342,24 @@ impl> Sol /// /// If the provider has requested the solving process to be cancelled, the cancellation value /// will be returned as an `Err(...)`. - fn add_clauses_for_package(&mut self, package_name: NameId) -> Result<(), Box> { - if self.clauses_added_for_package.contains(&package_name) { + async fn add_clauses_for_package( + &self, + negative_assertions: &mut Vec<(SolvableId, ClauseId)>, + clauses_to_watch: &mut Vec, + package_name: NameId, + ) -> Result<(), Box> { + let mutex = { + let mut clauses = self.clauses_added_for_package.borrow_mut(); + let mutex = clauses + .entry(package_name) + .or_insert_with(|| Rc::new(smol::lock::Mutex::new(false))); + mutex.clone() + }; + + // This prevents concurrent calls to `add_clauses_for_package` from racing. Only the first + // call for a given package will go through, and others will wait till it completes. + let mut clauses_added = mutex.lock().await; + if *clauses_added { return Ok(()); } @@ -335,7 +368,7 @@ impl> Sol self.pool().resolve_package_name(package_name) ); - let package_candidates = self.cache.get_or_cache_candidates(package_name)?; + let package_candidates = self.cache.get_or_cache_candidates(package_name).await?; let locked_solvable_id = package_candidates.locked; let candidates = &package_candidates.candidates; @@ -352,11 +385,11 @@ impl> Sol for &other_candidate in &candidates[i + 1..] { let clause_id = self .clauses + .borrow_mut() .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); - let clause = &mut self.clauses[clause_id]; - debug_assert!(clause.has_watches()); - self.watches.start_watching(clause, clause_id); + debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); + clauses_to_watch.push(clause_id); } } @@ -366,28 +399,30 @@ impl> Sol if other_candidate != locked_solvable_id { let clause_id = self .clauses + .borrow_mut() .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); - let clause = &mut self.clauses[clause_id]; - - debug_assert!(clause.has_watches()); - self.watches.start_watching(clause, clause_id); + debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); + clauses_to_watch.push(clause_id); } } } // Add a clause for solvables that are externally excluded. for (solvable, reason) in package_candidates.excluded.iter().copied() { - let clause_id = self.clauses.alloc(ClauseState::exclude(solvable, reason)); + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::exclude(solvable, reason)); // Exclusions are negative assertions, tracked outside of the watcher system - self.negative_assertions.push((solvable, clause_id)); + negative_assertions.push((solvable, clause_id)); // Conflicts should be impossible here debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true)); } - self.clauses_added_for_package.insert(package_name); + *clauses_added = true; Ok(()) } @@ -416,7 +451,6 @@ impl> Sol assert!(self.decision_tracker.is_empty()); let mut level = 0; - let mut new_clauses = Vec::new(); loop { // A level of 0 means the decision loop has been completely reset because a partial // solution was invalidated by newly added clauses. @@ -440,14 +474,15 @@ impl> Sol .expect("already decided"); // Add the clauses for the root solvable. - let (mut clauses, conflicting_clauses) = - self.add_clauses_for_solvable(SolvableId::root())?; - if let Some(clause_id) = conflicting_clauses.into_iter().next() { + let executor = smol::LocalExecutor::new(); + let output = smol::block_on( + executor.run(self.add_clauses_for_solvable(SolvableId::root())), + )?; + if let Err(clause_id) = self.process_add_clause_output(output) { return Err(UnsolvableOrCancelled::Unsolvable( self.analyze_unsolvable(clause_id), )); } - new_clauses.append(&mut clauses); } // Propagate decisions from assignments above @@ -465,7 +500,7 @@ impl> Sol // The conflict was caused because new clauses have been added dynamically. // We need to start over. tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", - clause=self.clauses[clause_id].debug(self.pool())); + clause=self.clauses.borrow()[clause_id].debug(self.pool())); level = 0; self.decision_tracker.clear(); continue; @@ -492,7 +527,12 @@ impl> Sol // Filter only decisions that led to a positive assignment .filter(|d| d.value) // Select solvables for which we do not yet have dependencies - .filter(|d| !self.clauses_added_for_solvable.contains(&d.solvable_id)) + .filter(|d| { + !self + .clauses_added_for_solvable + .borrow() + .contains_key(&d.solvable_id) + }) .map(|d| (d.solvable_id, d.derived_from)) .collect(); @@ -509,30 +549,81 @@ impl> Sol .format_with("\n- ", |(id, derived_from), f| f(&format_args!( "{} (derived from {:?})", id.display(self.pool()), - self.clauses[derived_from].debug(self.pool()), + self.clauses.borrow()[derived_from].debug(self.pool()), ))) ); - for (solvable, _) in new_solvables { - // Add the clauses for this particular solvable. - let (mut clauses_for_solvable, conflicting_causes) = - self.add_clauses_for_solvable(solvable)?; - new_clauses.append(&mut clauses_for_solvable); + // Concurrently get the solvable's clauses + let outputs = { + let executor = smol::LocalExecutor::new(); + let async_outputs = new_solvables + .iter() + .map(|(solvable, _)| { + executor.spawn(async { + let output = self.add_clauses_for_solvable(*solvable).await?; + Ok::<_, Box>(output) + }) + }) + .collect::>(); + + let mut outputs = Vec::with_capacity(async_outputs.len()); + smol::block_on(executor.run(async { + for async_output in async_outputs { + outputs.push(async_output.await?); + } - for &clause_id in &conflicting_causes { - // Backtrack in the case of conflicts + Ok::<_, Box>(outputs) + })) + }; + + // Serially process the outputs, to reduce the need for synchronization + let mut reset_solver = false; + for output in outputs? { + for &clause_id in &output.conflicting_clauses { tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", - clause=self.clauses[clause_id].debug(self.pool())); + clause=self.clauses.borrow()[clause_id].debug(self.pool())); } - if !conflicting_causes.is_empty() { - self.decision_tracker.clear(); - level = 0; + if let Err(_first_conflicting_clause_id) = self.process_add_clause_output(output) { + // There is a conflict, so make sure we backtrack + reset_solver = true; + + // We still need to process the output from other tasks, because they might add + // more clauses + continue; } } + + if reset_solver { + self.decision_tracker.clear(); + level = 0; + } } } + fn process_add_clause_output(&mut self, mut output: AddClauseOutput) -> Result<(), ClauseId> { + let mut clauses = self.clauses.borrow_mut(); + for clause_id in output.clauses_to_watch { + debug_assert!( + clauses[clause_id].has_watches(), + "attempting to watch a clause without watches!" + ); + self.watches + .start_watching(&mut clauses[clause_id], clause_id); + } + + self.requires_clauses + .append(&mut output.new_requires_clauses); + self.negative_assertions + .append(&mut output.negative_assertions); + + if let Some(&clause_id) = output.conflicting_clauses.first() { + return Err(clause_id); + } + + Ok(()) + } + /// Resolves all dependencies /// /// Repeatedly chooses the next variable to assign, and calls [`Solver::set_propagate_learn`] to @@ -563,7 +654,7 @@ impl> Sol /// ensures that if there are conflicts they are delt with as early as possible. fn decide(&mut self) -> Option<(SolvableId, SolvableId, ClauseId)> { let mut best_decision = None; - for &(solvable_id, deps, clause_id) in self.requires_clauses.iter() { + for &(solvable_id, deps, clause_id) in &self.requires_clauses { // Consider only clauses in which we have decided to install the solvable if self.decision_tracker.assigned_value(solvable_id) != Some(true) { continue; @@ -612,7 +703,7 @@ impl> Sol tracing::info!( "deciding to assign {}, ({:?}, {} possible candidates)", candidate.display(self.pool()), - self.clauses[clause_id].debug(self.pool()), + self.clauses.borrow()[clause_id].debug(self.pool()), count, ); } @@ -698,13 +789,13 @@ impl> Sol ); tracing::info!( "│ During unit propagation for clause: {:?}", - self.clauses[conflicting_clause].debug(self.pool()) + self.clauses.borrow()[conflicting_clause].debug(self.pool()) ); tracing::info!( "│ Previously decided value: {}. Derived from: {:?}", !attempted_value, - self.clauses[self + self.clauses.borrow()[self .decision_tracker .find_clause_for_assignment(conflicting_solvable) .unwrap()] @@ -715,7 +806,7 @@ impl> Sol if level == 1 { tracing::info!("╘══ UNSOLVABLE"); for decision in self.decision_tracker.stack() { - let clause = &self.clauses[decision.derived_from]; + let clause = &self.clauses.borrow()[decision.derived_from]; let level = self.decision_tracker.level(decision.solvable_id); let action = if decision.value { "install" } else { "forbid" }; @@ -789,7 +880,7 @@ impl> Sol // Assertions derived from learnt rules for learn_clause_idx in 0..self.learnt_clause_ids.len() { let clause_id = self.learnt_clause_ids[learn_clause_idx]; - let clause = &self.clauses[clause_id]; + let clause = &self.clauses.borrow()[clause_id]; let Clause::Learnt(learnt_index) = clause.kind else { unreachable!(); }; @@ -837,13 +928,14 @@ impl> Sol } // Get mutable access to both clauses. + let mut clauses = self.clauses.borrow_mut(); let (predecessor_clause, clause) = if let Some(prev_clause_id) = predecessor_clause_id { let (predecessor_clause, clause) = - self.clauses.get_two_mut(prev_clause_id, clause_id); + clauses.get_two_mut(prev_clause_id, clause_id); (Some(predecessor_clause), clause) } else { - (None, &mut self.clauses[clause_id]) + (None, &mut clauses[clause_id]) }; // Update the prev_clause_id for the next run @@ -970,7 +1062,7 @@ impl> Sol tracing::info!("=== ANALYZE UNSOLVABLE"); let mut involved = HashSet::new(); - self.clauses[clause_id].kind.visit_literals( + self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -980,7 +1072,7 @@ impl> Sol let mut seen = HashSet::new(); Self::analyze_unsolvable_clause( - &self.clauses, + &self.clauses.borrow(), &self.learnt_why, clause_id, &mut problem, @@ -1001,14 +1093,14 @@ impl> Sol assert_ne!(why, ClauseId::install_root()); Self::analyze_unsolvable_clause( - &self.clauses, + &self.clauses.borrow(), &self.learnt_why, why, &mut problem, &mut seen, ); - self.clauses[why].kind.visit_literals( + self.clauses.borrow()[why].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -1052,7 +1144,7 @@ impl> Sol loop { learnt_why.push(clause_id); - self.clauses[clause_id].kind.visit_literals( + self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -1121,10 +1213,13 @@ impl> Sol let learnt_id = self.learnt_clauses.alloc(learnt.clone()); self.learnt_why.insert(learnt_id, learnt_why); - let clause_id = self.clauses.alloc(ClauseState::learnt(learnt_id, &learnt)); + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::learnt(learnt_id, &learnt)); self.learnt_clause_ids.push(clause_id); - let clause = &mut self.clauses[clause_id]; + let clause = &mut self.clauses.borrow_mut()[clause_id]; if clause.has_watches() { self.watches.start_watching(clause, clause_id); } diff --git a/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap b/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap new file mode 100644 index 0000000..5365a68 --- /dev/null +++ b/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap @@ -0,0 +1,8 @@ +--- +source: tests/solver.rs +expression: result +--- +child1=3 +child2=2 +parent=4 + diff --git a/tests/solver.rs b/tests/solver.rs index e9fc230..08278b9 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -2,9 +2,13 @@ use indexmap::IndexMap; use itertools::Itertools; use resolvo::{ range::Range, Candidates, DefaultSolvableDisplay, Dependencies, DependencyProvider, - KnownDependencies, NameId, Pool, SolvableId, Solver, SolverCache, UnsolvableOrCancelled, - VersionSet, VersionSetId, + KnownDependencies, NameId, OneShotSender, Pool, SolvableId, Solver, SolverCache, + UnsolvableOrCancelled, VersionSet, VersionSetId, }; +use std::rc::Rc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; use std::{ any::Any, cell::Cell, @@ -15,6 +19,7 @@ use std::{ num::ParseIntError, str::FromStr, }; +use tokio::runtime::Runtime; use tracing_test::traced_test; // Let's define our own packaging version system and dependency specification. @@ -145,6 +150,9 @@ struct BundleBoxProvider { locked: HashMap, excluded: HashMap>, cancel_solving: Cell, + runtime: Option, + concurrent_requests: Arc, + concurrent_requests_max: Rc>, } struct BundleBoxPackageDependencies { @@ -224,6 +232,25 @@ impl BundleBoxProvider { }, ); } + + // Sends a value from the dependency provider to the solver, introducing a minimal delay to force + // concurrency to be used (unless there is no async runtime available) + fn send_with_delay(&self, sender: OneShotSender, value: T) { + if let Some(runtime) = &self.runtime { + let concurrent_requests = self.concurrent_requests.clone(); + runtime.spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + sender + .send(value) + .unwrap_or_else(|_| panic!("receiver end of channel was dropped")); + concurrent_requests.fetch_sub(1, Ordering::SeqCst); + }); + } else { + sender + .send(value) + .unwrap_or_else(|_| panic!("receiver end of channel was dropped")); + } + } } impl DependencyProvider> for BundleBoxProvider { @@ -244,9 +271,19 @@ impl DependencyProvider> for BundleBoxProvider { }); } - fn get_candidates(&self, name: NameId) -> Option { + fn get_candidates(&self, name: NameId, sender: OneShotSender>) { + let concurrent_requests = self.concurrent_requests.fetch_add(1, Ordering::SeqCst); + self.concurrent_requests_max.set( + self.concurrent_requests_max + .get() + .max(concurrent_requests + 1), + ); + let package_name = self.pool.resolve_package_name(name); - let package = self.packages.get(package_name)?; + let Some(package) = self.packages.get(package_name) else { + self.send_with_delay(sender, None); + return; + }; let mut candidates = Candidates { candidates: Vec::with_capacity(package.len()), @@ -271,10 +308,17 @@ impl DependencyProvider> for BundleBoxProvider { } } - Some(candidates) + self.send_with_delay(sender, Some(candidates)); } - fn get_dependencies(&self, solvable: SolvableId) -> Dependencies { + fn get_dependencies(&self, solvable: SolvableId, sender: OneShotSender) { + let concurrent_requests = self.concurrent_requests.fetch_add(1, Ordering::SeqCst); + self.concurrent_requests_max.set( + self.concurrent_requests_max + .get() + .max(concurrent_requests + 1), + ); + let candidate = self.pool.resolve_solvable(solvable); let package_name = self.pool.resolve_package_name(candidate.name_id()); let pack = candidate.inner(); @@ -282,16 +326,19 @@ impl DependencyProvider> for BundleBoxProvider { if pack.cancel_during_get_dependencies { self.cancel_solving.set(true); let reason = self.pool.intern_string("cancelled"); - return Dependencies::Unknown(reason); + self.send_with_delay(sender, Dependencies::Unknown(reason)); + return; } if pack.unknown_deps { let reason = self.pool.intern_string("could not retrieve deps"); - return Dependencies::Unknown(reason); + self.send_with_delay(sender, Dependencies::Unknown(reason)); + return; } let Some(deps) = self.packages.get(package_name).and_then(|v| v.get(pack)) else { - return Dependencies::Known(Default::default()); + self.send_with_delay(sender, Dependencies::Known(Default::default())); + return; }; let mut result = KnownDependencies { @@ -310,7 +357,7 @@ impl DependencyProvider> for BundleBoxProvider { result.constrains.push(dep_spec); } - Dependencies::Known(result) + self.send_with_delay(sender, Dependencies::Known(result)); } fn should_cancel_with_value(&self) -> Option> { @@ -362,7 +409,15 @@ fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { } /// Solve the problem and returns either a solution represented as a string or an error string. -fn solve_snapshot(provider: BundleBoxProvider, specs: &[&str]) -> String { +fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { + // The test dependency provider uses tokio for sleeping + let executor = tokio::runtime::Builder::new_multi_thread() + .enable_time() + .build() + .unwrap(); + + provider.runtime = Some(executor); + let requirements = provider.requirements(specs); let mut solver = Solver::new(provider); match solver.solve(requirements) { @@ -465,6 +520,22 @@ fn test_resolve_multiple() { assert_eq!(solvable.inner().version, 5); } +#[test] +fn test_resolve_with_concurrent_metadata_fetching() { + let provider = BundleBoxProvider::from_packages(&[ + ("parent", 4, vec!["child1", "child2"]), + ("child1", 3, vec![]), + ("child2", 2, vec![]), + ]); + + let max_concurrent_requests = provider.concurrent_requests_max.clone(); + + let result = solve_snapshot(provider, &["parent"]); + insta::assert_snapshot!(result); + + assert_eq!(2, max_concurrent_requests.get()); +} + /// In case of a conflict the version should not be selected with the conflict #[test] fn test_resolve_with_conflict() {