import argparse
import collections
import logging
import re
import subprocess
import sys
from typing import Generator, List, NamedTuple, Iterable
assert sys.version_info >= (3, 6), 'Only Python 3.6+ is supported.'
def _try_parse_reverts_from_commit_message(commit_message: str) -> List[str]:
  if not commit_message:
    return []
  results = re.findall(r'This reverts commit ([a-f0-9]{40})\b', commit_message)
  first_line = commit_message.splitlines()[0]
  initial_revert = re.match(r'Revert ([a-f0-9]{6,}) "', first_line)
  if initial_revert:
    results.append(initial_revert.group(1))
  return results
def _stream_stdout(command: List[str]) -> Generator[str, None, None]:
  with subprocess.Popen(
      command, stdout=subprocess.PIPE, encoding='utf-8', errors='replace') as p:
    assert p.stdout is not None      yield from p.stdout
def _resolve_sha(git_dir: str, sha: str) -> str:
  if len(sha) == 40:
    return sha
  return subprocess.check_output(
      ['git', '-C', git_dir, 'rev-parse', sha],
      encoding='utf-8',
      stderr=subprocess.DEVNULL,
  ).strip()
_LogEntry = NamedTuple('_LogEntry', [
    ('sha', str),
    ('commit_message', str),
])
def _log_stream(git_dir: str, root_sha: str,
                end_at_sha: str) -> Iterable[_LogEntry]:
  sep = 50 * '<>'
  log_command = [
      'git',
      '-C',
      git_dir,
      'log',
      '^' + end_at_sha,
      root_sha,
      '--format=' + sep + '%n%H%n%B%n',
  ]
  stdout_stream = iter(_stream_stdout(log_command))
      found_commit_header = False
  for line in stdout_stream:
    if line.rstrip() == sep:
      found_commit_header = True
      break
  while found_commit_header:
    sha = next(stdout_stream, None)
    assert sha is not None, 'git died?'
    sha = sha.rstrip()
    commit_message = []
    found_commit_header = False
    for line in stdout_stream:
      line = line.rstrip()
      if line.rstrip() == sep:
        found_commit_header = True
        break
      commit_message.append(line)
    yield _LogEntry(sha, '\n'.join(commit_message).rstrip())
def _shas_between(git_dir: str, base_ref: str, head_ref: str) -> Iterable[str]:
  rev_list = [
      'git',
      '-C',
      git_dir,
      'rev-list',
      '--first-parent',
      f'{base_ref}..{head_ref}',
  ]
  return (x.strip() for x in _stream_stdout(rev_list))
def _rev_parse(git_dir: str, ref: str) -> str:
  return subprocess.check_output(
      ['git', '-C', git_dir, 'rev-parse', ref],
      encoding='utf-8',
  ).strip()
Revert = NamedTuple('Revert', [
    ('sha', str),
    ('reverted_sha', str),
])
def _find_common_parent_commit(git_dir: str, ref_a: str, ref_b: str) -> str:
  
  return subprocess.check_output(
      ['git', '-C', git_dir, 'merge-base', ref_a, ref_b],
      encoding='utf-8',
  ).strip()
def find_reverts(git_dir: str, across_ref: str, root: str) -> List[Revert]:
  
  across_sha = _rev_parse(git_dir, across_ref)
  root_sha = _rev_parse(git_dir, root)
  common_ancestor = _find_common_parent_commit(git_dir, across_sha, root_sha)
  if common_ancestor != across_sha:
    raise ValueError(f"{across_sha} isn't an ancestor of {root_sha} "
                     '(common ancestor: {common_ancestor})')
  intermediate_commits = set(_shas_between(git_dir, across_sha, root_sha))
  assert across_sha not in intermediate_commits
  logging.debug('%d commits appear between %s and %s',
                len(intermediate_commits), across_sha, root_sha)
  all_reverts = []
  for sha, commit_message in _log_stream(git_dir, root_sha, across_sha):
    reverts = _try_parse_reverts_from_commit_message(commit_message)
    if not reverts:
      continue
    resolved_reverts = sorted(set(_resolve_sha(git_dir, x) for x in reverts))
    for reverted_sha in resolved_reverts:
      if reverted_sha in intermediate_commits:
        logging.debug('Commit %s reverts %s, which happened after %s', sha,
                      reverted_sha, across_sha)
        continue
      try:
        object_type = subprocess.check_output(
            ['git', '-C', git_dir, 'cat-file', '-t', reverted_sha],
            encoding='utf-8',
            stderr=subprocess.DEVNULL,
        ).strip()
      except subprocess.CalledProcessError:
        logging.warning(
            'Failed to resolve reverted object %s (claimed to be reverted '
            'by sha %s)', reverted_sha, sha)
        continue
      if object_type == 'commit':
        all_reverts.append(Revert(sha, reverted_sha))
        continue
      logging.error("%s claims to revert %s -- which isn't a commit -- %s", sha,
                    object_type, reverted_sha)
        all_reverts.reverse()
  return all_reverts
def _main() -> None:
  parser = argparse.ArgumentParser(
      description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
  parser.add_argument(
      'base_ref', help='Git ref or sha to check for reverts around.')
  parser.add_argument(
      '-C', '--git_dir', default='.', help='Git directory to use.')
  parser.add_argument(
      'root', nargs='+', help='Root(s) to search for commits from.')
  parser.add_argument('--debug', action='store_true')
  parser.add_argument(
      '-u', '--review_url', action='store_true',
      help='Format SHAs as llvm review URLs')
  opts = parser.parse_args()
  logging.basicConfig(
      format='%(asctime)s: %(levelname)s: %(filename)s:%(lineno)d: %(message)s',
      level=logging.DEBUG if opts.debug else logging.INFO,
  )
          seen_reverts = set()
  all_reverts = []
  for root in opts.root:
    for revert in find_reverts(opts.git_dir, opts.base_ref, root):
      if revert not in seen_reverts:
        seen_reverts.add(revert)
        all_reverts.append(revert)
  for revert in all_reverts:
    sha_fmt = (f'https://reviews.llvm.org/rG{revert.sha}'
               if opts.review_url else revert.sha)
    reverted_sha_fmt = (f'https://reviews.llvm.org/rG{revert.reverted_sha}'
                        if opts.review_url else revert.reverted_sha)
    print(f'{sha_fmt} claims to revert {reverted_sha_fmt}')
if __name__ == '__main__':
  _main()