• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Copyright 2010 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Base class for gsutil commands.
16
17In addition to base class code, this file contains helpers that depend on base
18class state (such as GetAndPrintAcl) In general, functions that depend on
19class state and that are used by multiple commands belong in this file.
20Functions that don't depend on class state belong in util.py, and non-shared
21helpers belong in individual subclasses.
22"""
23
24from __future__ import absolute_import
25
26import codecs
27from collections import namedtuple
28import copy
29import getopt
30import logging
31import multiprocessing
32import os
33import Queue
34import signal
35import sys
36import textwrap
37import threading
38import traceback
39
40import boto
41from boto.storage_uri import StorageUri
42import gslib
43from gslib.cloud_api import AccessDeniedException
44from gslib.cloud_api import ArgumentException
45from gslib.cloud_api import ServiceException
46from gslib.cloud_api_delegator import CloudApiDelegator
47from gslib.cs_api_map import ApiSelector
48from gslib.cs_api_map import GsutilApiMapFactory
49from gslib.exception import CommandException
50from gslib.help_provider import HelpProvider
51from gslib.name_expansion import NameExpansionIterator
52from gslib.name_expansion import NameExpansionResult
53from gslib.parallelism_framework_util import AtomicDict
54from gslib.plurality_checkable_iterator import PluralityCheckableIterator
55from gslib.sig_handling import RegisterSignalHandler
56from gslib.storage_url import StorageUrlFromString
57from gslib.third_party.storage_apitools import storage_v1_messages as apitools_messages
58from gslib.translation_helper import AclTranslation
59from gslib.translation_helper import PRIVATE_DEFAULT_OBJ_ACL
60from gslib.util import CheckMultiprocessingAvailableAndInit
61from gslib.util import GetConfigFilePath
62from gslib.util import GsutilStreamHandler
63from gslib.util import HaveFileUrls
64from gslib.util import HaveProviderUrls
65from gslib.util import IS_WINDOWS
66from gslib.util import NO_MAX
67from gslib.util import UrlsAreForSingleProvider
68from gslib.util import UTF8
69from gslib.wildcard_iterator import CreateWildcardIterator
70
71OFFER_GSUTIL_M_SUGGESTION_THRESHOLD = 5
72
73if IS_WINDOWS:
74  import ctypes  # pylint: disable=g-import-not-at-top
75
76
77def _DefaultExceptionHandler(cls, e):
78  cls.logger.exception(e)
79
80
81def CreateGsutilLogger(command_name):
82  """Creates a logger that resembles 'print' output.
83
84  This logger abides by gsutil -d/-D/-DD/-q options.
85
86  By default (if none of the above options is specified) the logger will display
87  all messages logged with level INFO or above. Log propagation is disabled.
88
89  Args:
90    command_name: Command name to create logger for.
91
92  Returns:
93    A logger object.
94  """
95  log = logging.getLogger(command_name)
96  log.propagate = False
97  log.setLevel(logging.root.level)
98  log_handler = GsutilStreamHandler()
99  log_handler.setFormatter(logging.Formatter('%(message)s'))
100  # Commands that call other commands (like mv) would cause log handlers to be
101  # added more than once, so avoid adding if one is already present.
102  if not log.handlers:
103    log.addHandler(log_handler)
104  return log
105
106
107def _UrlArgChecker(command_instance, url):
108  if not command_instance.exclude_symlinks:
109    return True
110  exp_src_url = url.expanded_storage_url
111  if exp_src_url.IsFileUrl() and os.path.islink(exp_src_url.object_name):
112    command_instance.logger.info('Skipping symbolic link %s...', exp_src_url)
113    return False
114  return True
115
116
117def DummyArgChecker(*unused_args):
118  return True
119
120
121def SetAclFuncWrapper(cls, name_expansion_result, thread_state=None):
122  return cls.SetAclFunc(name_expansion_result, thread_state=thread_state)
123
124
125def SetAclExceptionHandler(cls, e):
126  """Exception handler that maintains state about post-completion status."""
127  cls.logger.error(str(e))
128  cls.everything_set_okay = False
129
130# We will keep this list of all thread- or process-safe queues ever created by
131# the main thread so that we can forcefully kill them upon shutdown. Otherwise,
132# we encounter a Python bug in which empty queues block forever on join (which
133# is called as part of the Python exit function cleanup) under the impression
134# that they are non-empty.
135# However, this also lets us shut down somewhat more cleanly when interrupted.
136queues = []
137
138
139def _NewMultiprocessingQueue():
140  queue = multiprocessing.Queue(MAX_QUEUE_SIZE)
141  queues.append(queue)
142  return queue
143
144
145def _NewThreadsafeQueue():
146  queue = Queue.Queue(MAX_QUEUE_SIZE)
147  queues.append(queue)
148  return queue
149
150# The maximum size of a process- or thread-safe queue. Imposing this limit
151# prevents us from needing to hold an arbitrary amount of data in memory.
152# However, setting this number too high (e.g., >= 32768 on OS X) can cause
153# problems on some operating systems.
154MAX_QUEUE_SIZE = 32500
155
156# That maximum depth of the tree of recursive calls to command.Apply. This is
157# an arbitrary limit put in place to prevent developers from accidentally
158# causing problems with infinite recursion, and it can be increased if needed.
159MAX_RECURSIVE_DEPTH = 5
160
161ZERO_TASKS_TO_DO_ARGUMENT = ('There were no', 'tasks to do')
162
163# Map from deprecated aliases to the current command and subcommands that
164# provide the same behavior.
165# TODO: Remove this map and deprecate old commands on 9/9/14.
166OLD_ALIAS_MAP = {'chacl': ['acl', 'ch'],
167                 'getacl': ['acl', 'get'],
168                 'setacl': ['acl', 'set'],
169                 'getcors': ['cors', 'get'],
170                 'setcors': ['cors', 'set'],
171                 'chdefacl': ['defacl', 'ch'],
172                 'getdefacl': ['defacl', 'get'],
173                 'setdefacl': ['defacl', 'set'],
174                 'disablelogging': ['logging', 'set', 'off'],
175                 'enablelogging': ['logging', 'set', 'on'],
176                 'getlogging': ['logging', 'get'],
177                 'getversioning': ['versioning', 'get'],
178                 'setversioning': ['versioning', 'set'],
179                 'getwebcfg': ['web', 'get'],
180                 'setwebcfg': ['web', 'set']}
181
182
183# Declare all of the module level variables - see
184# InitializeMultiprocessingVariables for an explanation of why this is
185# necessary.
186# pylint: disable=global-at-module-level
187global manager, consumer_pools, task_queues, caller_id_lock, caller_id_counter
188global total_tasks, call_completed_map, global_return_values_map
189global need_pool_or_done_cond, caller_id_finished_count, new_pool_needed
190global current_max_recursive_level, shared_vars_map, shared_vars_list_map
191global class_map, worker_checking_level_lock, failure_count
192
193
194def InitializeMultiprocessingVariables():
195  """Initializes module-level variables that will be inherited by subprocesses.
196
197  On Windows, a multiprocessing.Manager object should only
198  be created within an "if __name__ == '__main__':" block. This function
199  must be called, otherwise every command that calls Command.Apply will fail.
200  """
201  # This list of global variables must exactly match the above list of
202  # declarations.
203  # pylint: disable=global-variable-undefined
204  global manager, consumer_pools, task_queues, caller_id_lock, caller_id_counter
205  global total_tasks, call_completed_map, global_return_values_map
206  global need_pool_or_done_cond, caller_id_finished_count, new_pool_needed
207  global current_max_recursive_level, shared_vars_map, shared_vars_list_map
208  global class_map, worker_checking_level_lock, failure_count
209
210  manager = multiprocessing.Manager()
211
212  consumer_pools = []
213
214  # List of all existing task queues - used by all pools to find the queue
215  # that's appropriate for the given recursive_apply_level.
216  task_queues = []
217
218  # Used to assign a globally unique caller ID to each Apply call.
219  caller_id_lock = manager.Lock()
220  caller_id_counter = multiprocessing.Value('i', 0)
221
222  # Map from caller_id to total number of tasks to be completed for that ID.
223  total_tasks = AtomicDict(manager=manager)
224
225  # Map from caller_id to a boolean which is True iff all its tasks are
226  # finished.
227  call_completed_map = AtomicDict(manager=manager)
228
229  # Used to keep track of the set of return values for each caller ID.
230  global_return_values_map = AtomicDict(manager=manager)
231
232  # Condition used to notify any waiting threads that a task has finished or
233  # that a call to Apply needs a new set of consumer processes.
234  need_pool_or_done_cond = manager.Condition()
235
236  # Lock used to prevent multiple worker processes from asking the main thread
237  # to create a new consumer pool for the same level.
238  worker_checking_level_lock = manager.Lock()
239
240  # Map from caller_id to the current number of completed tasks for that ID.
241  caller_id_finished_count = AtomicDict(manager=manager)
242
243  # Used as a way for the main thread to distinguish between being woken up
244  # by another call finishing and being woken up by a call that needs a new set
245  # of consumer processes.
246  new_pool_needed = multiprocessing.Value('i', 0)
247
248  current_max_recursive_level = multiprocessing.Value('i', 0)
249
250  # Map from (caller_id, name) to the value of that shared variable.
251  shared_vars_map = AtomicDict(manager=manager)
252  shared_vars_list_map = AtomicDict(manager=manager)
253
254  # Map from caller_id to calling class.
255  class_map = manager.dict()
256
257  # Number of tasks that resulted in an exception in calls to Apply().
258  failure_count = multiprocessing.Value('i', 0)
259
260
261def InitializeThreadingVariables():
262  """Initializes module-level variables used when running multi-threaded.
263
264  When multiprocessing is not available (or on Windows where only 1 process
265  is used), thread-safe analogs to the multiprocessing global variables
266  must be initialized. This function is the thread-safe analog to
267  InitializeMultiprocessingVariables.
268  """
269  # pylint: disable=global-variable-undefined
270  global global_return_values_map, shared_vars_map, failure_count
271  global caller_id_finished_count, shared_vars_list_map, total_tasks
272  global need_pool_or_done_cond, call_completed_map, class_map
273  global task_queues, caller_id_lock, caller_id_counter
274  caller_id_counter = 0
275  caller_id_finished_count = AtomicDict()
276  caller_id_lock = threading.Lock()
277  call_completed_map = AtomicDict()
278  class_map = AtomicDict()
279  failure_count = 0
280  global_return_values_map = AtomicDict()
281  need_pool_or_done_cond = threading.Condition()
282  shared_vars_list_map = AtomicDict()
283  shared_vars_map = AtomicDict()
284  task_queues = []
285  total_tasks = AtomicDict()
286
287
288# Each subclass of Command must define a property named 'command_spec' that is
289# an instance of the following class.
290CommandSpec = namedtuple('CommandSpec', [
291    # Name of command.
292    'command_name',
293    # Usage synopsis.
294    'usage_synopsis',
295    # List of command name aliases.
296    'command_name_aliases',
297    # Min number of args required by this command.
298    'min_args',
299    # Max number of args required by this command, or NO_MAX.
300    'max_args',
301    # Getopt-style string specifying acceptable sub args.
302    'supported_sub_args',
303    # True if file URLs are acceptable for this command.
304    'file_url_ok',
305    # True if provider-only URLs are acceptable for this command.
306    'provider_url_ok',
307    # Index in args of first URL arg.
308    'urls_start_arg',
309    # List of supported APIs
310    'gs_api_support',
311    # Default API to use for this command
312    'gs_default_api',
313    # Private arguments (for internal testing)
314    'supported_private_args',
315    'argparse_arguments',
316])
317
318
319class Command(HelpProvider):
320  """Base class for all gsutil commands."""
321
322  # Each subclass must override this with an instance of CommandSpec.
323  command_spec = None
324
325  _commands_with_subcommands_and_subopts = ['acl', 'defacl', 'logging', 'web',
326                                            'notification']
327
328  # This keeps track of the recursive depth of the current call to Apply.
329  recursive_apply_level = 0
330
331  # If the multiprocessing module isn't available, we'll use this to keep track
332  # of the caller_id.
333  sequential_caller_id = -1
334
335  @staticmethod
336  def CreateCommandSpec(command_name, usage_synopsis=None,
337                        command_name_aliases=None, min_args=0,
338                        max_args=NO_MAX, supported_sub_args='',
339                        file_url_ok=False, provider_url_ok=False,
340                        urls_start_arg=0, gs_api_support=None,
341                        gs_default_api=None, supported_private_args=None,
342                        argparse_arguments=None):
343    """Creates an instance of CommandSpec, with defaults."""
344    return CommandSpec(
345        command_name=command_name,
346        usage_synopsis=usage_synopsis,
347        command_name_aliases=command_name_aliases or [],
348        min_args=min_args,
349        max_args=max_args,
350        supported_sub_args=supported_sub_args,
351        file_url_ok=file_url_ok,
352        provider_url_ok=provider_url_ok,
353        urls_start_arg=urls_start_arg,
354        gs_api_support=gs_api_support or [ApiSelector.XML],
355        gs_default_api=gs_default_api or ApiSelector.XML,
356        supported_private_args=supported_private_args,
357        argparse_arguments=argparse_arguments or [])
358
359  # Define a convenience property for command name, since it's used many places.
360  def _GetDefaultCommandName(self):
361    return self.command_spec.command_name
362  command_name = property(_GetDefaultCommandName)
363
364  def _CalculateUrlsStartArg(self):
365    """Calculate the index in args of the first URL arg.
366
367    Returns:
368      Index of the first URL arg (according to the command spec).
369    """
370    return self.command_spec.urls_start_arg
371
372  def _TranslateDeprecatedAliases(self, args):
373    """Map deprecated aliases to the corresponding new command, and warn."""
374    new_command_args = OLD_ALIAS_MAP.get(self.command_alias_used, None)
375    if new_command_args:
376      # Prepend any subcommands for the new command. The command name itself
377      # is not part of the args, so leave it out.
378      args = new_command_args[1:] + args
379      self.logger.warn('\n'.join(textwrap.wrap(
380          ('You are using a deprecated alias, "%(used_alias)s", for the '
381           '"%(command_name)s" command. This will stop working on 9/9/2014. '
382           'Please use "%(command_name)s" with the appropriate sub-command in '
383           'the future. See "gsutil help %(command_name)s" for details.') %
384          {'used_alias': self.command_alias_used,
385           'command_name': self.command_name})))
386    return args
387
388  def __init__(self, command_runner, args, headers, debug, trace_token,
389               parallel_operations, bucket_storage_uri_class,
390               gsutil_api_class_map_factory, logging_filters=None,
391               command_alias_used=None):
392    """Instantiates a Command.
393
394    Args:
395      command_runner: CommandRunner (for commands built atop other commands).
396      args: Command-line args (arg0 = actual arg, not command name ala bash).
397      headers: Dictionary containing optional HTTP headers to pass to boto.
398      debug: Debug level to pass in to boto connection (range 0..3).
399      trace_token: Trace token to pass to the API implementation.
400      parallel_operations: Should command operations be executed in parallel?
401      bucket_storage_uri_class: Class to instantiate for cloud StorageUris.
402                                Settable for testing/mocking.
403      gsutil_api_class_map_factory: Creates map of cloud storage interfaces.
404                                    Settable for testing/mocking.
405      logging_filters: Optional list of logging. Filters to apply to this
406                       command's logger.
407      command_alias_used: The alias that was actually used when running this
408                          command (as opposed to the "official" command name,
409                          which will always correspond to the file name).
410
411    Implementation note: subclasses shouldn't need to define an __init__
412    method, and instead depend on the shared initialization that happens
413    here. If you do define an __init__ method in a subclass you'll need to
414    explicitly call super().__init__(). But you're encouraged not to do this,
415    because it will make changing the __init__ interface more painful.
416    """
417    # Save class values from constructor params.
418    self.command_runner = command_runner
419    self.unparsed_args = args
420    self.headers = headers
421    self.debug = debug
422    self.trace_token = trace_token
423    self.parallel_operations = parallel_operations
424    self.bucket_storage_uri_class = bucket_storage_uri_class
425    self.gsutil_api_class_map_factory = gsutil_api_class_map_factory
426    self.exclude_symlinks = False
427    self.recursion_requested = False
428    self.all_versions = False
429    self.command_alias_used = command_alias_used
430
431    # Global instance of a threaded logger object.
432    self.logger = CreateGsutilLogger(self.command_name)
433    if logging_filters:
434      for log_filter in logging_filters:
435        self.logger.addFilter(log_filter)
436
437    if self.command_spec is None:
438      raise CommandException('"%s" command implementation is missing a '
439                             'command_spec definition.' % self.command_name)
440
441    # Parse and validate args.
442    self.args = self._TranslateDeprecatedAliases(args)
443    self.ParseSubOpts()
444
445    # Named tuple public functions start with _
446    # pylint: disable=protected-access
447    self.command_spec = self.command_spec._replace(
448        urls_start_arg=self._CalculateUrlsStartArg())
449
450    if (len(self.args) < self.command_spec.min_args
451        or len(self.args) > self.command_spec.max_args):
452      self.RaiseWrongNumberOfArgumentsException()
453
454    if self.command_name not in self._commands_with_subcommands_and_subopts:
455      self.CheckArguments()
456
457    # Build the support and default maps from the command spec.
458    support_map = {
459        'gs': self.command_spec.gs_api_support,
460        's3': [ApiSelector.XML]
461    }
462    default_map = {
463        'gs': self.command_spec.gs_default_api,
464        's3': ApiSelector.XML
465    }
466    self.gsutil_api_map = GsutilApiMapFactory.GetApiMap(
467        self.gsutil_api_class_map_factory, support_map, default_map)
468
469    self.project_id = None
470    self.gsutil_api = CloudApiDelegator(
471        bucket_storage_uri_class, self.gsutil_api_map,
472        self.logger, debug=self.debug, trace_token=self.trace_token)
473
474    # Cross-platform path to run gsutil binary.
475    self.gsutil_cmd = ''
476    # If running on Windows, invoke python interpreter explicitly.
477    if gslib.util.IS_WINDOWS:
478      self.gsutil_cmd += 'python '
479    # Add full path to gsutil to make sure we test the correct version.
480    self.gsutil_path = gslib.GSUTIL_PATH
481    self.gsutil_cmd += self.gsutil_path
482
483    # We're treating recursion_requested like it's used by all commands, but
484    # only some of the commands accept the -R option.
485    if self.sub_opts:
486      for o, unused_a in self.sub_opts:
487        if o == '-r' or o == '-R':
488          self.recursion_requested = True
489          break
490
491    self.multiprocessing_is_available = (
492        CheckMultiprocessingAvailableAndInit().is_available)
493
494  def RaiseWrongNumberOfArgumentsException(self):
495    """Raises exception for wrong number of arguments supplied to command."""
496    if len(self.args) < self.command_spec.min_args:
497      tail_str = 's' if self.command_spec.min_args > 1 else ''
498      message = ('The %s command requires at least %d argument%s.' %
499                 (self.command_name, self.command_spec.min_args, tail_str))
500    else:
501      message = ('The %s command accepts at most %d arguments.' %
502                 (self.command_name, self.command_spec.max_args))
503    message += ' Usage:\n%s\nFor additional help run:\n  gsutil help %s' % (
504        self.command_spec.usage_synopsis, self.command_name)
505    raise CommandException(message)
506
507  def RaiseInvalidArgumentException(self):
508    """Raises exception for specifying an invalid argument to command."""
509    message = ('Incorrect option(s) specified. Usage:\n%s\n'
510               'For additional help run:\n  gsutil help %s' % (
511                   self.command_spec.usage_synopsis, self.command_name))
512    raise CommandException(message)
513
514  def ParseSubOpts(self, check_args=False):
515    """Parses sub-opt args.
516
517    Args:
518      check_args: True to have CheckArguments() called after parsing.
519
520    Populates:
521      (self.sub_opts, self.args) from parsing.
522
523    Raises: RaiseInvalidArgumentException if invalid args specified.
524    """
525    try:
526      self.sub_opts, self.args = getopt.getopt(
527          self.args, self.command_spec.supported_sub_args,
528          self.command_spec.supported_private_args or [])
529    except getopt.GetoptError:
530      self.RaiseInvalidArgumentException()
531    if check_args:
532      self.CheckArguments()
533
534  def CheckArguments(self):
535    """Checks that command line arguments match the command_spec.
536
537    Any commands in self._commands_with_subcommands_and_subopts are responsible
538    for calling this method after handling initial parsing of their arguments.
539    This prevents commands with sub-commands as well as options from breaking
540    the parsing of getopt.
541
542    TODO: Provide a function to parse commands and sub-commands more
543    intelligently once we stop allowing the deprecated command versions.
544
545    Raises:
546      CommandException if the arguments don't match.
547    """
548
549    if (not self.command_spec.file_url_ok
550        and HaveFileUrls(self.args[self.command_spec.urls_start_arg:])):
551      raise CommandException('"%s" command does not support "file://" URLs. '
552                             'Did you mean to use a gs:// URL?' %
553                             self.command_name)
554    if (not self.command_spec.provider_url_ok
555        and HaveProviderUrls(self.args[self.command_spec.urls_start_arg:])):
556      raise CommandException('"%s" command does not support provider-only '
557                             'URLs.' % self.command_name)
558
559  def WildcardIterator(self, url_string, all_versions=False):
560    """Helper to instantiate gslib.WildcardIterator.
561
562    Args are same as gslib.WildcardIterator interface, but this method fills in
563    most of the values from instance state.
564
565    Args:
566      url_string: URL string naming wildcard objects to iterate.
567      all_versions: If true, the iterator yields all versions of objects
568                    matching the wildcard.  If false, yields just the live
569                    object version.
570
571    Returns:
572      WildcardIterator for use by caller.
573    """
574    return CreateWildcardIterator(
575        url_string, self.gsutil_api, all_versions=all_versions,
576        debug=self.debug, project_id=self.project_id)
577
578  def RunCommand(self):
579    """Abstract function in base class. Subclasses must implement this.
580
581    The return value of this function will be used as the exit status of the
582    process, so subclass commands should return an integer exit code (0 for
583    success, a value in [1,255] for failure).
584    """
585    raise CommandException('Command %s is missing its RunCommand() '
586                           'implementation' % self.command_name)
587
588  ############################################################
589  # Shared helper functions that depend on base class state. #
590  ############################################################
591
592  def ApplyAclFunc(self, acl_func, acl_excep_handler, url_strs):
593    """Sets the standard or default object ACL depending on self.command_name.
594
595    Args:
596      acl_func: ACL function to be passed to Apply.
597      acl_excep_handler: ACL exception handler to be passed to Apply.
598      url_strs: URL strings on which to set ACL.
599
600    Raises:
601      CommandException if an ACL could not be set.
602    """
603    multi_threaded_url_args = []
604    # Handle bucket ACL setting operations single-threaded, because
605    # our threading machinery currently assumes it's working with objects
606    # (name_expansion_iterator), and normally we wouldn't expect users to need
607    # to set ACLs on huge numbers of buckets at once anyway.
608    for url_str in url_strs:
609      url = StorageUrlFromString(url_str)
610      if url.IsCloudUrl() and url.IsBucket():
611        if self.recursion_requested:
612          # If user specified -R option, convert any bucket args to bucket
613          # wildcards (e.g., gs://bucket/*), to prevent the operation from
614          # being applied to the buckets themselves.
615          url.object_name = '*'
616          multi_threaded_url_args.append(url.url_string)
617        else:
618          # Convert to a NameExpansionResult so we can re-use the threaded
619          # function for the single-threaded implementation.  RefType is unused.
620          for blr in self.WildcardIterator(url.url_string).IterBuckets(
621              bucket_fields=['id']):
622            name_expansion_for_url = NameExpansionResult(
623                url, False, False, blr.storage_url)
624            acl_func(self, name_expansion_for_url)
625      else:
626        multi_threaded_url_args.append(url_str)
627
628    if len(multi_threaded_url_args) >= 1:
629      name_expansion_iterator = NameExpansionIterator(
630          self.command_name, self.debug,
631          self.logger, self.gsutil_api,
632          multi_threaded_url_args, self.recursion_requested,
633          all_versions=self.all_versions,
634          continue_on_error=self.continue_on_error or self.parallel_operations)
635
636      # Perform requests in parallel (-m) mode, if requested, using
637      # configured number of parallel processes and threads. Otherwise,
638      # perform requests with sequential function calls in current process.
639      self.Apply(acl_func, name_expansion_iterator, acl_excep_handler,
640                 fail_on_error=not self.continue_on_error)
641
642    if not self.everything_set_okay and not self.continue_on_error:
643      raise CommandException('ACLs for some objects could not be set.')
644
645  def SetAclFunc(self, name_expansion_result, thread_state=None):
646    """Sets the object ACL for the name_expansion_result provided.
647
648    Args:
649      name_expansion_result: NameExpansionResult describing the target object.
650      thread_state: If present, use this gsutil Cloud API instance for the set.
651    """
652    if thread_state:
653      assert not self.def_acl
654      gsutil_api = thread_state
655    else:
656      gsutil_api = self.gsutil_api
657    op_string = 'default object ACL' if self.def_acl else 'ACL'
658    url = name_expansion_result.expanded_storage_url
659    self.logger.info('Setting %s on %s...', op_string, url)
660    if (gsutil_api.GetApiSelector(url.scheme) == ApiSelector.XML
661        and url.scheme != 'gs'):
662      # If we are called with a non-google ACL model, we need to use the XML
663      # passthrough. acl_arg should either be a canned ACL or an XML ACL.
664      self._SetAclXmlPassthrough(url, gsutil_api)
665    else:
666      # Normal Cloud API path. acl_arg is a JSON ACL or a canned ACL.
667      self._SetAclGsutilApi(url, gsutil_api)
668
669  def _SetAclXmlPassthrough(self, url, gsutil_api):
670    """Sets the ACL for the URL provided using the XML passthrough functions.
671
672    This function assumes that self.def_acl, self.canned,
673    and self.continue_on_error are initialized, and that self.acl_arg is
674    either an XML string or a canned ACL string.
675
676    Args:
677      url: CloudURL to set the ACL on.
678      gsutil_api: gsutil Cloud API to use for the ACL set. Must support XML
679          passthrough functions.
680    """
681    try:
682      orig_prefer_api = gsutil_api.prefer_api
683      gsutil_api.prefer_api = ApiSelector.XML
684      gsutil_api.XmlPassThroughSetAcl(
685          self.acl_arg, url, canned=self.canned,
686          def_obj_acl=self.def_acl, provider=url.scheme)
687    except ServiceException as e:
688      if self.continue_on_error:
689        self.everything_set_okay = False
690        self.logger.error(e)
691      else:
692        raise
693    finally:
694      gsutil_api.prefer_api = orig_prefer_api
695
696  def _SetAclGsutilApi(self, url, gsutil_api):
697    """Sets the ACL for the URL provided using the gsutil Cloud API.
698
699    This function assumes that self.def_acl, self.canned,
700    and self.continue_on_error are initialized, and that self.acl_arg is
701    either a JSON string or a canned ACL string.
702
703    Args:
704      url: CloudURL to set the ACL on.
705      gsutil_api: gsutil Cloud API to use for the ACL set.
706    """
707    try:
708      if url.IsBucket():
709        if self.def_acl:
710          if self.canned:
711            gsutil_api.PatchBucket(
712                url.bucket_name, apitools_messages.Bucket(),
713                canned_def_acl=self.acl_arg, provider=url.scheme, fields=['id'])
714          else:
715            def_obj_acl = AclTranslation.JsonToMessage(
716                self.acl_arg, apitools_messages.ObjectAccessControl)
717            if not def_obj_acl:
718              # Use a sentinel value to indicate a private (no entries) default
719              # object ACL.
720              def_obj_acl.append(PRIVATE_DEFAULT_OBJ_ACL)
721            bucket_metadata = apitools_messages.Bucket(
722                defaultObjectAcl=def_obj_acl)
723            gsutil_api.PatchBucket(url.bucket_name, bucket_metadata,
724                                   provider=url.scheme, fields=['id'])
725        else:
726          if self.canned:
727            gsutil_api.PatchBucket(
728                url.bucket_name, apitools_messages.Bucket(),
729                canned_acl=self.acl_arg, provider=url.scheme, fields=['id'])
730          else:
731            bucket_acl = AclTranslation.JsonToMessage(
732                self.acl_arg, apitools_messages.BucketAccessControl)
733            bucket_metadata = apitools_messages.Bucket(acl=bucket_acl)
734            gsutil_api.PatchBucket(url.bucket_name, bucket_metadata,
735                                   provider=url.scheme, fields=['id'])
736      else:  # url.IsObject()
737        if self.canned:
738          gsutil_api.PatchObjectMetadata(
739              url.bucket_name, url.object_name, apitools_messages.Object(),
740              provider=url.scheme, generation=url.generation,
741              canned_acl=self.acl_arg)
742        else:
743          object_acl = AclTranslation.JsonToMessage(
744              self.acl_arg, apitools_messages.ObjectAccessControl)
745          object_metadata = apitools_messages.Object(acl=object_acl)
746          gsutil_api.PatchObjectMetadata(url.bucket_name, url.object_name,
747                                         object_metadata, provider=url.scheme,
748                                         generation=url.generation)
749    except ArgumentException, e:
750      raise
751    except ServiceException, e:
752      if self.continue_on_error:
753        self.everything_set_okay = False
754        self.logger.error(e)
755      else:
756        raise
757
758  def SetAclCommandHelper(self, acl_func, acl_excep_handler):
759    """Sets ACLs on the self.args using the passed-in acl function.
760
761    Args:
762      acl_func: ACL function to be passed to Apply.
763      acl_excep_handler: ACL exception handler to be passed to Apply.
764    """
765    acl_arg = self.args[0]
766    url_args = self.args[1:]
767    # Disallow multi-provider setacl requests, because there are differences in
768    # the ACL models.
769    if not UrlsAreForSingleProvider(url_args):
770      raise CommandException('"%s" command spanning providers not allowed.' %
771                             self.command_name)
772
773    # Determine whether acl_arg names a file containing XML ACL text vs. the
774    # string name of a canned ACL.
775    if os.path.isfile(acl_arg):
776      with codecs.open(acl_arg, 'r', UTF8) as f:
777        acl_arg = f.read()
778      self.canned = False
779    else:
780      # No file exists, so expect a canned ACL string.
781      # validate=False because we allow wildcard urls.
782      storage_uri = boto.storage_uri(
783          url_args[0], debug=self.debug, validate=False,
784          bucket_storage_uri_class=self.bucket_storage_uri_class)
785
786      canned_acls = storage_uri.canned_acls()
787      if acl_arg not in canned_acls:
788        raise CommandException('Invalid canned ACL "%s".' % acl_arg)
789      self.canned = True
790
791    # Used to track if any ACLs failed to be set.
792    self.everything_set_okay = True
793    self.acl_arg = acl_arg
794
795    self.ApplyAclFunc(acl_func, acl_excep_handler, url_args)
796    if not self.everything_set_okay and not self.continue_on_error:
797      raise CommandException('ACLs for some objects could not be set.')
798
799  def _WarnServiceAccounts(self):
800    """Warns service account users who have received an AccessDenied error.
801
802    When one of the metadata-related commands fails due to AccessDenied, user
803    must ensure that they are listed as an Owner in the API console.
804    """
805    # Import this here so that the value will be set first in
806    # gcs_oauth2_boto_plugin.
807    # pylint: disable=g-import-not-at-top
808    from gcs_oauth2_boto_plugin.oauth2_plugin import IS_SERVICE_ACCOUNT
809
810    if IS_SERVICE_ACCOUNT:
811      # This method is only called when canned ACLs are used, so the warning
812      # definitely applies.
813      self.logger.warning('\n'.join(textwrap.wrap(
814          'It appears that your service account has been denied access while '
815          'attempting to perform a metadata operation. If you believe that you '
816          'should have access to this metadata (i.e., if it is associated with '
817          'your account), please make sure that your service account''s email '
818          'address is listed as an Owner in the Team tab of the API console. '
819          'See "gsutil help creds" for further information.\n')))
820
821  def GetAndPrintAcl(self, url_str):
822    """Prints the standard or default object ACL depending on self.command_name.
823
824    Args:
825      url_str: URL string to get ACL for.
826    """
827    blr = self.GetAclCommandBucketListingReference(url_str)
828    url = StorageUrlFromString(url_str)
829    if (self.gsutil_api.GetApiSelector(url.scheme) == ApiSelector.XML
830        and url.scheme != 'gs'):
831      # Need to use XML passthrough.
832      try:
833        acl = self.gsutil_api.XmlPassThroughGetAcl(
834            url, def_obj_acl=self.def_acl, provider=url.scheme)
835        print acl.to_xml()
836      except AccessDeniedException, _:
837        self._WarnServiceAccounts()
838        raise
839    else:
840      if self.command_name == 'defacl':
841        acl = blr.root_object.defaultObjectAcl
842        if not acl:
843          self.logger.warn(
844              'No default object ACL present for %s. This could occur if '
845              'the default object ACL is private, in which case objects '
846              'created in this bucket will be readable only by their '
847              'creators. It could also mean you do not have OWNER permission '
848              'on %s and therefore do not have permission to read the '
849              'default object ACL.', url_str, url_str)
850      else:
851        acl = blr.root_object.acl
852        if not acl:
853          self._WarnServiceAccounts()
854          raise AccessDeniedException('Access denied. Please ensure you have '
855                                      'OWNER permission on %s.' % url_str)
856      print AclTranslation.JsonFromMessage(acl)
857
858  def GetAclCommandBucketListingReference(self, url_str):
859    """Gets a single bucket listing reference for an acl get command.
860
861    Args:
862      url_str: URL string to get the bucket listing reference for.
863
864    Returns:
865      BucketListingReference for the URL string.
866
867    Raises:
868      CommandException if string did not result in exactly one reference.
869    """
870    # We're guaranteed by caller that we have the appropriate type of url
871    # string for the call (ex. we will never be called with an object string
872    # by getdefacl)
873    wildcard_url = StorageUrlFromString(url_str)
874    if wildcard_url.IsObject():
875      plurality_iter = PluralityCheckableIterator(
876          self.WildcardIterator(url_str).IterObjects(
877              bucket_listing_fields=['acl']))
878    else:
879      # Bucket or provider.  We call IterBuckets explicitly here to ensure that
880      # the root object is populated with the acl.
881      if self.command_name == 'defacl':
882        bucket_fields = ['defaultObjectAcl']
883      else:
884        bucket_fields = ['acl']
885      plurality_iter = PluralityCheckableIterator(
886          self.WildcardIterator(url_str).IterBuckets(
887              bucket_fields=bucket_fields))
888    if plurality_iter.IsEmpty():
889      raise CommandException('No URLs matched')
890    if plurality_iter.HasPlurality():
891      raise CommandException(
892          '%s matched more than one URL, which is not allowed by the %s '
893          'command' % (url_str, self.command_name))
894    return list(plurality_iter)[0]
895
896  def _HandleMultiProcessingSigs(self, signal_num, unused_cur_stack_frame):
897    """Handles signals INT AND TERM during a multi-process/multi-thread request.
898
899    Kills subprocesses.
900
901    Args:
902      unused_signal_num: signal generated by ^C.
903      unused_cur_stack_frame: Current stack frame.
904    """
905    # Note: This only works under Linux/MacOS. See
906    # https://github.com/GoogleCloudPlatform/gsutil/issues/99 for details
907    # about why making it work correctly across OS's is harder and still open.
908    ShutDownGsutil()
909    if signal_num == signal.SIGINT:
910      sys.stderr.write('Caught ^C - exiting\n')
911    # Simply calling sys.exit(1) doesn't work - see above bug for details.
912    KillProcess(os.getpid())
913
914  def GetSingleBucketUrlFromArg(self, arg, bucket_fields=None):
915    """Gets a single bucket URL based on the command arguments.
916
917    Args:
918      arg: String argument to get bucket URL for.
919      bucket_fields: Fields to populate for the bucket.
920
921    Returns:
922      (StorageUrl referring to a single bucket, Bucket metadata).
923
924    Raises:
925      CommandException if args did not match exactly one bucket.
926    """
927    plurality_checkable_iterator = self.GetBucketUrlIterFromArg(
928        arg, bucket_fields=bucket_fields)
929    if plurality_checkable_iterator.HasPlurality():
930      raise CommandException(
931          '%s matched more than one URL, which is not\n'
932          'allowed by the %s command' % (arg, self.command_name))
933    blr = list(plurality_checkable_iterator)[0]
934    return StorageUrlFromString(blr.url_string), blr.root_object
935
936  def GetBucketUrlIterFromArg(self, arg, bucket_fields=None):
937    """Gets a single bucket URL based on the command arguments.
938
939    Args:
940      arg: String argument to iterate over.
941      bucket_fields: Fields to populate for the bucket.
942
943    Returns:
944      PluralityCheckableIterator over buckets.
945
946    Raises:
947      CommandException if iterator matched no buckets.
948    """
949    arg_url = StorageUrlFromString(arg)
950    if not arg_url.IsCloudUrl() or arg_url.IsObject():
951      raise CommandException('"%s" command must specify a bucket' %
952                             self.command_name)
953
954    plurality_checkable_iterator = PluralityCheckableIterator(
955        self.WildcardIterator(arg).IterBuckets(
956            bucket_fields=bucket_fields))
957    if plurality_checkable_iterator.IsEmpty():
958      raise CommandException('No URLs matched')
959    return plurality_checkable_iterator
960
961  ######################
962  # Private functions. #
963  ######################
964
965  def _ResetConnectionPool(self):
966    # Each OS process needs to establish its own set of connections to
967    # the server to avoid writes from different OS processes interleaving
968    # onto the same socket (and garbling the underlying SSL session).
969    # We ensure each process gets its own set of connections here by
970    # closing all connections in the storage provider connection pool.
971    connection_pool = StorageUri.provider_pool
972    if connection_pool:
973      for i in connection_pool:
974        connection_pool[i].connection.close()
975
976  def _GetProcessAndThreadCount(self, process_count, thread_count,
977                                parallel_operations_override):
978    """Determines the values of process_count and thread_count.
979
980    These values are used for parallel operations.
981    If we're not performing operations in parallel, then ignore
982    existing values and use process_count = thread_count = 1.
983
984    Args:
985      process_count: A positive integer or None. In the latter case, we read
986                     the value from the .boto config file.
987      thread_count: A positive integer or None. In the latter case, we read
988                    the value from the .boto config file.
989      parallel_operations_override: Used to override self.parallel_operations.
990                                    This allows the caller to safely override
991                                    the top-level flag for a single call.
992
993    Returns:
994      (process_count, thread_count): The number of processes and threads to use,
995                                     respectively.
996    """
997    # Set OS process and python thread count as a function of options
998    # and config.
999    if self.parallel_operations or parallel_operations_override:
1000      if not process_count:
1001        process_count = boto.config.getint(
1002            'GSUtil', 'parallel_process_count',
1003            gslib.commands.config.DEFAULT_PARALLEL_PROCESS_COUNT)
1004      if process_count < 1:
1005        raise CommandException('Invalid parallel_process_count "%d".' %
1006                               process_count)
1007      if not thread_count:
1008        thread_count = boto.config.getint(
1009            'GSUtil', 'parallel_thread_count',
1010            gslib.commands.config.DEFAULT_PARALLEL_THREAD_COUNT)
1011      if thread_count < 1:
1012        raise CommandException('Invalid parallel_thread_count "%d".' %
1013                               thread_count)
1014    else:
1015      # If -m not specified, then assume 1 OS process and 1 Python thread.
1016      process_count = 1
1017      thread_count = 1
1018
1019    if IS_WINDOWS and process_count > 1:
1020      raise CommandException('\n'.join(textwrap.wrap(
1021          ('It is not possible to set process_count > 1 on Windows. Please '
1022           'update your config file (located at %s) and set '
1023           '"parallel_process_count = 1".') %
1024          GetConfigFilePath())))
1025    self.logger.debug('process count: %d', process_count)
1026    self.logger.debug('thread count: %d', thread_count)
1027
1028    return (process_count, thread_count)
1029
1030  def _SetUpPerCallerState(self):
1031    """Set up the state for a caller id, corresponding to one Apply call."""
1032    # pylint: disable=global-variable-undefined,global-variable-not-assigned
1033    # These variables are initialized in InitializeMultiprocessingVariables or
1034    # InitializeThreadingVariables
1035    global global_return_values_map, shared_vars_map, failure_count
1036    global caller_id_finished_count, shared_vars_list_map, total_tasks
1037    global need_pool_or_done_cond, call_completed_map, class_map
1038    global task_queues, caller_id_lock, caller_id_counter
1039    # Get a new caller ID.
1040    with caller_id_lock:
1041      if isinstance(caller_id_counter, int):
1042        caller_id_counter += 1
1043        caller_id = caller_id_counter
1044      else:
1045        caller_id_counter.value += 1
1046        caller_id = caller_id_counter.value
1047
1048    # Create a copy of self with an incremented recursive level. This allows
1049    # the class to report its level correctly if the function called from it
1050    # also needs to call Apply.
1051    cls = copy.copy(self)
1052    cls.recursive_apply_level += 1
1053
1054    # Thread-safe loggers can't be pickled, so we will remove it here and
1055    # recreate it later in the WorkerThread. This is not a problem since any
1056    # logger with the same name will be treated as a singleton.
1057    cls.logger = None
1058
1059    # Likewise, the default API connection can't be pickled, but it is unused
1060    # anyway as each thread gets its own API delegator.
1061    cls.gsutil_api = None
1062
1063    class_map[caller_id] = cls
1064    total_tasks[caller_id] = -1  # -1 => the producer hasn't finished yet.
1065    call_completed_map[caller_id] = False
1066    caller_id_finished_count[caller_id] = 0
1067    global_return_values_map[caller_id] = []
1068    return caller_id
1069
1070  def _CreateNewConsumerPool(self, num_processes, num_threads):
1071    """Create a new pool of processes that call _ApplyThreads."""
1072    processes = []
1073    task_queue = _NewMultiprocessingQueue()
1074    task_queues.append(task_queue)
1075
1076    current_max_recursive_level.value += 1
1077    if current_max_recursive_level.value > MAX_RECURSIVE_DEPTH:
1078      raise CommandException('Recursion depth of Apply calls is too great.')
1079    for _ in range(num_processes):
1080      recursive_apply_level = len(consumer_pools)
1081      p = multiprocessing.Process(
1082          target=self._ApplyThreads,
1083          args=(num_threads, num_processes, recursive_apply_level))
1084      p.daemon = True
1085      processes.append(p)
1086      p.start()
1087    consumer_pool = _ConsumerPool(processes, task_queue)
1088    consumer_pools.append(consumer_pool)
1089
1090  def Apply(self, func, args_iterator, exception_handler,
1091            shared_attrs=None, arg_checker=_UrlArgChecker,
1092            parallel_operations_override=False, process_count=None,
1093            thread_count=None, should_return_results=False,
1094            fail_on_error=False):
1095    """Calls _Parallel/SequentialApply based on multiprocessing availability.
1096
1097    Args:
1098      func: Function to call to process each argument.
1099      args_iterator: Iterable collection of arguments to be put into the
1100                     work queue.
1101      exception_handler: Exception handler for WorkerThread class.
1102      shared_attrs: List of attributes to manage across sub-processes.
1103      arg_checker: Used to determine whether we should process the current
1104                   argument or simply skip it. Also handles any logging that
1105                   is specific to a particular type of argument.
1106      parallel_operations_override: Used to override self.parallel_operations.
1107                                    This allows the caller to safely override
1108                                    the top-level flag for a single call.
1109      process_count: The number of processes to use. If not specified, then
1110                     the configured default will be used.
1111      thread_count: The number of threads per process. If not speficied, then
1112                    the configured default will be used..
1113      should_return_results: If true, then return the results of all successful
1114                             calls to func in a list.
1115      fail_on_error: If true, then raise any exceptions encountered when
1116                     executing func. This is only applicable in the case of
1117                     process_count == thread_count == 1.
1118
1119    Returns:
1120      Results from spawned threads.
1121    """
1122    if shared_attrs:
1123      original_shared_vars_values = {}  # We'll add these back in at the end.
1124      for name in shared_attrs:
1125        original_shared_vars_values[name] = getattr(self, name)
1126        # By setting this to 0, we simplify the logic for computing deltas.
1127        # We'll add it back after all of the tasks have been performed.
1128        setattr(self, name, 0)
1129
1130    (process_count, thread_count) = self._GetProcessAndThreadCount(
1131        process_count, thread_count, parallel_operations_override)
1132
1133    is_main_thread = (self.recursive_apply_level == 0
1134                      and self.sequential_caller_id == -1)
1135
1136    # We don't honor the fail_on_error flag in the case of multiple threads
1137    # or processes.
1138    fail_on_error = fail_on_error and (process_count * thread_count == 1)
1139
1140    # Only check this from the first call in the main thread. Apart from the
1141    # fact that it's  wasteful to try this multiple times in general, it also
1142    # will never work when called from a subprocess since we use daemon
1143    # processes, and daemons can't create other processes.
1144    if (is_main_thread and not self.multiprocessing_is_available and
1145        process_count > 1):
1146      # Run the check again and log the appropriate warnings. This was run
1147      # before, when the Command object was created, in order to calculate
1148      # self.multiprocessing_is_available, but we don't want to print the
1149      # warning until we're sure the user actually tried to use multiple
1150      # threads or processes.
1151      CheckMultiprocessingAvailableAndInit(logger=self.logger)
1152
1153    caller_id = self._SetUpPerCallerState()
1154
1155    # If any shared attributes passed by caller, create a dictionary of
1156    # shared memory variables for every element in the list of shared
1157    # attributes.
1158    if shared_attrs:
1159      shared_vars_list_map[caller_id] = shared_attrs
1160      for name in shared_attrs:
1161        shared_vars_map[(caller_id, name)] = 0
1162
1163    # Make all of the requested function calls.
1164    usable_processes_count = (process_count if self.multiprocessing_is_available
1165                              else 1)
1166    if thread_count * usable_processes_count > 1:
1167      self._ParallelApply(func, args_iterator, exception_handler, caller_id,
1168                          arg_checker, usable_processes_count, thread_count,
1169                          should_return_results, fail_on_error)
1170    else:
1171      self._SequentialApply(func, args_iterator, exception_handler, caller_id,
1172                            arg_checker, should_return_results, fail_on_error)
1173
1174    if shared_attrs:
1175      for name in shared_attrs:
1176        # This allows us to retain the original value of the shared variable,
1177        # and simply apply the delta after what was done during the call to
1178        # apply.
1179        final_value = (original_shared_vars_values[name] +
1180                       shared_vars_map.get((caller_id, name)))
1181        setattr(self, name, final_value)
1182
1183    if should_return_results:
1184      return global_return_values_map.get(caller_id)
1185
1186  def _MaybeSuggestGsutilDashM(self):
1187    """Outputs a sugestion to the user to use gsutil -m."""
1188    if not (boto.config.getint('GSUtil', 'parallel_process_count', 0) == 1 and
1189            boto.config.getint('GSUtil', 'parallel_thread_count', 0) == 1):
1190      self.logger.info('\n' + textwrap.fill(
1191          '==> NOTE: You are performing a sequence of gsutil operations that '
1192          'may run significantly faster if you instead use gsutil -m %s ...\n'
1193          'Please see the -m section under "gsutil help options" for further '
1194          'information about when gsutil -m can be advantageous.'
1195          % sys.argv[1]) + '\n')
1196
1197  # pylint: disable=g-doc-args
1198  def _SequentialApply(self, func, args_iterator, exception_handler, caller_id,
1199                       arg_checker, should_return_results, fail_on_error):
1200    """Performs all function calls sequentially in the current thread.
1201
1202    No other threads or processes will be spawned. This degraded functionality
1203    is used when the multiprocessing module is not available or the user
1204    requests only one thread and one process.
1205    """
1206    # Create a WorkerThread to handle all of the logic needed to actually call
1207    # the function. Note that this thread will never be started, and all work
1208    # is done in the current thread.
1209    worker_thread = WorkerThread(None, False)
1210    args_iterator = iter(args_iterator)
1211    # Count of sequential calls that have been made. Used for producing
1212    # suggestion to use gsutil -m.
1213    sequential_call_count = 0
1214    while True:
1215
1216      # Try to get the next argument, handling any exceptions that arise.
1217      try:
1218        args = args_iterator.next()
1219      except StopIteration, e:
1220        break
1221      except Exception, e:  # pylint: disable=broad-except
1222        _IncrementFailureCount()
1223        if fail_on_error:
1224          raise
1225        else:
1226          try:
1227            exception_handler(self, e)
1228          except Exception, _:  # pylint: disable=broad-except
1229            self.logger.debug(
1230                'Caught exception while handling exception for %s:\n%s',
1231                func, traceback.format_exc())
1232          continue
1233
1234      sequential_call_count += 1
1235      if sequential_call_count == OFFER_GSUTIL_M_SUGGESTION_THRESHOLD:
1236        # Output suggestion near beginning of run, so user sees it early and can
1237        # ^C and try gsutil -m.
1238        self._MaybeSuggestGsutilDashM()
1239      if arg_checker(self, args):
1240        # Now that we actually have the next argument, perform the task.
1241        task = Task(func, args, caller_id, exception_handler,
1242                    should_return_results, arg_checker, fail_on_error)
1243        worker_thread.PerformTask(task, self)
1244    if sequential_call_count >= gslib.util.GetTermLines():
1245      # Output suggestion at end of long run, in case user missed it at the
1246      # start and it scrolled off-screen.
1247      self._MaybeSuggestGsutilDashM()
1248
1249    # If the final iterated argument results in an exception, and that
1250    # exception modifies shared_attrs, we need to publish the results.
1251    worker_thread.shared_vars_updater.Update(caller_id, self)
1252
1253  # pylint: disable=g-doc-args
1254  def _ParallelApply(self, func, args_iterator, exception_handler, caller_id,
1255                     arg_checker, process_count, thread_count,
1256                     should_return_results, fail_on_error):
1257    """Dispatches input arguments across a thread/process pool.
1258
1259    Pools are composed of parallel OS processes and/or Python threads,
1260    based on options (-m or not) and settings in the user's config file.
1261
1262    If only one OS process is requested/available, dispatch requests across
1263    threads in the current OS process.
1264
1265    In the multi-process case, we will create one pool of worker processes for
1266    each level of the tree of recursive calls to Apply. E.g., if A calls
1267    Apply(B), and B ultimately calls Apply(C) followed by Apply(D), then we
1268    will only create two sets of worker processes - B will execute in the first,
1269    and C and D will execute in the second. If C is then changed to call
1270    Apply(E) and D is changed to call Apply(F), then we will automatically
1271    create a third set of processes (lazily, when needed) that will be used to
1272    execute calls to E and F. This might look something like:
1273
1274    Pool1 Executes:                B
1275                                  / \
1276    Pool2 Executes:              C   D
1277                                /     \
1278    Pool3 Executes:            E       F
1279
1280    Apply's parallelism is generally broken up into 4 cases:
1281    - If process_count == thread_count == 1, then all tasks will be executed
1282      by _SequentialApply.
1283    - If process_count > 1 and thread_count == 1, then the main thread will
1284      create a new pool of processes (if they don't already exist) and each of
1285      those processes will execute the tasks in a single thread.
1286    - If process_count == 1 and thread_count > 1, then this process will create
1287      a new pool of threads to execute the tasks.
1288    - If process_count > 1 and thread_count > 1, then the main thread will
1289      create a new pool of processes (if they don't already exist) and each of
1290      those processes will, upon creation, create a pool of threads to
1291      execute the tasks.
1292
1293    Args:
1294      caller_id: The caller ID unique to this call to command.Apply.
1295      See command.Apply for description of other arguments.
1296    """
1297    is_main_thread = self.recursive_apply_level == 0
1298
1299    # Catch SIGINT and SIGTERM under Linux/MacOs so we can do cleanup before
1300    # exiting.
1301    if not IS_WINDOWS and is_main_thread:
1302      # Register as a final signal handler because this handler kills the
1303      # main gsutil process (so it must run last).
1304      RegisterSignalHandler(signal.SIGINT, self._HandleMultiProcessingSigs,
1305                            is_final_handler=True)
1306      RegisterSignalHandler(signal.SIGTERM, self._HandleMultiProcessingSigs,
1307                            is_final_handler=True)
1308
1309    if not task_queues:
1310      # The process we create will need to access the next recursive level
1311      # of task queues if it makes a call to Apply, so we always keep around
1312      # one more queue than we know we need. OTOH, if we don't create a new
1313      # process, the existing process still needs a task queue to use.
1314      if process_count > 1:
1315        task_queues.append(_NewMultiprocessingQueue())
1316      else:
1317        task_queues.append(_NewThreadsafeQueue())
1318
1319    if process_count > 1:  # Handle process pool creation.
1320      # Check whether this call will need a new set of workers.
1321
1322      # Each worker must acquire a shared lock before notifying the main thread
1323      # that it needs a new worker pool, so that at most one worker asks for
1324      # a new worker pool at once.
1325      try:
1326        if not is_main_thread:
1327          worker_checking_level_lock.acquire()
1328        if self.recursive_apply_level >= current_max_recursive_level.value:
1329          with need_pool_or_done_cond:
1330            # Only the main thread is allowed to create new processes -
1331            # otherwise, we will run into some Python bugs.
1332            if is_main_thread:
1333              self._CreateNewConsumerPool(process_count, thread_count)
1334            else:
1335              # Notify the main thread that we need a new consumer pool.
1336              new_pool_needed.value = 1
1337              need_pool_or_done_cond.notify_all()
1338              # The main thread will notify us when it finishes.
1339              need_pool_or_done_cond.wait()
1340      finally:
1341        if not is_main_thread:
1342          worker_checking_level_lock.release()
1343
1344    # If we're running in this process, create a separate task queue. Otherwise,
1345    # if Apply has already been called with process_count > 1, then there will
1346    # be consumer pools trying to use our processes.
1347    if process_count > 1:
1348      task_queue = task_queues[self.recursive_apply_level]
1349    elif self.multiprocessing_is_available:
1350      task_queue = _NewMultiprocessingQueue()
1351    else:
1352      task_queue = _NewThreadsafeQueue()
1353
1354    # Kick off a producer thread to throw tasks in the global task queue. We
1355    # do this asynchronously so that the main thread can be free to create new
1356    # consumer pools when needed (otherwise, any thread with a task that needs
1357    # a new consumer pool must block until we're completely done producing; in
1358    # the worst case, every worker blocks on such a call and the producer fills
1359    # up the task queue before it finishes, so we block forever).
1360    producer_thread = ProducerThread(copy.copy(self), args_iterator, caller_id,
1361                                     func, task_queue, should_return_results,
1362                                     exception_handler, arg_checker,
1363                                     fail_on_error)
1364
1365    if process_count > 1:
1366      # Wait here until either:
1367      #   1. We're the main thread and someone needs a new consumer pool - in
1368      #      which case we create one and continue waiting.
1369      #   2. Someone notifies us that all of the work we requested is done, in
1370      #      which case we retrieve the results (if applicable) and stop
1371      #      waiting.
1372      while True:
1373        with need_pool_or_done_cond:
1374          # Either our call is done, or someone needs a new level of consumer
1375          # pools, or we the wakeup call was meant for someone else. It's
1376          # impossible for both conditions to be true, since the main thread is
1377          # blocked on any other ongoing calls to Apply, and a thread would not
1378          # ask for a new consumer pool unless it had more work to do.
1379          if call_completed_map[caller_id]:
1380            break
1381          elif is_main_thread and new_pool_needed.value:
1382            new_pool_needed.value = 0
1383            self._CreateNewConsumerPool(process_count, thread_count)
1384            need_pool_or_done_cond.notify_all()
1385
1386          # Note that we must check the above conditions before the wait() call;
1387          # otherwise, the notification can happen before we start waiting, in
1388          # which case we'll block forever.
1389          need_pool_or_done_cond.wait()
1390    else:  # Using a single process.
1391      self._ApplyThreads(thread_count, process_count,
1392                         self.recursive_apply_level,
1393                         is_blocking_call=True, task_queue=task_queue)
1394
1395    # We encountered an exception from the producer thread before any arguments
1396    # were enqueued, but it wouldn't have been propagated, so we'll now
1397    # explicitly raise it here.
1398    if producer_thread.unknown_exception:
1399      # pylint: disable=raising-bad-type
1400      raise producer_thread.unknown_exception
1401
1402    # We encountered an exception from the producer thread while iterating over
1403    # the arguments, so raise it here if we're meant to fail on error.
1404    if producer_thread.iterator_exception and fail_on_error:
1405      # pylint: disable=raising-bad-type
1406      raise producer_thread.iterator_exception
1407
1408  def _ApplyThreads(self, thread_count, process_count, recursive_apply_level,
1409                    is_blocking_call=False, task_queue=None):
1410    """Assigns the work from the multi-process global task queue.
1411
1412    Work is assigned to an individual process for later consumption either by
1413    the WorkerThreads or (if thread_count == 1) this thread.
1414
1415    Args:
1416      thread_count: The number of threads used to perform the work. If 1, then
1417                    perform all work in this thread.
1418      process_count: The number of processes used to perform the work.
1419      recursive_apply_level: The depth in the tree of recursive calls to Apply
1420                             of this thread.
1421      is_blocking_call: True iff the call to Apply is blocked on this call
1422                        (which is true iff process_count == 1), implying that
1423                        _ApplyThreads must behave as a blocking call.
1424    """
1425    self._ResetConnectionPool()
1426    self.recursive_apply_level = recursive_apply_level
1427
1428    task_queue = task_queue or task_queues[recursive_apply_level]
1429
1430    # Ensure fairness across processes by filling our WorkerPool only with
1431    # as many tasks as it has WorkerThreads. This semaphore is acquired each
1432    # time that a task is retrieved from the queue and released each time
1433    # a task is completed by a WorkerThread.
1434    worker_semaphore = threading.BoundedSemaphore(thread_count)
1435
1436    assert thread_count * process_count > 1, (
1437        'Invalid state, calling command._ApplyThreads with only one thread '
1438        'and process.')
1439    # TODO: Presently, this pool gets recreated with each call to Apply. We
1440    # should be able to do it just once, at process creation time.
1441    worker_pool = WorkerPool(
1442        thread_count, self.logger, worker_semaphore,
1443        bucket_storage_uri_class=self.bucket_storage_uri_class,
1444        gsutil_api_map=self.gsutil_api_map, debug=self.debug)
1445
1446    num_enqueued = 0
1447    while True:
1448      worker_semaphore.acquire()
1449      task = task_queue.get()
1450      if task.args != ZERO_TASKS_TO_DO_ARGUMENT:
1451        # If we have no tasks to do and we're performing a blocking call, we
1452        # need a special signal to tell us to stop - otherwise, we block on
1453        # the call to task_queue.get() forever.
1454        worker_pool.AddTask(task)
1455        num_enqueued += 1
1456      else:
1457        # No tasks remain; don't block the semaphore on WorkerThread completion.
1458        worker_semaphore.release()
1459
1460      if is_blocking_call:
1461        num_to_do = total_tasks[task.caller_id]
1462        # The producer thread won't enqueue the last task until after it has
1463        # updated total_tasks[caller_id], so we know that num_to_do < 0 implies
1464        # we will do this check again.
1465        if num_to_do >= 0 and num_enqueued == num_to_do:
1466          if thread_count == 1:
1467            return
1468          else:
1469            while True:
1470              with need_pool_or_done_cond:
1471                if call_completed_map[task.caller_id]:
1472                  # We need to check this first, in case the condition was
1473                  # notified before we grabbed the lock.
1474                  return
1475                need_pool_or_done_cond.wait()
1476
1477
1478# Below here lie classes and functions related to controlling the flow of tasks
1479# between various threads and processes.
1480
1481
1482class _ConsumerPool(object):
1483
1484  def __init__(self, processes, task_queue):
1485    self.processes = processes
1486    self.task_queue = task_queue
1487
1488  def ShutDown(self):
1489    for process in self.processes:
1490      KillProcess(process.pid)
1491
1492
1493def KillProcess(pid):
1494  """Make best effort to kill the given process.
1495
1496  We ignore all exceptions so a caller looping through a list of processes will
1497  continue attempting to kill each, even if one encounters a problem.
1498
1499  Args:
1500    pid: The process ID.
1501  """
1502  try:
1503    # os.kill doesn't work in 2.X or 3.Y on Windows for any X < 7 or Y < 2.
1504    if IS_WINDOWS and ((2, 6) <= sys.version_info[:3] < (2, 7) or
1505                       (3, 0) <= sys.version_info[:3] < (3, 2)):
1506      kernel32 = ctypes.windll.kernel32
1507      handle = kernel32.OpenProcess(1, 0, pid)
1508      kernel32.TerminateProcess(handle, 0)
1509    else:
1510      os.kill(pid, signal.SIGKILL)
1511  except:  # pylint: disable=bare-except
1512    pass
1513
1514
1515class Task(namedtuple('Task', (
1516    'func args caller_id exception_handler should_return_results arg_checker '
1517    'fail_on_error'))):
1518  """Task class representing work to be completed.
1519
1520  Args:
1521    func: The function to be executed.
1522    args: The arguments to func.
1523    caller_id: The globally-unique caller ID corresponding to the Apply call.
1524    exception_handler: The exception handler to use if the call to func fails.
1525    should_return_results: True iff the results of this function should be
1526                           returned from the Apply call.
1527    arg_checker: Used to determine whether we should process the current
1528                 argument or simply skip it. Also handles any logging that
1529                 is specific to a particular type of argument.
1530    fail_on_error: If true, then raise any exceptions encountered when
1531                   executing func. This is only applicable in the case of
1532                   process_count == thread_count == 1.
1533  """
1534  pass
1535
1536
1537class ProducerThread(threading.Thread):
1538  """Thread used to enqueue work for other processes and threads."""
1539
1540  def __init__(self, cls, args_iterator, caller_id, func, task_queue,
1541               should_return_results, exception_handler, arg_checker,
1542               fail_on_error):
1543    """Initializes the producer thread.
1544
1545    Args:
1546      cls: Instance of Command for which this ProducerThread was created.
1547      args_iterator: Iterable collection of arguments to be put into the
1548                     work queue.
1549      caller_id: Globally-unique caller ID corresponding to this call to Apply.
1550      func: The function to be called on each element of args_iterator.
1551      task_queue: The queue into which tasks will be put, to later be consumed
1552                  by Command._ApplyThreads.
1553      should_return_results: True iff the results for this call to command.Apply
1554                             were requested.
1555      exception_handler: The exception handler to use when errors are
1556                         encountered during calls to func.
1557      arg_checker: Used to determine whether we should process the current
1558                   argument or simply skip it. Also handles any logging that
1559                   is specific to a particular type of argument.
1560      fail_on_error: If true, then raise any exceptions encountered when
1561                     executing func. This is only applicable in the case of
1562                     process_count == thread_count == 1.
1563    """
1564    super(ProducerThread, self).__init__()
1565    self.func = func
1566    self.cls = cls
1567    self.args_iterator = args_iterator
1568    self.caller_id = caller_id
1569    self.task_queue = task_queue
1570    self.arg_checker = arg_checker
1571    self.exception_handler = exception_handler
1572    self.should_return_results = should_return_results
1573    self.fail_on_error = fail_on_error
1574    self.shared_variables_updater = _SharedVariablesUpdater()
1575    self.daemon = True
1576    self.unknown_exception = None
1577    self.iterator_exception = None
1578    self.start()
1579
1580  def run(self):
1581    num_tasks = 0
1582    cur_task = None
1583    last_task = None
1584    try:
1585      args_iterator = iter(self.args_iterator)
1586      while True:
1587        try:
1588          args = args_iterator.next()
1589        except StopIteration, e:
1590          break
1591        except Exception, e:  # pylint: disable=broad-except
1592          _IncrementFailureCount()
1593          if self.fail_on_error:
1594            self.iterator_exception = e
1595            raise
1596          else:
1597            try:
1598              self.exception_handler(self.cls, e)
1599            except Exception, _:  # pylint: disable=broad-except
1600              self.cls.logger.debug(
1601                  'Caught exception while handling exception for %s:\n%s',
1602                  self.func, traceback.format_exc())
1603            self.shared_variables_updater.Update(self.caller_id, self.cls)
1604            continue
1605
1606        if self.arg_checker(self.cls, args):
1607          num_tasks += 1
1608          last_task = cur_task
1609          cur_task = Task(self.func, args, self.caller_id,
1610                          self.exception_handler, self.should_return_results,
1611                          self.arg_checker, self.fail_on_error)
1612          if last_task:
1613            self.task_queue.put(last_task)
1614    except Exception, e:  # pylint: disable=broad-except
1615      # This will also catch any exception raised due to an error in the
1616      # iterator when fail_on_error is set, so check that we failed for some
1617      # other reason before claiming that we had an unknown exception.
1618      if not self.iterator_exception:
1619        self.unknown_exception = e
1620    finally:
1621      # We need to make sure to update total_tasks[caller_id] before we enqueue
1622      # the last task. Otherwise, a worker can retrieve the last task and
1623      # complete it, then check total_tasks and determine that we're not done
1624      # producing all before we update total_tasks. This approach forces workers
1625      # to wait on the last task until after we've updated total_tasks.
1626      total_tasks[self.caller_id] = num_tasks
1627      if not cur_task:
1628        # This happens if there were zero arguments to be put in the queue.
1629        cur_task = Task(None, ZERO_TASKS_TO_DO_ARGUMENT, self.caller_id,
1630                        None, None, None, None)
1631      self.task_queue.put(cur_task)
1632
1633      # It's possible that the workers finished before we updated total_tasks,
1634      # so we need to check here as well.
1635      _NotifyIfDone(self.caller_id,
1636                    caller_id_finished_count.get(self.caller_id))
1637
1638
1639class WorkerPool(object):
1640  """Pool of worker threads to which tasks can be added."""
1641
1642  def __init__(self, thread_count, logger, worker_semaphore,
1643               bucket_storage_uri_class=None, gsutil_api_map=None, debug=0):
1644    self.task_queue = _NewThreadsafeQueue()
1645    self.threads = []
1646    for _ in range(thread_count):
1647      worker_thread = WorkerThread(
1648          self.task_queue, logger, worker_semaphore=worker_semaphore,
1649          bucket_storage_uri_class=bucket_storage_uri_class,
1650          gsutil_api_map=gsutil_api_map, debug=debug)
1651      self.threads.append(worker_thread)
1652      worker_thread.start()
1653
1654  def AddTask(self, task):
1655    self.task_queue.put(task)
1656
1657
1658class WorkerThread(threading.Thread):
1659  """Thread where all the work will be performed.
1660
1661  This makes the function calls for Apply and takes care of all error handling,
1662  return value propagation, and shared_vars.
1663
1664  Note that this thread is NOT started upon instantiation because the function-
1665  calling logic is also used in the single-threaded case.
1666  """
1667
1668  def __init__(self, task_queue, logger, worker_semaphore=None,
1669               bucket_storage_uri_class=None, gsutil_api_map=None, debug=0):
1670    """Initializes the worker thread.
1671
1672    Args:
1673      task_queue: The thread-safe queue from which this thread should obtain
1674                  its work.
1675      logger: Logger to use for this thread.
1676      worker_semaphore: threading.BoundedSemaphore to be released each time a
1677          task is completed, or None for single-threaded execution.
1678      bucket_storage_uri_class: Class to instantiate for cloud StorageUris.
1679                                Settable for testing/mocking.
1680      gsutil_api_map: Map of providers and API selector tuples to api classes
1681                      which can be used to communicate with those providers.
1682                      Used for the instantiating CloudApiDelegator class.
1683      debug: debug level for the CloudApiDelegator class.
1684    """
1685    super(WorkerThread, self).__init__()
1686    self.task_queue = task_queue
1687    self.worker_semaphore = worker_semaphore
1688    self.daemon = True
1689    self.cached_classes = {}
1690    self.shared_vars_updater = _SharedVariablesUpdater()
1691
1692    self.thread_gsutil_api = None
1693    if bucket_storage_uri_class and gsutil_api_map:
1694      self.thread_gsutil_api = CloudApiDelegator(
1695          bucket_storage_uri_class, gsutil_api_map, logger, debug=debug)
1696
1697  def PerformTask(self, task, cls):
1698    """Makes the function call for a task.
1699
1700    Args:
1701      task: The Task to perform.
1702      cls: The instance of a class which gives context to the functions called
1703           by the Task's function. E.g., see SetAclFuncWrapper.
1704    """
1705    caller_id = task.caller_id
1706    try:
1707      results = task.func(cls, task.args, thread_state=self.thread_gsutil_api)
1708      if task.should_return_results:
1709        global_return_values_map.Increment(caller_id, [results],
1710                                           default_value=[])
1711    except Exception, e:  # pylint: disable=broad-except
1712      _IncrementFailureCount()
1713      if task.fail_on_error:
1714        raise  # Only happens for single thread and process case.
1715      else:
1716        try:
1717          task.exception_handler(cls, e)
1718        except Exception, _:  # pylint: disable=broad-except
1719          # Don't allow callers to raise exceptions here and kill the worker
1720          # threads.
1721          cls.logger.debug(
1722              'Caught exception while handling exception for %s:\n%s',
1723              task, traceback.format_exc())
1724    finally:
1725      if self.worker_semaphore:
1726        self.worker_semaphore.release()
1727      self.shared_vars_updater.Update(caller_id, cls)
1728
1729      # Even if we encounter an exception, we still need to claim that that
1730      # the function finished executing. Otherwise, we won't know when to
1731      # stop waiting and return results.
1732      num_done = caller_id_finished_count.Increment(caller_id, 1)
1733      _NotifyIfDone(caller_id, num_done)
1734
1735  def run(self):
1736    while True:
1737      task = self.task_queue.get()
1738      caller_id = task.caller_id
1739
1740      # Get the instance of the command with the appropriate context.
1741      cls = self.cached_classes.get(caller_id, None)
1742      if not cls:
1743        cls = copy.copy(class_map[caller_id])
1744        cls.logger = CreateGsutilLogger(cls.command_name)
1745        self.cached_classes[caller_id] = cls
1746
1747      self.PerformTask(task, cls)
1748
1749
1750class _SharedVariablesUpdater(object):
1751  """Used to update shared variable for a class in the global map.
1752
1753     Note that each thread will have its own instance of the calling class for
1754     context, and it will also have its own instance of a
1755     _SharedVariablesUpdater.  This is used in the following way:
1756
1757     1. Before any tasks are performed, each thread will get a copy of the
1758        calling class, and the globally-consistent value of this shared variable
1759        will be initialized to whatever it was before the call to Apply began.
1760
1761     2. After each time a thread performs a task, it will look at the current
1762        values of the shared variables in its instance of the calling class.
1763
1764        2.A. For each such variable, it computes the delta of this variable
1765             between the last known value for this class (which is stored in
1766             a dict local to this class) and the current value of the variable
1767             in the class.
1768
1769        2.B. Using this delta, we update the last known value locally as well
1770             as the globally-consistent value shared across all classes (the
1771             globally consistent value is simply increased by the computed
1772             delta).
1773  """
1774
1775  def __init__(self):
1776    self.last_shared_var_values = {}
1777
1778  def Update(self, caller_id, cls):
1779    """Update any shared variables with their deltas."""
1780    shared_vars = shared_vars_list_map.get(caller_id, None)
1781    if shared_vars:
1782      for name in shared_vars:
1783        key = (caller_id, name)
1784        last_value = self.last_shared_var_values.get(key, 0)
1785        # Compute the change made since the last time we updated here. This is
1786        # calculated by simply subtracting the last known value from the current
1787        # value in the class instance.
1788        delta = getattr(cls, name) - last_value
1789        self.last_shared_var_values[key] = delta + last_value
1790
1791        # Update the globally-consistent value by simply increasing it by the
1792        # computed delta.
1793        shared_vars_map.Increment(key, delta)
1794
1795
1796def _NotifyIfDone(caller_id, num_done):
1797  """Notify any threads waiting for results that something has finished.
1798
1799  Each waiting thread will then need to check the call_completed_map to see if
1800  its work is done.
1801
1802  Note that num_done could be calculated here, but it is passed in as an
1803  optimization so that we have one less call to a globally-locked data
1804  structure.
1805
1806  Args:
1807    caller_id: The caller_id of the function whose progress we're checking.
1808    num_done: The number of tasks currently completed for that caller_id.
1809  """
1810  num_to_do = total_tasks[caller_id]
1811  if num_to_do == num_done and num_to_do >= 0:
1812    # Notify the Apply call that's sleeping that it's ready to return.
1813    with need_pool_or_done_cond:
1814      call_completed_map[caller_id] = True
1815      need_pool_or_done_cond.notify_all()
1816
1817
1818def ShutDownGsutil():
1819  """Shut down all processes in consumer pools in preparation for exiting."""
1820  for q in queues:
1821    try:
1822      q.cancel_join_thread()
1823    except:  # pylint: disable=bare-except
1824      pass
1825  for consumer_pool in consumer_pools:
1826    consumer_pool.ShutDown()
1827
1828
1829# pylint: disable=global-variable-undefined
1830def _IncrementFailureCount():
1831  global failure_count
1832  if isinstance(failure_count, int):
1833    failure_count += 1
1834  else:  # Otherwise it's a multiprocessing.Value() of type 'i'.
1835    failure_count.value += 1
1836
1837
1838# pylint: disable=global-variable-undefined
1839def GetFailureCount():
1840  """Returns the number of failures processed during calls to Apply()."""
1841  try:
1842    if isinstance(failure_count, int):
1843      return failure_count
1844    else:  # It's a multiprocessing.Value() of type 'i'.
1845      return failure_count.value
1846  except NameError:  # If it wasn't initialized, Apply() wasn't called.
1847    return 0
1848
1849
1850def ResetFailureCount():
1851  """Resets the failure_count variable to 0 - useful if error is expected."""
1852  try:
1853    global failure_count
1854    if isinstance(failure_count, int):
1855      failure_count = 0
1856    else:  # It's a multiprocessing.Value() of type 'i'.
1857      failure_count = multiprocessing.Value('i', 0)
1858  except NameError:  # If it wasn't initialized, Apply() wasn't called.
1859    pass
1860