import os
import re
import abc
import csv
import sys
from .. import zipp
import email
import pathlib
import operator
import textwrap
import warnings
import functools
import itertools
import posixpath
import collections
from . import _adapters, _meta
from ._collections import FreezableDefaultDict, Pair
from ._compat import (
NullFinder,
install,
pypy_partial,
)
from ._functools import method_cache, pass_none
from ._itertools import always_iterable, unique_everseen
from ._meta import PackageMetadata, SimplePath
from contextlib import suppress
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
from typing import List, Mapping, Optional, Union
__all__ = [
'Distribution',
'DistributionFinder',
'PackageMetadata',
'PackageNotFoundError',
'distribution',
'distributions',
'entry_points',
'files',
'metadata',
'packages_distributions',
'requires',
'version',
]
class PackageNotFoundError(ModuleNotFoundError):
def __str__(self):
return f"No package metadata was found for {self.name}"
@property
def name(self):
(name,) = self.args
return name
class Sectioned:
_sample = textwrap.dedent(
"""
[sec1]
# comments ignored
a = 1
b = 2
[sec2]
a = 2
"""
).lstrip()
@classmethod
def section_pairs(cls, text):
return (
section._replace(value=Pair.parse(section.value))
for section in cls.read(text, filter_=cls.valid)
if section.name is not None
)
@staticmethod
def read(text, filter_=None):
lines = filter(filter_, map(str.strip, text.splitlines()))
name = None
for value in lines:
section_match = value.startswith('[') and value.endswith(']')
if section_match:
name = value.strip('[]')
continue
yield Pair(name, value)
@staticmethod
def valid(line):
return line and not line.startswith('#')
class DeprecatedTuple:
_warn = functools.partial(
warnings.warn,
"EntryPoint tuple interface is deprecated. Access members by name.",
DeprecationWarning,
stacklevel=pypy_partial(2),
)
def __getitem__(self, item):
self._warn()
return self._key()[item]
class EntryPoint(DeprecatedTuple):
pattern = re.compile(
r'(?P<module>[\w.]+)\s*'
r'(:\s*(?P<attr>[\w.]+)\s*)?'
r'((?P<extras>\[.*\])\s*)?$'
)
dist: Optional['Distribution'] = None
def __init__(self, name, value, group):
vars(self).update(name=name, value=value, group=group)
def load(self):
match = self.pattern.match(self.value)
module = import_module(match.group('module'))
attrs = filter(None, (match.group('attr') or '').split('.'))
return functools.reduce(getattr, attrs, module)
@property
def module(self):
match = self.pattern.match(self.value)
return match.group('module')
@property
def attr(self):
match = self.pattern.match(self.value)
return match.group('attr')
@property
def extras(self):
match = self.pattern.match(self.value)
return list(re.finditer(r'\w+', match.group('extras') or ''))
def _for(self, dist):
vars(self).update(dist=dist)
return self
def __iter__(self):
msg = (
"Construction of dict of EntryPoints is deprecated in "
"favor of EntryPoints."
)
warnings.warn(msg, DeprecationWarning)
return iter((self.name, self))
def matches(self, **params):
attrs = (getattr(self, param) for param in params)
return all(map(operator.eq, params.values(), attrs))
def _key(self):
return self.name, self.value, self.group
def __lt__(self, other):
return self._key() < other._key()
def __eq__(self, other):
return self._key() == other._key()
def __setattr__(self, name, value):
raise AttributeError("EntryPoint objects are immutable.")
def __repr__(self):
return (
f'EntryPoint(name={self.name!r}, value={self.value!r}, '
f'group={self.group!r})'
)
def __hash__(self):
return hash(self._key())
class DeprecatedList(list):
__slots__ = ()
_warn = functools.partial(
warnings.warn,
"EntryPoints list interface is deprecated. Cast to list if needed.",
DeprecationWarning,
stacklevel=pypy_partial(2),
)
def _wrap_deprecated_method(method_name: str): def wrapped(self, *args, **kwargs):
self._warn()
return getattr(super(), method_name)(*args, **kwargs)
return method_name, wrapped
locals().update(
map(
_wrap_deprecated_method,
'__setitem__ __delitem__ append reverse extend pop remove '
'__iadd__ insert sort'.split(),
)
)
def __add__(self, other):
if not isinstance(other, tuple):
self._warn()
other = tuple(other)
return self.__class__(tuple(self) + other)
def __eq__(self, other):
if not isinstance(other, tuple):
self._warn()
other = tuple(other)
return tuple(self).__eq__(other)
class EntryPoints(DeprecatedList):
__slots__ = ()
def __getitem__(self, name):
if isinstance(name, int):
warnings.warn(
"Accessing entry points by index is deprecated. "
"Cast to tuple if needed.",
DeprecationWarning,
stacklevel=2,
)
return super().__getitem__(name)
try:
return next(iter(self.select(name=name)))
except StopIteration:
raise KeyError(name)
def select(self, **params):
return EntryPoints(ep for ep in self if ep.matches(**params))
@property
def names(self):
return {ep.name for ep in self}
@property
def groups(self):
return {ep.group for ep in self}
@classmethod
def _from_text_for(cls, text, dist):
return cls(ep._for(dist) for ep in cls._from_text(text))
@staticmethod
def _from_text(text):
return (
EntryPoint(name=item.value.name, value=item.value.value, group=item.name)
for item in Sectioned.section_pairs(text or '')
)
class Deprecated:
_warn = functools.partial(
warnings.warn,
"SelectableGroups dict interface is deprecated. Use select.",
DeprecationWarning,
stacklevel=pypy_partial(2),
)
def __getitem__(self, name):
self._warn()
return super().__getitem__(name)
def get(self, name, default=None):
self._warn()
return super().get(name, default)
def __iter__(self):
self._warn()
return super().__iter__()
def __contains__(self, *args):
self._warn()
return super().__contains__(*args)
def keys(self):
self._warn()
return super().keys()
def values(self):
self._warn()
return super().values()
class SelectableGroups(Deprecated, dict):
@classmethod
def load(cls, eps):
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return cls((group, EntryPoints(eps)) for group, eps in grouped)
@property
def _all(self):
groups = super(Deprecated, self).values()
return EntryPoints(itertools.chain.from_iterable(groups))
@property
def groups(self):
return self._all.groups
@property
def names(self):
return self._all.names
def select(self, **params):
if not params:
return self
return self._all.select(**params)
class PackagePath(pathlib.PurePosixPath):
def read_text(self, encoding='utf-8'):
with self.locate().open(encoding=encoding) as stream:
return stream.read()
def read_binary(self):
with self.locate().open('rb') as stream:
return stream.read()
def locate(self):
return self.dist.locate_file(self)
class FileHash:
def __init__(self, spec):
self.mode, _, self.value = spec.partition('=')
def __repr__(self):
return f'<FileHash mode: {self.mode} value: {self.value}>'
class Distribution:
@abc.abstractmethod
def read_text(self, filename):
@abc.abstractmethod
def locate_file(self, path):
@classmethod
def from_name(cls, name):
for resolver in cls._discover_resolvers():
dists = resolver(DistributionFinder.Context(name=name))
dist = next(iter(dists), None)
if dist is not None:
return dist
else:
raise PackageNotFoundError(name)
@classmethod
def discover(cls, **kwargs):
context = kwargs.pop('context', None)
if context and kwargs:
raise ValueError("cannot accept context and kwargs")
context = context or DistributionFinder.Context(**kwargs)
return itertools.chain.from_iterable(
resolver(context) for resolver in cls._discover_resolvers()
)
@staticmethod
def at(path):
return PathDistribution(pathlib.Path(path))
@staticmethod
def _discover_resolvers():
declared = (
getattr(finder, 'find_distributions', None) for finder in sys.meta_path
)
return filter(None, declared)
@property
def metadata(self) -> _meta.PackageMetadata:
text = (
self.read_text('METADATA')
or self.read_text('PKG-INFO')
or self.read_text('')
)
return _adapters.Message(email.message_from_string(text))
@property
def name(self):
return self.metadata['Name']
@property
def _normalized_name(self):
return Prepared.normalize(self.name)
@property
def version(self):
return self.metadata['Version']
@property
def entry_points(self):
return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self)
@property
def files(self):
def make_file(name, hash=None, size_str=None):
result = PackagePath(name)
result.hash = FileHash(hash) if hash else None
result.size = int(size_str) if size_str else None
result.dist = self
return result
@pass_none
def make_files(lines):
return list(starmap(make_file, csv.reader(lines)))
return make_files(self._read_files_distinfo() or self._read_files_egginfo())
def _read_files_distinfo(self):
text = self.read_text('RECORD')
return text and text.splitlines()
def _read_files_egginfo(self):
text = self.read_text('SOURCES.txt')
return text and map('"{}"'.format, text.splitlines())
@property
def requires(self):
reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs()
return reqs and list(reqs)
def _read_dist_info_reqs(self):
return self.metadata.get_all('Requires-Dist')
def _read_egg_info_reqs(self):
source = self.read_text('requires.txt')
return pass_none(self._deps_from_requires_text)(source)
@classmethod
def _deps_from_requires_text(cls, source):
return cls._convert_egg_info_reqs_to_simple_reqs(Sectioned.read(source))
@staticmethod
def _convert_egg_info_reqs_to_simple_reqs(sections):
def make_condition(name):
return name and f'extra == "{name}"'
def quoted_marker(section):
section = section or ''
extra, sep, markers = section.partition(':')
if extra and markers:
markers = f'({markers})'
conditions = list(filter(None, [markers, make_condition(extra)]))
return '; ' + ' and '.join(conditions) if conditions else ''
def url_req_space(req):
return ' ' * ('@' in req)
for section in sections:
space = url_req_space(section.value)
yield section.value + space + quoted_marker(section.name)
class DistributionFinder(MetaPathFinder):
class Context:
name = None
def __init__(self, **kwargs):
vars(self).update(kwargs)
@property
def path(self):
return vars(self).get('path', sys.path)
@abc.abstractmethod
def find_distributions(self, context=Context()):
class FastPath:
@functools.lru_cache() def __new__(cls, root):
return super().__new__(cls)
def __init__(self, root):
self.root = str(root)
def joinpath(self, child):
return pathlib.Path(self.root, child)
def children(self):
with suppress(Exception):
return os.listdir(self.root or '.')
with suppress(Exception):
return self.zip_children()
return []
def zip_children(self):
zip_path = zipp.Path(self.root)
names = zip_path.root.namelist()
self.joinpath = zip_path.joinpath
return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names)
def search(self, name):
return self.lookup(self.mtime).search(name)
@property
def mtime(self):
with suppress(OSError):
return os.stat(self.root).st_mtime
self.lookup.cache_clear()
@method_cache
def lookup(self, mtime):
return Lookup(self)
class Lookup:
def __init__(self, path: FastPath):
base = os.path.basename(path.root).lower()
base_is_egg = base.endswith(".egg")
self.infos = FreezableDefaultDict(list)
self.eggs = FreezableDefaultDict(list)
for child in path.children():
low = child.lower()
if low.endswith((".dist-info", ".egg-info")):
name = low.rpartition(".")[0].partition("-")[0]
normalized = Prepared.normalize(name)
self.infos[normalized].append(path.joinpath(child))
elif base_is_egg and low == "egg-info":
name = base.rpartition(".")[0].partition("-")[0]
legacy_normalized = Prepared.legacy_normalize(name)
self.eggs[legacy_normalized].append(path.joinpath(child))
self.infos.freeze()
self.eggs.freeze()
def search(self, prepared):
infos = (
self.infos[prepared.normalized]
if prepared
else itertools.chain.from_iterable(self.infos.values())
)
eggs = (
self.eggs[prepared.legacy_normalized]
if prepared
else itertools.chain.from_iterable(self.eggs.values())
)
return itertools.chain(infos, eggs)
class Prepared:
normalized = None
legacy_normalized = None
def __init__(self, name):
self.name = name
if name is None:
return
self.normalized = self.normalize(name)
self.legacy_normalized = self.legacy_normalize(name)
@staticmethod
def normalize(name):
return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_')
@staticmethod
def legacy_normalize(name):
return name.lower().replace('-', '_')
def __bool__(self):
return bool(self.name)
@install
class MetadataPathFinder(NullFinder, DistributionFinder):
def find_distributions(self, context=DistributionFinder.Context()):
found = self._search_paths(context.name, context.path)
return map(PathDistribution, found)
@classmethod
def _search_paths(cls, name, paths):
prepared = Prepared(name)
return itertools.chain.from_iterable(
path.search(prepared) for path in map(FastPath, paths)
)
def invalidate_caches(cls):
FastPath.__new__.cache_clear()
class PathDistribution(Distribution):
def __init__(self, path: SimplePath):
self._path = path
def read_text(self, filename):
with suppress(
FileNotFoundError,
IsADirectoryError,
KeyError,
NotADirectoryError,
PermissionError,
):
return self._path.joinpath(filename).read_text(encoding='utf-8')
read_text.__doc__ = Distribution.read_text.__doc__
def locate_file(self, path):
return self._path.parent / path
@property
def _normalized_name(self):
stem = os.path.basename(str(self._path))
return self._name_from_stem(stem) or super()._normalized_name
def _name_from_stem(self, stem):
name, ext = os.path.splitext(stem)
if ext not in ('.dist-info', '.egg-info'):
return
name, sep, rest = stem.partition('-')
return name
def distribution(distribution_name):
return Distribution.from_name(distribution_name)
def distributions(**kwargs):
return Distribution.discover(**kwargs)
def metadata(distribution_name) -> _meta.PackageMetadata:
return Distribution.from_name(distribution_name).metadata
def version(distribution_name):
return distribution(distribution_name).version
def entry_points(**params) -> Union[EntryPoints, SelectableGroups]:
norm_name = operator.attrgetter('_normalized_name')
unique = functools.partial(unique_everseen, key=norm_name)
eps = itertools.chain.from_iterable(
dist.entry_points for dist in unique(distributions())
)
return SelectableGroups.load(eps).select(**params)
def files(distribution_name):
return distribution(distribution_name).files
def requires(distribution_name):
return distribution(distribution_name).requires
def packages_distributions() -> Mapping[str, List[str]]:
pkg_to_dist = collections.defaultdict(list)
for dist in distributions():
for pkg in _top_level_declared(dist) or _top_level_inferred(dist):
pkg_to_dist[pkg].append(dist.metadata['Name'])
return dict(pkg_to_dist)
def _top_level_declared(dist):
return (dist.read_text('top_level.txt') or '').split()
def _top_level_inferred(dist):
return {
f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name
for f in always_iterable(dist.files)
if f.suffix == ".py"
}