Skip to content

Commit

Permalink
Use a set for TargetPython.get_tags for performance (#12204)
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored Aug 6, 2023
1 parent d311e6e commit 901db9c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 16 deletions.
1 change: 1 addition & 0 deletions news/12204.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve use of datastructures to make candidate selection 1.6x faster
2 changes: 1 addition & 1 deletion src/pip/_internal/commands/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def show_tags(options: Values) -> None:
tag_limit = 10

target_python = make_target_python(options)
tags = target_python.get_tags()
tags = target_python.get_sorted_tags()

# Display the target options that were explicitly provided.
formatted_target = target_python.format_given()
Expand Down
4 changes: 2 additions & 2 deletions src/pip/_internal/index/package_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def evaluate_link(self, link: Link) -> Tuple[LinkType, str]:
reason = f"wrong project name (not {self.project_name})"
return (LinkType.different_project, reason)

supported_tags = self._target_python.get_tags()
supported_tags = self._target_python.get_unsorted_tags()
if not wheel.supported(supported_tags):
# Include the wheel's tags in the reason string to
# simplify troubleshooting compatibility issues.
Expand Down Expand Up @@ -414,7 +414,7 @@ def create(
if specifier is None:
specifier = specifiers.SpecifierSet()

supported_tags = target_python.get_tags()
supported_tags = target_python.get_sorted_tags()

return cls(
project_name=project_name,
Expand Down
18 changes: 15 additions & 3 deletions src/pip/_internal/models/target_python.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

from pip._vendor.packaging.tags import Tag

Expand All @@ -22,6 +22,7 @@ class TargetPython:
"py_version",
"py_version_info",
"_valid_tags",
"_valid_tags_set",
]

def __init__(
Expand Down Expand Up @@ -61,8 +62,9 @@ def __init__(
self.py_version = py_version
self.py_version_info = py_version_info

# This is used to cache the return value of get_tags().
# This is used to cache the return value of get_(un)sorted_tags.
self._valid_tags: Optional[List[Tag]] = None
self._valid_tags_set: Optional[Set[Tag]] = None

def format_given(self) -> str:
"""
Expand All @@ -84,7 +86,7 @@ def format_given(self) -> str:
f"{key}={value!r}" for key, value in key_values if value is not None
)

def get_tags(self) -> List[Tag]:
def get_sorted_tags(self) -> List[Tag]:
"""
Return the supported PEP 425 tags to check wheel candidates against.
Expand All @@ -108,3 +110,13 @@ def get_tags(self) -> List[Tag]:
self._valid_tags = tags

return self._valid_tags

def get_unsorted_tags(self) -> Set[Tag]:
"""Exactly the same as get_sorted_tags, but returns a set.
This is important for performance.
"""
if self._valid_tags_set is None:
self._valid_tags_set = set(self.get_sorted_tags())

return self._valid_tags_set
2 changes: 1 addition & 1 deletion src/pip/_internal/resolution/resolvelib/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _fail_if_link_is_unsupported_wheel(self, link: Link) -> None:
if not link.is_wheel:
return
wheel = Wheel(link.filename)
if wheel.supported(self._finder.target_python.get_tags()):
if wheel.supported(self._finder.target_python.get_unsorted_tags()):
return
msg = f"{link.filename} is not a supported wheel on this platform."
raise UnsupportedWheel(msg)
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/test_target_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def test_format_given(self, kwargs: Dict[str, Any], expected: str) -> None:
((3, 7, 3), "37"),
# Check a minor version with two digits.
((3, 10, 1), "310"),
# Check that versions=None is passed to get_tags().
# Check that versions=None is passed to get_sorted_tags().
(None, None),
],
)
@mock.patch("pip._internal.models.target_python.get_supported")
def test_get_tags(
def test_get_sorted_tags(
self,
mock_get_supported: mock.Mock,
py_version_info: Optional[Tuple[int, ...]],
Expand All @@ -102,7 +102,7 @@ def test_get_tags(
mock_get_supported.return_value = ["tag-1", "tag-2"]

target_python = TargetPython(py_version_info=py_version_info)
actual = target_python.get_tags()
actual = target_python.get_sorted_tags()
assert actual == ["tag-1", "tag-2"]

actual = mock_get_supported.call_args[1]["version"]
Expand All @@ -111,14 +111,14 @@ def test_get_tags(
# Check that the value was cached.
assert target_python._valid_tags == ["tag-1", "tag-2"]

def test_get_tags__uses_cached_value(self) -> None:
def test_get_unsorted_tags__uses_cached_value(self) -> None:
"""
Test that get_tags() uses the cached value.
Test that get_unsorted_tags() uses the cached value.
"""
target_python = TargetPython(py_version_info=None)
target_python._valid_tags = [
target_python._valid_tags_set = {
Tag("py2", "none", "any"),
Tag("py3", "none", "any"),
]
actual = target_python.get_tags()
assert actual == [Tag("py2", "none", "any"), Tag("py3", "none", "any")]
}
actual = target_python.get_unsorted_tags()
assert actual == {Tag("py2", "none", "any"), Tag("py3", "none", "any")}

0 comments on commit 901db9c

Please sign in to comment.