1#!/usr/bin/env python3 2 3# 4# Copyright (C) 2018 The Android Open Source Project 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17# 18 19"""A command line utility to pull multiple change lists from Gerrit.""" 20 21from __future__ import print_function 22 23import argparse 24import collections 25import itertools 26import json 27import multiprocessing 28import os 29import os.path 30import re 31import sys 32import xml.dom.minidom 33 34from gerrit import ( 35 create_url_opener_from_args, find_gerrit_name, query_change_lists, run 36) 37from subprocess import PIPE 38 39try: 40 # pylint: disable=redefined-builtin 41 from __builtin__ import raw_input as input # PY2 42except ImportError: 43 pass 44 45try: 46 from shlex import quote as _sh_quote # PY3.3 47except ImportError: 48 # Shell language simple string pattern. If a string matches this pattern, 49 # it doesn't have to be quoted. 50 _SHELL_SIMPLE_PATTERN = re.compile('^[a-zA-Z90-9_./-]+$') 51 52 def _sh_quote(txt): 53 """Quote a string if it contains special characters.""" 54 return txt if _SHELL_SIMPLE_PATTERN.match(txt) else json.dumps(txt) 55 56 57if bytes is str: 58 def write_bytes(data, file): # PY2 59 """Write bytes to a file.""" 60 # pylint: disable=redefined-builtin 61 file.write(data) 62else: 63 def write_bytes(data, file): # PY3 64 """Write bytes to a file.""" 65 # pylint: disable=redefined-builtin 66 file.buffer.write(data) 67 68 69def _confirm(question, default, file=sys.stderr): 70 """Prompt a yes/no question and convert the answer to a boolean value.""" 71 # pylint: disable=redefined-builtin 72 answers = {'': default, 'y': True, 'yes': True, 'n': False, 'no': False} 73 suffix = '[Y/n] ' if default else ' [y/N] ' 74 while True: 75 file.write(question + suffix) 76 file.flush() 77 ans = answers.get(input().lower()) 78 if ans is not None: 79 return ans 80 81 82class ChangeList(object): 83 """A ChangeList to be checked out.""" 84 # pylint: disable=too-few-public-methods,too-many-instance-attributes 85 86 def __init__(self, project, fetch, commit_sha1, commit, change_list): 87 """Initialize a ChangeList instance.""" 88 # pylint: disable=too-many-arguments 89 90 self.project = project 91 self.number = change_list['_number'] 92 93 self.fetch = fetch 94 95 fetch_git = None 96 for protocol in ('http', 'sso', 'rpc'): 97 fetch_git = fetch.get(protocol) 98 if fetch_git: 99 break 100 101 if not fetch_git: 102 raise ValueError( 103 'unknown fetch protocols: ' + str(list(fetch.keys()))) 104 105 self.fetch_url = fetch_git['url'] 106 self.fetch_ref = fetch_git['ref'] 107 108 self.commit_sha1 = commit_sha1 109 self.commit = commit 110 self.parents = commit['parents'] 111 112 self.change_list = change_list 113 114 115 def is_merge(self): 116 """Check whether this change list a merge commit.""" 117 return len(self.parents) > 1 118 119 120def find_repo_top(curdir): 121 """Find the top directory for this git-repo source tree.""" 122 olddir = None 123 while curdir != olddir: 124 if os.path.exists(os.path.join(curdir, '.repo')): 125 return curdir 126 olddir = curdir 127 curdir = os.path.dirname(curdir) 128 raise ValueError('.repo dir not found') 129 130 131def build_project_name_dir_dict(manifest_name): 132 """Build the mapping from Gerrit project name to source tree project 133 directory path.""" 134 manifest_cmd = ['repo', 'manifest'] 135 if manifest_name: 136 manifest_cmd.extend(['-m', manifest_name]) 137 raw_manifest_xml = run(manifest_cmd, stdout=PIPE, check=True).stdout 138 139 manifest_xml = xml.dom.minidom.parseString(raw_manifest_xml) 140 project_dirs = {} 141 for project in manifest_xml.getElementsByTagName('project'): 142 name = project.getAttribute('name') 143 path = project.getAttribute('path') 144 if path: 145 project_dirs[name] = path 146 else: 147 project_dirs[name] = name 148 149 return project_dirs 150 151 152def group_and_sort_change_lists(change_lists): 153 """Build a dict that maps projects to a list of topologically sorted change 154 lists.""" 155 156 # Build a dict that map projects to dicts that map commits to changes. 157 projects = collections.defaultdict(dict) 158 for change_list in change_lists: 159 commit_sha1 = None 160 for commit_sha1, value in change_list['revisions'].items(): 161 fetch = value['fetch'] 162 commit = value['commit'] 163 164 if not commit_sha1: 165 raise ValueError('bad revision') 166 167 project = change_list['project'] 168 169 project_changes = projects[project] 170 if commit_sha1 in project_changes: 171 raise KeyError('repeated commit sha1 "{}" in project "{}"'.format( 172 commit_sha1, project)) 173 174 project_changes[commit_sha1] = ChangeList( 175 project, fetch, commit_sha1, commit, change_list) 176 177 # Sort all change lists in a project in post ordering. 178 def _sort_project_change_lists(changes): 179 visited_changes = set() 180 sorted_changes = [] 181 182 def _post_order_traverse(change): 183 visited_changes.add(change) 184 for parent in change.parents: 185 parent_change = changes.get(parent['commit']) 186 if parent_change and parent_change not in visited_changes: 187 _post_order_traverse(parent_change) 188 sorted_changes.append(change) 189 190 for change in sorted(changes.values(), key=lambda x: x.number): 191 if change not in visited_changes: 192 _post_order_traverse(change) 193 194 return sorted_changes 195 196 # Sort changes in each projects 197 sorted_changes = [] 198 for project in sorted(projects.keys()): 199 sorted_changes.append(_sort_project_change_lists(projects[project])) 200 201 return sorted_changes 202 203 204def _main_json(args): 205 """Print the change lists in JSON format.""" 206 change_lists = _get_change_lists_from_args(args) 207 json.dump(change_lists, sys.stdout, indent=4, separators=(', ', ': ')) 208 print() # Print the end-of-line 209 210 211# Git commands for merge commits 212_MERGE_COMMANDS = { 213 'merge': ['git', 'merge', '--no-edit'], 214 'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'], 215 'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'], 216 'reset': ['git', 'reset', '--hard'], 217 'checkout': ['git', 'checkout'], 218} 219 220 221# Git commands for non-merge commits 222_PICK_COMMANDS = { 223 'pick': ['git', 'cherry-pick', '--allow-empty'], 224 'merge': ['git', 'merge', '--no-edit'], 225 'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'], 226 'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'], 227 'reset': ['git', 'reset', '--hard'], 228 'checkout': ['git', 'checkout'], 229} 230 231 232def build_pull_commands(change, branch_name, merge_opt, pick_opt): 233 """Build command lines for each change. The command lines will be passed 234 to subprocess.run().""" 235 236 cmds = [] 237 if branch_name is not None: 238 cmds.append(['repo', 'start', branch_name]) 239 cmds.append(['git', 'fetch', change.fetch_url, change.fetch_ref]) 240 if change.is_merge(): 241 cmds.append(_MERGE_COMMANDS[merge_opt] + ['FETCH_HEAD']) 242 else: 243 cmds.append(_PICK_COMMANDS[pick_opt] + ['FETCH_HEAD']) 244 return cmds 245 246 247def _sh_quote_command(cmd): 248 """Convert a command (an argument to subprocess.run()) to a shell command 249 string.""" 250 return ' '.join(_sh_quote(x) for x in cmd) 251 252 253def _sh_quote_commands(cmds): 254 """Convert multiple commands (arguments to subprocess.run()) to shell 255 command strings.""" 256 return ' && '.join(_sh_quote_command(cmd) for cmd in cmds) 257 258 259def _main_bash(args): 260 """Print the bash command to pull the change lists.""" 261 repo_top = find_repo_top(os.getcwd()) 262 project_dirs = build_project_name_dir_dict(args.manifest) 263 branch_name = _get_local_branch_name_from_args(args) 264 265 change_lists = _get_change_lists_from_args(args) 266 change_list_groups = group_and_sort_change_lists(change_lists) 267 268 print(_sh_quote_command(['pushd', repo_top])) 269 for changes in change_list_groups: 270 for change in changes: 271 project_dir = project_dirs.get(change.project, change.project) 272 cmds = [] 273 cmds.append(['pushd', project_dir]) 274 cmds.extend(build_pull_commands( 275 change, branch_name, args.merge, args.pick)) 276 cmds.append(['popd']) 277 print(_sh_quote_commands(cmds)) 278 print(_sh_quote_command(['popd'])) 279 280 281def _do_pull_change_lists_for_project(task): 282 """Pick a list of changes (usually under a project directory).""" 283 changes, task_opts = task 284 285 branch_name = task_opts['branch_name'] 286 merge_opt = task_opts['merge_opt'] 287 pick_opt = task_opts['pick_opt'] 288 project_dirs = task_opts['project_dirs'] 289 repo_top = task_opts['repo_top'] 290 291 for i, change in enumerate(changes): 292 try: 293 cwd = project_dirs[change.project] 294 except KeyError: 295 err_msg = 'error: project "{}" cannot be found in manifest.xml\n' 296 err_msg = err_msg.format(change.project).encode('utf-8') 297 return (change, changes[i + 1:], [], err_msg) 298 299 print(change.commit_sha1[0:10], i + 1, cwd) 300 cmds = build_pull_commands(change, branch_name, merge_opt, pick_opt) 301 for cmd in cmds: 302 proc = run(cmd, cwd=os.path.join(repo_top, cwd), stderr=PIPE) 303 if proc.returncode != 0: 304 return (change, changes[i + 1:], cmd, proc.stderr) 305 return None 306 307 308def _print_pull_failures(failures, file=sys.stderr): 309 """Print pull failures and tracebacks.""" 310 # pylint: disable=redefined-builtin 311 312 separator = '=' * 78 313 separator_sub = '-' * 78 314 315 print(separator, file=file) 316 for failed_change, skipped_changes, cmd, errors in failures: 317 print('PROJECT:', failed_change.project, file=file) 318 print('FAILED COMMIT:', failed_change.commit_sha1, file=file) 319 for change in skipped_changes: 320 print('PENDING COMMIT:', change.commit_sha1, file=file) 321 print(separator_sub, file=sys.stderr) 322 print('FAILED COMMAND:', _sh_quote_command(cmd), file=file) 323 write_bytes(errors, file=sys.stderr) 324 print(separator, file=sys.stderr) 325 326 327def _main_pull(args): 328 """Pull the change lists.""" 329 repo_top = find_repo_top(os.getcwd()) 330 project_dirs = build_project_name_dir_dict(args.manifest) 331 branch_name = _get_local_branch_name_from_args(args) 332 333 # Collect change lists 334 change_lists = _get_change_lists_from_args(args) 335 change_list_groups = group_and_sort_change_lists(change_lists) 336 337 # Build the options list for tasks 338 task_opts = { 339 'branch_name': branch_name, 340 'merge_opt': args.merge, 341 'pick_opt': args.pick, 342 'project_dirs': project_dirs, 343 'repo_top': repo_top, 344 } 345 346 # Run the commands to pull the change lists 347 if args.parallel <= 1: 348 results = [_do_pull_change_lists_for_project((changes, task_opts)) 349 for changes in change_list_groups] 350 else: 351 pool = multiprocessing.Pool(processes=args.parallel) 352 results = pool.map(_do_pull_change_lists_for_project, 353 zip(change_list_groups, itertools.repeat(task_opts))) 354 355 # Print failures and tracebacks 356 failures = [result for result in results if result] 357 if failures: 358 _print_pull_failures(failures) 359 sys.exit(1) 360 361 362def _parse_args(): 363 """Parse command line options.""" 364 parser = argparse.ArgumentParser() 365 366 parser.add_argument('command', choices=['pull', 'bash', 'json'], 367 help='Commands') 368 369 parser.add_argument('query', help='Change list query string') 370 parser.add_argument('-g', '--gerrit', help='Gerrit review URL') 371 372 parser.add_argument('--gitcookies', 373 default=os.path.expanduser('~/.gitcookies'), 374 help='Gerrit cookie file') 375 parser.add_argument('--manifest', help='Manifest') 376 parser.add_argument('--limits', default=1000, 377 help='Max number of change lists') 378 379 parser.add_argument('-m', '--merge', 380 choices=sorted(_MERGE_COMMANDS.keys()), 381 default='merge-ff-only', 382 help='Method to pull merge commits') 383 384 parser.add_argument('-p', '--pick', 385 choices=sorted(_PICK_COMMANDS.keys()), 386 default='pick', 387 help='Method to pull merge commits') 388 389 parser.add_argument('-b', '--branch', 390 help='Local branch name for `repo start`') 391 392 parser.add_argument('-j', '--parallel', default=1, type=int, 393 help='Number of parallel running commands') 394 395 return parser.parse_args() 396 397 398def _get_change_lists_from_args(args): 399 """Query the change lists by args.""" 400 url_opener = create_url_opener_from_args(args) 401 return query_change_lists(url_opener, args.gerrit, args.query, args.limits) 402 403 404def _get_local_branch_name_from_args(args): 405 """Get the local branch name from args.""" 406 if not args.branch and not _confirm( 407 'Do you want to continue without local branch name?', False): 408 print('error: `-b` or `--branch` must be specified', file=sys.stderr) 409 sys.exit(1) 410 return args.branch 411 412 413def main(): 414 """Main function""" 415 args = _parse_args() 416 417 if not args.gerrit: 418 try: 419 args.gerrit = find_gerrit_name() 420 # pylint: disable=bare-except 421 except: 422 print('gerrit instance not found, use [-g GERRIT]') 423 sys.exit(1) 424 425 if args.command == 'json': 426 _main_json(args) 427 elif args.command == 'bash': 428 _main_bash(args) 429 elif args.command == 'pull': 430 _main_pull(args) 431 else: 432 raise KeyError('unknown command') 433 434if __name__ == '__main__': 435 main() 436