1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for tf.data options.""" 16 17import collections 18 19from absl import logging 20 21 22def _internal_attr_name(name): 23 return "_" + name 24 25 26class OptionsBase: 27 """Base class for representing a set of tf.data options. 28 29 Attributes: 30 _options: Stores the option values. 31 """ 32 33 def __init__(self): 34 # NOTE: Cannot use `self._options` here as we override `__setattr__` 35 object.__setattr__(self, "_options", {}) 36 object.__setattr__(self, "_mutable", True) 37 38 def __eq__(self, other): 39 if not isinstance(other, self.__class__): 40 return NotImplemented 41 for name in set(self._options) | set(other._options): # pylint: disable=protected-access 42 if getattr(self, name) != getattr(other, name): 43 return False 44 return True 45 46 def __ne__(self, other): 47 if isinstance(other, self.__class__): 48 return not self.__eq__(other) 49 else: 50 return NotImplemented 51 52 def __setattr__(self, name, value): 53 if not self._mutable: 54 raise ValueError("Mutating `tf.data.Options()` returned by " 55 "`tf.data.Dataset.options()` has no effect. Use " 56 "`tf.data.Dataset.with_options(options)` to set or " 57 "update dataset options.") 58 if hasattr(self, name): 59 object.__setattr__(self, name, value) 60 else: 61 raise AttributeError("Cannot set the property {} on {}.".format( 62 name, 63 type(self).__name__)) 64 65 def _set_mutable(self, mutable): 66 """Change the mutability property to `mutable`.""" 67 object.__setattr__(self, "_mutable", mutable) 68 69 def _set_mutable(self, mutable): 70 """Change the mutability property to `mutable`.""" 71 object.__setattr__(self, "_mutable", mutable) 72 73 def _to_proto(self): 74 """Convert options to protocol buffer.""" 75 raise NotImplementedError("{}._to_proto()".format(type(self).__name__)) 76 77 def _from_proto(self, pb): 78 """Convert protocol buffer to options.""" 79 raise NotImplementedError("{}._from_proto()".format(type(self).__name__)) 80 81 82# Creates a namedtuple with three keys for optimization graph rewrites settings. 83def graph_rewrites(): 84 return collections.namedtuple("GraphRewrites", 85 ["enabled", "disabled", "default"]) 86 87 88def create_option(name, ty, docstring, default_factory=lambda: None): 89 """Creates a type-checked property. 90 91 Args: 92 name: The name to use. 93 ty: The type to use. The type of the property will be validated when it 94 is set. 95 docstring: The docstring to use. 96 default_factory: A callable that takes no arguments and returns a default 97 value to use if not set. 98 99 Returns: 100 A type-checked property. 101 """ 102 103 def get_fn(option): 104 # pylint: disable=protected-access 105 if name not in option._options: 106 option._options[name] = default_factory() 107 return option._options.get(name) 108 109 def set_fn(option, value): 110 if not isinstance(value, ty): 111 raise TypeError( 112 "Property \"{}\" must be of type {}, got: {} (type: {})".format( 113 name, ty, value, type(value))) 114 option._options[name] = value # pylint: disable=protected-access 115 116 return property(get_fn, set_fn, None, docstring) 117 118 119def merge_options(*options_list): 120 """Merges the given options, returning the result as a new options object. 121 122 The input arguments are expected to have a matching type that derives from 123 `tf.data.OptionsBase` (and thus each represent a set of options). The method 124 outputs an object of the same type created by merging the sets of options 125 represented by the input arguments. 126 127 If an option is set to different values by different options objects, the 128 result will match the setting of the options object that appears in the input 129 list last. 130 131 If an option is an instance of `tf.data.OptionsBase` itself, then this method 132 is applied recursively to the set of options represented by this option. 133 134 Args: 135 *options_list: options to merge 136 137 Raises: 138 TypeError: if the input arguments are incompatible or not derived from 139 `tf.data.OptionsBase` 140 141 Returns: 142 A new options object which is the result of merging the given options. 143 """ 144 if len(options_list) < 1: 145 raise ValueError("At least one options should be provided") 146 result_type = type(options_list[0]) 147 148 for options in options_list: 149 if not isinstance(options, result_type): 150 raise TypeError( 151 "Could not merge incompatible options of type {} and {}.".format( 152 type(options), result_type)) 153 154 if not isinstance(options_list[0], OptionsBase): 155 raise TypeError( 156 "All options to be merged should inherit from `OptionsBase` but found " 157 "option of type {} which does not.".format(type(options_list[0]))) 158 159 default_options = result_type() 160 result = result_type() 161 for options in options_list: 162 # Iterate over all set options and merge them into the result. 163 for name in options._options: # pylint: disable=protected-access 164 this = getattr(result, name) 165 that = getattr(options, name) 166 default = getattr(default_options, name) 167 if that == default: 168 continue 169 elif this == default: 170 setattr(result, name, that) 171 elif isinstance(this, OptionsBase): 172 setattr(result, name, merge_options(this, that)) 173 elif this != that: 174 logging.warning("Changing the value of option %s from %r to %r.", name, 175 this, that) 176 setattr(result, name, that) 177 return result 178