• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Decorator and context manager for saving and restoring flag values.
16
17There are many ways to save and restore.  Always use the most convenient method
18for a given use case.
19
20Here are examples of each method.  They all call ``do_stuff()`` while
21``FLAGS.someflag`` is temporarily set to ``'foo'``::
22
23    from absl.testing import flagsaver
24
25    # Use a decorator which can optionally override flags via arguments.
26    @flagsaver.flagsaver(someflag='foo')
27    def some_func():
28      do_stuff()
29
30    # Use a decorator which can optionally override flags with flagholders.
31    @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23))
32    def some_func():
33      do_stuff()
34
35    # Use a decorator which does not override flags itself.
36    @flagsaver.flagsaver
37    def some_func():
38      FLAGS.someflag = 'foo'
39      do_stuff()
40
41    # Use a context manager which can optionally override flags via arguments.
42    with flagsaver.flagsaver(someflag='foo'):
43      do_stuff()
44
45    # Save and restore the flag values yourself.
46    saved_flag_values = flagsaver.save_flag_values()
47    try:
48      FLAGS.someflag = 'foo'
49      do_stuff()
50    finally:
51      flagsaver.restore_flag_values(saved_flag_values)
52
53We save and restore a shallow copy of each Flag object's ``__dict__`` attribute.
54This preserves all attributes of the flag, such as whether or not it was
55overridden from its default value.
56
57WARNING: Currently a flag that is saved and then deleted cannot be restored.  An
58exception will be raised.  However if you *add* a flag after saving flag values,
59and then restore flag values, the added flag will be deleted with no errors.
60"""
61
62import functools
63import inspect
64
65from absl import flags
66
67FLAGS = flags.FLAGS
68
69
70def flagsaver(*args, **kwargs):
71  """The main flagsaver interface. See module doc for usage."""
72  if not args:
73    return _FlagOverrider(**kwargs)
74  # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
75  if len(args) == 1 and callable(args[0]):
76    if kwargs:
77      raise ValueError(
78          "It's invalid to specify both positional and keyword parameters.")
79    func = args[0]
80    if inspect.isclass(func):
81      raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
82    return _wrap(func, {})
83  # args can be a list of (FlagHolder, value) pairs.
84  # In which case they augment any specified kwargs.
85  for arg in args:
86    if not isinstance(arg, tuple) or len(arg) != 2:
87      raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
88    holder, value = arg
89    if not isinstance(holder, flags.FlagHolder):
90      raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
91    if holder.name in kwargs:
92      raise ValueError('Cannot set --%s multiple times' % holder.name)
93    kwargs[holder.name] = value
94  return _FlagOverrider(**kwargs)
95
96
97def save_flag_values(flag_values=FLAGS):
98  """Returns copy of flag values as a dict.
99
100  Args:
101    flag_values: FlagValues, the FlagValues instance with which the flag will
102        be saved. This should almost never need to be overridden.
103  Returns:
104    Dictionary mapping keys to values. Keys are flag names, values are
105    corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
106  """
107  return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
108
109
110def restore_flag_values(saved_flag_values, flag_values=FLAGS):
111  """Restores flag values based on the dictionary of flag values.
112
113  Args:
114    saved_flag_values: {'flag_name': value_dict, ...}
115    flag_values: FlagValues, the FlagValues instance from which the flag will
116        be restored. This should almost never need to be overridden.
117  """
118  new_flag_names = list(flag_values)
119  for name in new_flag_names:
120    saved = saved_flag_values.get(name)
121    if saved is None:
122      # If __dict__ was not saved delete "new" flag.
123      delattr(flag_values, name)
124    else:
125      if flag_values[name].value != saved['_value']:
126        flag_values[name].value = saved['_value']  # Ensure C++ value is set.
127      flag_values[name].__dict__ = saved
128
129
130def _wrap(func, overrides):
131  """Creates a wrapper function that saves/restores flag values.
132
133  Args:
134    func: function object - This will be called between saving flags and
135        restoring flags.
136    overrides: {str: object} - Flag names mapped to their values.  These flags
137        will be set after saving the original flag state.
138
139  Returns:
140    return value from func()
141  """
142  @functools.wraps(func)
143  def _flagsaver_wrapper(*args, **kwargs):
144    """Wrapper function that saves and restores flags."""
145    with _FlagOverrider(**overrides):
146      return func(*args, **kwargs)
147  return _flagsaver_wrapper
148
149
150class _FlagOverrider(object):
151  """Overrides flags for the duration of the decorated function call.
152
153  It also restores all original values of flags after decorated method
154  completes.
155  """
156
157  def __init__(self, **overrides):
158    self._overrides = overrides
159    self._saved_flag_values = None
160
161  def __call__(self, func):
162    if inspect.isclass(func):
163      raise TypeError('flagsaver cannot be applied to a class.')
164    return _wrap(func, self._overrides)
165
166  def __enter__(self):
167    self._saved_flag_values = save_flag_values(FLAGS)
168    try:
169      FLAGS._set_attributes(**self._overrides)
170    except:
171      # It may fail because of flag validators.
172      restore_flag_values(self._saved_flag_values, FLAGS)
173      raise
174
175  def __exit__(self, exc_type, exc_value, traceback):
176    restore_flag_values(self._saved_flag_values, FLAGS)
177
178
179def _copy_flag_dict(flag):
180  """Returns a copy of the flag object's ``__dict__``.
181
182  It's mostly a shallow copy of the ``__dict__``, except it also does a shallow
183  copy of the validator list.
184
185  Args:
186    flag: flags.Flag, the flag to copy.
187
188  Returns:
189    A copy of the flag object's ``__dict__``.
190  """
191  copy = flag.__dict__.copy()
192  copy['_value'] = flag.value  # Ensure correct restore for C++ flags.
193  copy['validators'] = list(flag.validators)
194  return copy
195