import hashlib
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from pip._vendor.packaging.tags import Tag, interpreter_name, interpreter_version
from pip._vendor.packaging.utils import canonicalize_name
from pip._internal.exceptions import InvalidWheelFilename
from pip._internal.models.direct_url import DirectUrl
from pip._internal.models.link import Link
from pip._internal.models.wheel import Wheel
from pip._internal.utils.temp_dir import TempDirectory, tempdir_kinds
from pip._internal.utils.urls import path_to_url
logger = logging.getLogger(__name__)
ORIGIN_JSON_NAME = "origin.json"
def _hash_dict(d: Dict[str, str]) -> str:
s = json.dumps(d, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
return hashlib.sha224(s.encode("ascii")).hexdigest()
class Cache:
def __init__(self, cache_dir: str) -> None:
super().__init__()
assert not cache_dir or os.path.isabs(cache_dir)
self.cache_dir = cache_dir or None
def _get_cache_path_parts(self, link: Link) -> List[str]:
key_parts = {"url": link.url_without_fragment}
if link.hash_name is not None and link.hash is not None:
key_parts[link.hash_name] = link.hash
if link.subdirectory_fragment:
key_parts["subdirectory"] = link.subdirectory_fragment
key_parts["interpreter_name"] = interpreter_name()
key_parts["interpreter_version"] = interpreter_version()
hashed = _hash_dict(key_parts)
parts = [hashed[:2], hashed[2:4], hashed[4:6], hashed[6:]]
return parts
def _get_candidates(self, link: Link, canonical_package_name: str) -> List[Any]:
can_not_cache = not self.cache_dir or not canonical_package_name or not link
if can_not_cache:
return []
path = self.get_path_for_link(link)
if os.path.isdir(path):
return [(candidate, path) for candidate in os.listdir(path)]
return []
def get_path_for_link(self, link: Link) -> str:
raise NotImplementedError()
def get(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Link:
raise NotImplementedError()
class SimpleWheelCache(Cache):
def __init__(self, cache_dir: str) -> None:
super().__init__(cache_dir)
def get_path_for_link(self, link: Link) -> str:
parts = self._get_cache_path_parts(link)
assert self.cache_dir
return os.path.join(self.cache_dir, "wheels", *parts)
def get(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Link:
candidates = []
if not package_name:
return link
canonical_package_name = canonicalize_name(package_name)
for wheel_name, wheel_dir in self._get_candidates(link, canonical_package_name):
try:
wheel = Wheel(wheel_name)
except InvalidWheelFilename:
continue
if canonicalize_name(wheel.name) != canonical_package_name:
logger.debug(
"Ignoring cached wheel %s for %s as it "
"does not match the expected distribution name %s.",
wheel_name,
link,
package_name,
)
continue
if not wheel.supported(supported_tags):
continue
candidates.append(
(
wheel.support_index_min(supported_tags),
wheel_name,
wheel_dir,
)
)
if not candidates:
return link
_, wheel_name, wheel_dir = min(candidates)
return Link(path_to_url(os.path.join(wheel_dir, wheel_name)))
class EphemWheelCache(SimpleWheelCache):
def __init__(self) -> None:
self._temp_dir = TempDirectory(
kind=tempdir_kinds.EPHEM_WHEEL_CACHE,
globally_managed=True,
)
super().__init__(self._temp_dir.path)
class CacheEntry:
def __init__(
self,
link: Link,
persistent: bool,
):
self.link = link
self.persistent = persistent
self.origin: Optional[DirectUrl] = None
origin_direct_url_path = Path(self.link.file_path).parent / ORIGIN_JSON_NAME
if origin_direct_url_path.exists():
try:
self.origin = DirectUrl.from_json(
origin_direct_url_path.read_text(encoding="utf-8")
)
except Exception as e:
logger.warning(
"Ignoring invalid cache entry origin file %s for %s (%s)",
origin_direct_url_path,
link.filename,
e,
)
class WheelCache(Cache):
def __init__(self, cache_dir: str) -> None:
super().__init__(cache_dir)
self._wheel_cache = SimpleWheelCache(cache_dir)
self._ephem_cache = EphemWheelCache()
def get_path_for_link(self, link: Link) -> str:
return self._wheel_cache.get_path_for_link(link)
def get_ephem_path_for_link(self, link: Link) -> str:
return self._ephem_cache.get_path_for_link(link)
def get(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Link:
cache_entry = self.get_cache_entry(link, package_name, supported_tags)
if cache_entry is None:
return link
return cache_entry.link
def get_cache_entry(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Optional[CacheEntry]:
retval = self._wheel_cache.get(
link=link,
package_name=package_name,
supported_tags=supported_tags,
)
if retval is not link:
return CacheEntry(retval, persistent=True)
retval = self._ephem_cache.get(
link=link,
package_name=package_name,
supported_tags=supported_tags,
)
if retval is not link:
return CacheEntry(retval, persistent=False)
return None
@staticmethod
def record_download_origin(cache_dir: str, download_info: DirectUrl) -> None:
origin_path = Path(cache_dir) / ORIGIN_JSON_NAME
if origin_path.exists():
try:
origin = DirectUrl.from_json(origin_path.read_text(encoding="utf-8"))
except Exception as e:
logger.warning(
"Could not read origin file %s in cache entry (%s). "
"Will attempt to overwrite it.",
origin_path,
e,
)
else:
if origin.url != download_info.url:
logger.warning(
"Origin URL %s in cache entry %s does not match download URL "
"%s. This is likely a pip bug or a cache corruption issue. "
"Will overwrite it with the new value.",
origin.url,
cache_dir,
download_info.url,
)
origin_path.write_text(download_info.to_json(), encoding="utf-8")