Skip to content

Commit 3cd2c8e

Browse files
rd4398claude
andcommitted
refactor(resolver): use provider pattern directly in resolution layer
Replace intermediate resolution functions with direct provider usage: - RequirementResolver now calls resolve_from_provider() directly - sources.resolve_source() uses get_resolver_provider plugin hook - wheels.resolve_prebuilt_wheel() uses get_resolver_provider plugin hook - Removed resolve_all() methods per architect guidance This refactoring simplifies the resolution architecture by eliminating the intermediate *_all() function layer while maintaining all existing functionality. All resolution now goes through the provider pattern with overrides.find_and_invoke() + resolver.resolve_from_provider(). Tests updated to match new implementation. Co-Authored-By: Claude Sonnet 4.5 <[email protected]> Signed-off-by: Rohan Devasthale <[email protected]>
1 parent aec9c9c commit 3cd2c8e

7 files changed

Lines changed: 222 additions & 199 deletions

File tree

src/fromager/bootstrapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def resolve_version(
176176
) -> tuple[str, Version]:
177177
"""Resolve the version of a requirement.
178178
179-
Returns the source URL and the version of the requirement.
179+
Returns the source URL and the version of the requirement (highest matching version).
180180
181181
Git URL resolution stays in Bootstrapper because it requires
182182
build orchestration (BuildEnvironment, build dependencies).
@@ -193,19 +193,22 @@ def resolve_version(
193193
cached_result = self._resolver.get_cached_resolution(req, pre_built=False)
194194
if cached_result is not None:
195195
logger.debug(f"resolved {req} from cache")
196-
return cached_result
196+
# Pick highest version from cached list
197+
return cached_result[0]
197198

198199
logger.info("resolving source via URL, ignoring any plugins")
199200
source_url, resolved_version = self._resolve_version_from_git_url(req=req)
200201
# Cache the git URL resolution (always source, not prebuilt)
202+
# Store as list for consistency with cache structure
201203
self._resolver.cache_resolution(
202-
req, pre_built=False, result=(source_url, resolved_version)
204+
req, pre_built=False, result=[(source_url, resolved_version)]
203205
)
204206
return source_url, resolved_version
205207

206208
# Delegate to RequirementResolver
207209
parent_req = self.why[-1][1] if self.why else None
208210

211+
# Returns the highest matching version
209212
return self._resolver.resolve(
210213
req=req,
211214
req_type=req_type,

src/fromager/requirement_resolver.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from packaging.requirements import Requirement
1313
from packaging.version import Version
1414

15-
from . import resolver, sources, wheels
15+
from . import overrides, resolver
1616
from .dependency_graph import DependencyGraph
1717
from .requirements_file import RequirementType
1818

@@ -51,7 +51,10 @@ def __init__(
5151
self.prev_graph = prev_graph
5252
# Session-level resolution cache to avoid re-resolving same requirements
5353
# Key: (requirement_string, pre_built) to distinguish source vs prebuilt
54-
self._resolved_requirements: dict[tuple[str, bool], tuple[str, Version]] = {}
54+
# Value: list of (url, version) tuples sorted by version (highest first)
55+
self._resolved_requirements: dict[
56+
tuple[str, bool], list[tuple[str, Version]]
57+
] = {}
5558

5659
def resolve(
5760
self,
@@ -60,7 +63,7 @@ def resolve(
6063
parent_req: Requirement | None = None,
6164
pre_built: bool | None = None,
6265
) -> tuple[str, Version]:
63-
"""Resolve package requirement.
66+
"""Resolve package requirement to the best matching version.
6467
6568
Tries resolution strategies in order:
6669
1. Session cache (if previously resolved)
@@ -75,7 +78,7 @@ def resolve(
7578
If None (default), uses package build info to determine.
7679
7780
Returns:
78-
Tuple of (url, resolved_version)
81+
(url, version) tuple for the highest matching version
7982
8083
Raises:
8184
ValueError: If req contains a git URL and pre_built is False
@@ -98,23 +101,22 @@ def resolve(
98101
cached_result = self.get_cached_resolution(req, pre_built)
99102
if cached_result is not None:
100103
logger.debug(f"resolved {req} from cache")
101-
return cached_result
104+
return cached_result[0]
102105

103106
# Resolve using strategies
104-
url, resolved_version = self._resolve(req, req_type, parent_req, pre_built)
107+
results = self._resolve(req, req_type, parent_req, pre_built)
105108

106109
# Cache the result
107-
result = (url, resolved_version)
108-
self.cache_resolution(req, pre_built, result)
109-
return url, resolved_version
110+
self.cache_resolution(req, pre_built, results)
111+
return results[0]
110112

111113
def _resolve(
112114
self,
113115
req: Requirement,
114116
req_type: RequirementType,
115117
parent_req: Requirement | None,
116118
pre_built: bool,
117-
) -> tuple[str, Version]:
119+
) -> list[tuple[str, Version]]:
118120
"""Internal resolution logic without caching.
119121
120122
Tries resolution strategies in order:
@@ -128,7 +130,7 @@ def _resolve(
128130
pre_built: Whether to resolve prebuilt (True) or source (False)
129131
130132
Returns:
131-
Tuple of (url, resolved_version)
133+
List of (url, version) tuples sorted by version (highest first)
132134
"""
133135
# Try graph
134136
cached_resolution = self._resolve_from_graph(
@@ -139,43 +141,81 @@ def _resolve(
139141
)
140142

141143
if cached_resolution and not req.url:
142-
url, resolved_version = cached_resolution
143-
logger.debug(f"resolved from previous bootstrap to {resolved_version}")
144-
return url, resolved_version
144+
logger.debug(
145+
f"resolved from previous bootstrap: {len(cached_resolution)} version(s)"
146+
)
147+
return cached_resolution
148+
149+
# Fallback to PyPI using provider pattern
150+
pbi = self.ctx.package_build_info(req)
145151

146-
# Fallback to PyPI
147152
if pre_built:
148153
# Resolve prebuilt wheel
149-
servers = wheels.get_wheel_server_urls(
150-
self.ctx, req, cache_wheel_server_url=resolver.PYPI_SERVER_URL
151-
)
152-
url, resolved_version = wheels.resolve_prebuilt_wheel(
153-
ctx=self.ctx, req=req, wheel_server_urls=servers, req_type=req_type
154+
# Get wheel server URLs
155+
wheel_server_urls: list[str] = []
156+
if pbi.wheel_server_url:
157+
wheel_server_urls.append(pbi.wheel_server_url)
158+
else:
159+
if self.ctx.wheel_server_url:
160+
wheel_server_urls.append(self.ctx.wheel_server_url)
161+
wheel_server_urls.append(resolver.PYPI_SERVER_URL)
162+
163+
# Try each wheel server until one succeeds
164+
for url in wheel_server_urls:
165+
try:
166+
provider = overrides.find_and_invoke(
167+
req.name,
168+
"get_resolver_provider",
169+
resolver.default_resolver_provider,
170+
ctx=self.ctx,
171+
req=req,
172+
include_sdists=False,
173+
include_wheels=True,
174+
sdist_server_url=url,
175+
req_type=req_type,
176+
ignore_platform=False,
177+
)
178+
results = resolver.resolve_from_provider(provider, req)
179+
if results:
180+
return results
181+
except Exception:
182+
continue
183+
# If we get here, no wheel server succeeded
184+
raise ValueError(
185+
f"Could not find a prebuilt wheel for {req} on {' or '.join(wheel_server_urls)}"
154186
)
155187
else:
156188
# Resolve source (sdist)
157-
url, resolved_version = sources.resolve_source(
189+
override_sdist_server_url = pbi.resolver_sdist_server_url(
190+
resolver.PYPI_SERVER_URL
191+
)
192+
provider = overrides.find_and_invoke(
193+
req.name,
194+
"get_resolver_provider",
195+
resolver.default_resolver_provider,
158196
ctx=self.ctx,
159197
req=req,
160-
sdist_server_url=resolver.PYPI_SERVER_URL,
198+
include_sdists=pbi.resolver_include_sdists,
199+
include_wheels=pbi.resolver_include_wheels,
200+
sdist_server_url=override_sdist_server_url,
161201
req_type=req_type,
202+
ignore_platform=pbi.resolver_ignore_platform,
162203
)
163-
164-
return url, resolved_version
204+
return resolver.resolve_from_provider(provider, req)
165205

166206
def get_cached_resolution(
167207
self,
168208
req: Requirement,
169209
pre_built: bool,
170-
) -> tuple[str, Version] | None:
210+
) -> list[tuple[str, Version]] | None:
171211
"""Get a cached resolution result if it exists.
172212
173213
Args:
174214
req: Package requirement to look up in cache
175215
pre_built: Whether looking for prebuilt or source resolution
176216
177217
Returns:
178-
Tuple of (source_url, resolved_version) if cached, None otherwise
218+
List of (url, version) tuples if cached, None otherwise
179219
"""
180220
cache_key = (str(req), pre_built)
181221
return self._resolved_requirements.get(cache_key)
@@ -184,7 +224,7 @@ def cache_resolution(
184224
self,
185225
req: Requirement,
186226
pre_built: bool,
187-
result: tuple[str, Version],
227+
result: list[tuple[str, Version]],
188228
) -> None:
189229
"""Cache a resolution result.
190230
@@ -194,7 +234,7 @@ def cache_resolution(
194234
Args:
195235
req: Package requirement to cache
196236
pre_built: Whether this is a prebuilt or source resolution
197-
result: Tuple of (source_url, resolved_version)
237+
result: List of (url, version) tuples
198238
"""
199239
cache_key = (str(req), pre_built)
200240
self._resolved_requirements[cache_key] = result
@@ -205,7 +245,7 @@ def _resolve_from_graph(
205245
req_type: RequirementType,
206246
pre_built: bool,
207247
parent_req: Requirement | None,
208-
) -> tuple[str, Version] | None:
248+
) -> list[tuple[str, Version]] | None:
209249
"""Resolve from previous dependency graph.
210250
211251
Extracted from Bootstrapper._resolve_from_graph().
@@ -217,7 +257,7 @@ def _resolve_from_graph(
217257
parent_req: Parent requirement for graph traversal
218258
219259
Returns:
220-
Tuple of (url, version) if found in graph, None otherwise
260+
List of (url, version) tuples if found in graph, None otherwise
221261
"""
222262
if not self.prev_graph:
223263
return None
@@ -307,8 +347,8 @@ def _resolve_from_version_source(
307347
self,
308348
version_source: list[tuple[str, Version]],
309349
req: Requirement,
310-
) -> tuple[str, Version] | None:
311-
"""Select best version from candidates.
350+
) -> list[tuple[str, Version]] | None:
351+
"""Filter and return all matching versions from candidates.
312352
313353
Extracted from Bootstrapper._resolve_from_version_source().
314354
@@ -317,7 +357,7 @@ def _resolve_from_version_source(
317357
req: Package requirement with version specifier
318358
319359
Returns:
320-
Tuple of (url, version) for best match, None if no match
360+
List of (url, version) tuples for all matches, None if no matches
321361
"""
322362
if not version_source:
323363
return None
@@ -329,6 +369,7 @@ def _resolve_from_version_source(
329369
constraints=self.ctx.constraints,
330370
use_resolver_cache=False,
331371
)
372+
# resolve_from_provider now returns all matching candidates
332373
return resolver.resolve_from_provider(provider, req)
333374
except Exception as err:
334375
logger.debug(f"could not resolve {req} from {version_source}: {err}")

src/fromager/resolver.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def resolve(
8686
req_type: RequirementType | None = None,
8787
ignore_platform: bool = False,
8888
) -> tuple[str, Version]:
89+
"""Resolve requirement and return the best matching version.
90+
91+
Returns (url, version) tuple for the highest matching version.
92+
"""
8993
# Create the (reusable) resolver.
9094
provider = overrides.find_and_invoke(
9195
req.name,
@@ -99,7 +103,8 @@ def resolve(
99103
req_type=req_type,
100104
ignore_platform=ignore_platform,
101105
)
102-
return resolve_from_provider(provider, req)
106+
results = resolve_from_provider(provider, req)
107+
return results[0]
103108

104109

105110
def default_resolver_provider(
@@ -165,26 +170,31 @@ def ending(self, state: typing.Any) -> None:
165170

166171
def resolve_from_provider(
167172
provider: BaseProvider, req: Requirement
168-
) -> tuple[str, Version]:
169-
reporter = LogReporter(req)
170-
rslvr: resolvelib.Resolver = resolvelib.Resolver(provider, reporter)
173+
) -> list[tuple[str, Version]]:
174+
"""Resolve requirement and return all matching candidates.
175+
176+
Returns list of (url, version) tuples sorted by version (highest first).
177+
"""
178+
# Get all matching candidates directly from provider
179+
# instead of using resolvelib's resolver which picks just one
180+
identifier = provider.identify(req)
171181
try:
172-
result = rslvr.resolve([req])
182+
candidates = provider.find_matches(
183+
identifier=identifier,
184+
requirements={identifier: [req]},
185+
incompatibilities={},
186+
)
173187
except resolvelib.resolvers.ResolverException as err:
174188
constraint = provider.constraints.get_constraint(req.name)
175189
provider_desc = provider.get_provider_description()
176-
# Include the original error message to preserve detailed information
177-
# (e.g., file types, pre-release info from PyPIProvider)
178190
original_msg = str(err)
179191
raise resolvelib.resolvers.ResolverException(
180192
f"Unable to resolve requirement specifier {req} with constraint {constraint} using {provider_desc}: {original_msg}"
181193
) from err
182-
# resolvelib actually just returns one candidate per requirement.
183-
# result.mapping is map from an identifier to its resolved candidate
184-
candidate: Candidate
185-
for candidate in result.mapping.values():
186-
return candidate.url, candidate.version
187-
raise ValueError(f"Unable to resolve {req}")
194+
195+
# Convert candidates to list of (url, version) tuples
196+
# Candidates are already sorted by version (highest first)
197+
return [(candidate.url, candidate.version) for candidate in candidates]
188198

189199

190200
def get_project_from_pypi(
@@ -468,8 +478,8 @@ def validate_candidate(
468478
incompatibilities: CandidatesMap,
469479
candidate: Candidate,
470480
) -> bool:
471-
identifier_reqs = list(requirements[identifier])
472-
bad_versions = {c.version for c in incompatibilities[identifier]}
481+
identifier_reqs = list(requirements.get(identifier, []))
482+
bad_versions = {c.version for c in incompatibilities.get(identifier, [])}
473483
# Skip versions that are known bad
474484
if candidate.version in bad_versions:
475485
if DEBUG_RESOLVER:
@@ -573,8 +583,11 @@ def _get_no_match_error_message(
573583
574584
Subclasses should override this to provide provider-specific error details.
575585
"""
576-
r = next(iter(requirements[identifier]))
577-
return f"found no match for {r} using {self.get_provider_description()}"
586+
reqs = requirements.get(identifier, [])
587+
if reqs:
588+
r = next(iter(reqs))
589+
return f"found no match for {r} using {self.get_provider_description()}"
590+
return f"found no match for identifier {identifier} using {self.get_provider_description()}"
578591

579592
def find_matches(
580593
self,

0 commit comments

Comments
 (0)