diff --git a/src/resolvelib/providers.py b/src/resolvelib/providers.py index 965cf9c..a7f4fe9 100644 --- a/src/resolvelib/providers.py +++ b/src/resolvelib/providers.py @@ -94,6 +94,21 @@ def get_dependencies(self, candidate): """ raise NotImplementedError + def match_identically(self, requirements_a, requirements_b): + """Whether the two given requirement sets find the same candidates. + + This is used by the resolver to perform tree-pruning. If the two + requirement sets provide the same candidates, the resolver can avoid + visiting the subtree again when it's encountered, and directly mark it + as a dead end instead. + + Both arguments are iterators yielding requirement objects. A boolean + should be returned to indicate whether the two sets should be treated + as matching. + """ + return False # TODO: Remove this and implement the method in tests. + raise NotImplementedError + class AbstractResolver(object): """The thing that performs the actual resolution work.""" diff --git a/src/resolvelib/resolvers.py b/src/resolvelib/resolvers.py index 621267d..dc63338 100644 --- a/src/resolvelib/resolvers.py +++ b/src/resolvelib/resolvers.py @@ -1,4 +1,5 @@ import collections +import itertools from .providers import AbstractResolver from .structs import DirectedGraph, build_iter_view @@ -143,6 +144,7 @@ def __init__(self, provider, reporter): self._p = provider self._r = reporter self._states = [] + self._known_failures = [] @property def state(self): @@ -199,6 +201,22 @@ def _get_criteria_to_update(self, candidate): criteria[name] = crit return criteria + def _match_known_failure_causes(self, updating_criteria): + criteria = self.state.criteria.copy() + criteria.update(updating_criteria) + for state in self._known_failures: + identical = self._p.match_identically( + itertools.chain.from_iterable( + crit.iter_requirement() for crit in criteria.values() + ), + itertools.chain.from_iterable( + crit.iter_requirement() for crit in state.criteria.values() + ), + ) + if identical: + return True + return False + def _attempt_to_pin_criterion(self, name, criterion): causes = [] for candidate in criterion.candidates: @@ -208,6 +226,9 @@ def _attempt_to_pin_criterion(self, name, criterion): causes.append(e.criterion) continue + if self._match_known_failure_causes(criteria): + continue + # Check the newly-pinned candidate actually works. This should # always pass under normal circumstances, but in the case of a # faulty provider, we will raise an error to notify the implementer @@ -226,7 +247,7 @@ def _attempt_to_pin_criterion(self, name, criterion): self.state.mapping[name] = candidate self.state.criteria.update(criteria) - return [] + return None # All candidates tried, nothing works. This criterion is a dead # end, signal for backtracking. @@ -260,7 +281,7 @@ def _backtrack(self): """ while len(self._states) >= 3: # Remove the state that triggered backtracking. - del self._states[-1] + self._known_failures.append(self._states.pop()) # Retrieve the last candidate pin and known incompatibilities. broken_state = self._states.pop() @@ -345,7 +366,7 @@ def resolve(self, requirements, max_rounds): ) failure_causes = self._attempt_to_pin_criterion(name, criterion) - if failure_causes: + if failure_causes is not None: # Backtrack if pinning fails. The backtrack process puts us in # an unpinned state, so we can work on it in the next round. success = self._backtrack() diff --git a/tests/conftest.py b/tests/conftest.py index c9fb2d0..263c651 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,3 +27,8 @@ def reporter_cls(): @pytest.fixture() def reporter(reporter_cls): return reporter_cls() + + +@pytest.fixture() +def base_reporter(): + return BaseReporter() diff --git a/tests/test_resolvers.py b/tests/test_resolvers.py index b0c2d38..909bc6a 100644 --- a/tests/test_resolvers.py +++ b/tests/test_resolvers.py @@ -1,3 +1,6 @@ +import collections +import operator + import pytest from resolvelib import ( @@ -40,3 +43,75 @@ def is_satisfied_by(self, r, c): assert str(ctx.value) == "Provided candidate 'bar' does not satisfy 'foo'" assert ctx.value.candidate is candidate assert list(ctx.value.criterion.iter_requirement()) == [requirement] + + +def test_criteria_pruning(reporter_cls, base_reporter): + C = collections.namedtuple("C", "name version dependencies") + R = collections.namedtuple("R", "name versions") + + # Both C versions have the same dependencies. The resolver should be start + # enough to not pin C1 after C2 fails. + candidate_definitions = [ + C("a", 1, []), + C("a", 2, []), + C("b", 1, [R("a", [2])]), + C("c", 1, [R("b", [1]), R("a", [2])]), + C("c", 2, [R("b", [1]), R("a", [1])]), + C("c", 3, [R("b", [1]), R("a", [1])]), + ] + + class Provider(AbstractProvider): + def identify(self, d): + return d.name + + def get_preference(self, resolution, candidates, information): + # Order by name for reproducibility. + return next(iter(candidates)).name + + def find_matches(self, requirements): + if not requirements: + return () + matches = ( + c + for c in candidate_definitions + if all(self.is_satisfied_by(r, c) for r in requirements) + ) + return sorted( + matches, + key=operator.attrgetter("version"), + reverse=True, + ) + + def is_satisfied_by(self, requirement, candidate): + return ( + candidate.name == requirement.name + and candidate.version in requirement.versions + ) + + def match_identically(self, reqs1, reqs2): + vers1 = collections.defaultdict(set) + vers2 = collections.defaultdict(set) + for rs, vs in [(reqs1, vers1), (reqs2, vers2)]: + for r in rs: + vs[r.name] = vs[r.name].union(r.versions) + return vers1 == vers2 + + def get_dependencies(self, candidate): + return candidate.dependencies + + class Reporter(reporter_cls): + def __init__(self): + super(Reporter, self).__init__() + self.pinned_c = [] + + def pinning(self, candidate): + super(Reporter, self).pinning(candidate) + if candidate.name == "c": + self.pinned_c.append(candidate.version) + + reporter = Reporter() + result = Resolver(Provider(), reporter).resolve([R("c", [1, 2, 3])]) + + pinned_versions = {c.name: c.version for c in result.mapping.values()} + assert pinned_versions == {"a": 2, "b": 1, "c": 1} + assert reporter.pinned_c == [3, 1], "should be smart enough to skip c==2"