1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Decorator and context manager for saving and restoring flag values. 16 17There are many ways to save and restore. Always use the most convenient method 18for a given use case. 19 20Here are examples of each method. They all call ``do_stuff()`` while 21``FLAGS.someflag`` is temporarily set to ``'foo'``:: 22 23 from absl.testing import flagsaver 24 25 # Use a decorator which can optionally override flags via arguments. 26 @flagsaver.flagsaver(someflag='foo') 27 def some_func(): 28 do_stuff() 29 30 # Use a decorator which can optionally override flags with flagholders. 31 @flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23)) 32 def some_func(): 33 do_stuff() 34 35 # Use a decorator which does not override flags itself. 36 @flagsaver.flagsaver 37 def some_func(): 38 FLAGS.someflag = 'foo' 39 do_stuff() 40 41 # Use a context manager which can optionally override flags via arguments. 42 with flagsaver.flagsaver(someflag='foo'): 43 do_stuff() 44 45 # Save and restore the flag values yourself. 46 saved_flag_values = flagsaver.save_flag_values() 47 try: 48 FLAGS.someflag = 'foo' 49 do_stuff() 50 finally: 51 flagsaver.restore_flag_values(saved_flag_values) 52 53We save and restore a shallow copy of each Flag object's ``__dict__`` attribute. 54This preserves all attributes of the flag, such as whether or not it was 55overridden from its default value. 56 57WARNING: Currently a flag that is saved and then deleted cannot be restored. An 58exception will be raised. However if you *add* a flag after saving flag values, 59and then restore flag values, the added flag will be deleted with no errors. 60""" 61 62import functools 63import inspect 64 65from absl import flags 66 67FLAGS = flags.FLAGS 68 69 70def flagsaver(*args, **kwargs): 71 """The main flagsaver interface. See module doc for usage.""" 72 if not args: 73 return _FlagOverrider(**kwargs) 74 # args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)` 75 if len(args) == 1 and callable(args[0]): 76 if kwargs: 77 raise ValueError( 78 "It's invalid to specify both positional and keyword parameters.") 79 func = args[0] 80 if inspect.isclass(func): 81 raise TypeError('@flagsaver.flagsaver cannot be applied to a class.') 82 return _wrap(func, {}) 83 # args can be a list of (FlagHolder, value) pairs. 84 # In which case they augment any specified kwargs. 85 for arg in args: 86 if not isinstance(arg, tuple) or len(arg) != 2: 87 raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) 88 holder, value = arg 89 if not isinstance(holder, flags.FlagHolder): 90 raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,)) 91 if holder.name in kwargs: 92 raise ValueError('Cannot set --%s multiple times' % holder.name) 93 kwargs[holder.name] = value 94 return _FlagOverrider(**kwargs) 95 96 97def save_flag_values(flag_values=FLAGS): 98 """Returns copy of flag values as a dict. 99 100 Args: 101 flag_values: FlagValues, the FlagValues instance with which the flag will 102 be saved. This should almost never need to be overridden. 103 Returns: 104 Dictionary mapping keys to values. Keys are flag names, values are 105 corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``. 106 """ 107 return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} 108 109 110def restore_flag_values(saved_flag_values, flag_values=FLAGS): 111 """Restores flag values based on the dictionary of flag values. 112 113 Args: 114 saved_flag_values: {'flag_name': value_dict, ...} 115 flag_values: FlagValues, the FlagValues instance from which the flag will 116 be restored. This should almost never need to be overridden. 117 """ 118 new_flag_names = list(flag_values) 119 for name in new_flag_names: 120 saved = saved_flag_values.get(name) 121 if saved is None: 122 # If __dict__ was not saved delete "new" flag. 123 delattr(flag_values, name) 124 else: 125 if flag_values[name].value != saved['_value']: 126 flag_values[name].value = saved['_value'] # Ensure C++ value is set. 127 flag_values[name].__dict__ = saved 128 129 130def _wrap(func, overrides): 131 """Creates a wrapper function that saves/restores flag values. 132 133 Args: 134 func: function object - This will be called between saving flags and 135 restoring flags. 136 overrides: {str: object} - Flag names mapped to their values. These flags 137 will be set after saving the original flag state. 138 139 Returns: 140 return value from func() 141 """ 142 @functools.wraps(func) 143 def _flagsaver_wrapper(*args, **kwargs): 144 """Wrapper function that saves and restores flags.""" 145 with _FlagOverrider(**overrides): 146 return func(*args, **kwargs) 147 return _flagsaver_wrapper 148 149 150class _FlagOverrider(object): 151 """Overrides flags for the duration of the decorated function call. 152 153 It also restores all original values of flags after decorated method 154 completes. 155 """ 156 157 def __init__(self, **overrides): 158 self._overrides = overrides 159 self._saved_flag_values = None 160 161 def __call__(self, func): 162 if inspect.isclass(func): 163 raise TypeError('flagsaver cannot be applied to a class.') 164 return _wrap(func, self._overrides) 165 166 def __enter__(self): 167 self._saved_flag_values = save_flag_values(FLAGS) 168 try: 169 FLAGS._set_attributes(**self._overrides) 170 except: 171 # It may fail because of flag validators. 172 restore_flag_values(self._saved_flag_values, FLAGS) 173 raise 174 175 def __exit__(self, exc_type, exc_value, traceback): 176 restore_flag_values(self._saved_flag_values, FLAGS) 177 178 179def _copy_flag_dict(flag): 180 """Returns a copy of the flag object's ``__dict__``. 181 182 It's mostly a shallow copy of the ``__dict__``, except it also does a shallow 183 copy of the validator list. 184 185 Args: 186 flag: flags.Flag, the flag to copy. 187 188 Returns: 189 A copy of the flag object's ``__dict__``. 190 """ 191 copy = flag.__dict__.copy() 192 copy['_value'] = flag.value # Ensure correct restore for C++ flags. 193 copy['validators'] = list(flag.validators) 194 return copy 195