• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Git helper functions."""
16
17import os
18import re
19import sys
20
21_path = os.path.realpath(__file__ + '/../..')
22if sys.path[0] != _path:
23    sys.path.insert(0, _path)
24del _path
25
26# pylint: disable=wrong-import-position
27import rh.utils
28
29
30def get_upstream_remote():
31    """Returns the current upstream remote name."""
32    # First get the current branch name.
33    cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
34    result = rh.utils.run(cmd, capture_output=True)
35    branch = result.stdout.strip()
36
37    # Then get the remote associated with this branch.
38    cmd = ['git', 'config', 'branch.%s.remote' % branch]
39    result = rh.utils.run(cmd, capture_output=True)
40    return result.stdout.strip()
41
42
43def get_upstream_branch():
44    """Returns the upstream tracking branch of the current branch.
45
46    Raises:
47      Error if there is no tracking branch
48    """
49    cmd = ['git', 'symbolic-ref', 'HEAD']
50    result = rh.utils.run(cmd, capture_output=True)
51    current_branch = result.stdout.strip().replace('refs/heads/', '')
52    if not current_branch:
53        raise ValueError('Need to be on a tracking branch')
54
55    cfg_option = 'branch.' + current_branch + '.%s'
56    cmd = ['git', 'config', cfg_option % 'merge']
57    result = rh.utils.run(cmd, capture_output=True)
58    full_upstream = result.stdout.strip()
59    # If remote is not fully qualified, add an implicit namespace.
60    if '/' not in full_upstream:
61        full_upstream = 'refs/heads/%s' % full_upstream
62    cmd = ['git', 'config', cfg_option % 'remote']
63    result = rh.utils.run(cmd, capture_output=True)
64    remote = result.stdout.strip()
65    if not remote or not full_upstream:
66        raise ValueError('Need to be on a tracking branch')
67
68    return full_upstream.replace('heads', 'remotes/' + remote)
69
70
71def get_commit_for_ref(ref):
72    """Returns the latest commit for this ref."""
73    cmd = ['git', 'rev-parse', ref]
74    result = rh.utils.run(cmd, capture_output=True)
75    return result.stdout.strip()
76
77
78def get_remote_revision(ref, remote):
79    """Returns the remote revision for this ref."""
80    prefix = 'refs/remotes/%s/' % remote
81    if ref.startswith(prefix):
82        return ref[len(prefix):]
83    return ref
84
85
86def get_patch(commit):
87    """Returns the patch for this commit."""
88    cmd = ['git', 'format-patch', '--stdout', '-1', commit]
89    return rh.utils.run(cmd, capture_output=True).stdout
90
91
92def get_file_content(commit, path):
93    """Returns the content of a file at a specific commit.
94
95    We can't rely on the file as it exists in the filesystem as people might be
96    uploading a series of changes which modifies the file multiple times.
97
98    Note: The "content" of a symlink is just the target.  So if you're expecting
99    a full file, you should check that first.  One way to detect is that the
100    content will not have any newlines.
101    """
102    cmd = ['git', 'show', '%s:%s' % (commit, path)]
103    return rh.utils.run(cmd, capture_output=True).stdout
104
105
106class RawDiffEntry(object):
107    """Representation of a line from raw formatted git diff output."""
108
109    # pylint: disable=redefined-builtin
110    def __init__(self, src_mode=0, dst_mode=0, src_sha=None, dst_sha=None,
111                 status=None, score=None, src_file=None, dst_file=None,
112                 file=None):
113        self.src_mode = src_mode
114        self.dst_mode = dst_mode
115        self.src_sha = src_sha
116        self.dst_sha = dst_sha
117        self.status = status
118        self.score = score
119        self.src_file = src_file
120        self.dst_file = dst_file
121        self.file = file
122
123
124# This regular expression pulls apart a line of raw formatted git diff output.
125DIFF_RE = re.compile(
126    r':(?P<src_mode>[0-7]*) (?P<dst_mode>[0-7]*) '
127    r'(?P<src_sha>[0-9a-f]*)(\.)* (?P<dst_sha>[0-9a-f]*)(\.)* '
128    r'(?P<status>[ACDMRTUX])(?P<score>[0-9]+)?\t'
129    r'(?P<src_file>[^\t]+)\t?(?P<dst_file>[^\t]+)?')
130
131
132def raw_diff(path, target):
133    """Return the parsed raw format diff of target
134
135    Args:
136      path: Path to the git repository to diff in.
137      target: The target to diff.
138
139    Returns:
140      A list of RawDiffEntry's.
141    """
142    entries = []
143
144    cmd = ['git', 'diff', '--no-ext-diff', '-M', '--raw', target]
145    diff = rh.utils.run(cmd, cwd=path, capture_output=True).stdout
146    diff_lines = diff.strip().splitlines()
147    for line in diff_lines:
148        match = DIFF_RE.match(line)
149        if not match:
150            raise ValueError('Failed to parse diff output: %s' % line)
151        rawdiff = RawDiffEntry(**match.groupdict())
152        rawdiff.src_mode = int(rawdiff.src_mode)
153        rawdiff.dst_mode = int(rawdiff.dst_mode)
154        rawdiff.file = (rawdiff.dst_file
155                        if rawdiff.dst_file else rawdiff.src_file)
156        entries.append(rawdiff)
157
158    return entries
159
160
161def get_affected_files(commit):
162    """Returns list of file paths that were modified/added.
163
164    Returns:
165      A list of modified/added (and perhaps deleted) files
166    """
167    return raw_diff(os.getcwd(), '%s^-' % commit)
168
169
170def get_commits(ignore_merged_commits=False):
171    """Returns a list of commits for this review."""
172    cmd = ['git', 'rev-list', '%s..' % get_upstream_branch()]
173    if ignore_merged_commits:
174        cmd.append('--first-parent')
175    return rh.utils.run(cmd, capture_output=True).stdout.split()
176
177
178def get_commit_desc(commit):
179    """Returns the full commit message of a commit."""
180    cmd = ['git', 'diff-tree', '-s', '--always', '--format=%B', commit]
181    return rh.utils.run(cmd, capture_output=True).stdout
182
183
184def find_repo_root(path=None):
185    """Locate the top level of this repo checkout starting at |path|."""
186    if path is None:
187        path = os.getcwd()
188    orig_path = path
189
190    path = os.path.abspath(path)
191    while not os.path.exists(os.path.join(path, '.repo')):
192        path = os.path.dirname(path)
193        if path == '/':
194            raise ValueError('Could not locate .repo in %s' % orig_path)
195
196    return path
197
198
199def is_git_repository(path):
200    """Returns True if the path is a valid git repository."""
201    cmd = ['git', 'rev-parse', '--resolve-git-dir', os.path.join(path, '.git')]
202    result = rh.utils.run(cmd, capture_output=True, check=False)
203    return result.returncode == 0
204