1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for tf.data options.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from absl import logging 24 25 26def _internal_attr_name(name): 27 return "_" + name 28 29 30class OptionsBase(object): 31 """Base class for representing a set of tf.data options. 32 33 Attributes: 34 _options: Stores the option values. 35 """ 36 37 def __init__(self): 38 # NOTE: Cannot use `self._options` here as we override `__setattr__` 39 object.__setattr__(self, "_options", {}) 40 object.__setattr__(self, "_mutable", True) 41 42 def __eq__(self, other): 43 if not isinstance(other, self.__class__): 44 return NotImplemented 45 for name in set(self._options) | set(other._options): # pylint: disable=protected-access 46 if getattr(self, name) != getattr(other, name): 47 return False 48 return True 49 50 def __ne__(self, other): 51 if isinstance(other, self.__class__): 52 return not self.__eq__(other) 53 else: 54 return NotImplemented 55 56 def __setattr__(self, name, value): 57 if not self._mutable: 58 raise ValueError("Mutating `tf.data.Options()` returned by " 59 "`tf.data.Dataset.options()` has no effect. Use " 60 "`tf.data.Dataset.with_options(options)` to set or " 61 "update dataset options.") 62 if hasattr(self, name): 63 object.__setattr__(self, name, value) 64 else: 65 raise AttributeError( 66 "Cannot set the property %s on %s." % (name, type(self).__name__)) 67 68 def _set_mutable(self, mutable): 69 """Change the mutability property to `mutable`.""" 70 object.__setattr__(self, "_mutable", mutable) 71 72 def _to_proto(self): 73 """Convert options to protocol buffer.""" 74 raise NotImplementedError("%s._to_proto()" % type(self).__name__) 75 76 def _from_proto(self, pb): 77 """Convert protocol buffer to options.""" 78 raise NotImplementedError("%s._from_proto()" % type(self).__name__) 79 80 81# Creates a namedtuple with three keys for optimization graph rewrites settings. 82def graph_rewrites(): 83 return collections.namedtuple("GraphRewrites", 84 ["enabled", "disabled", "default"]) 85 86 87def create_option(name, ty, docstring, default_factory=lambda: None): 88 """Creates a type-checked property. 89 90 Args: 91 name: The name to use. 92 ty: The type to use. The type of the property will be validated when it 93 is set. 94 docstring: The docstring to use. 95 default_factory: A callable that takes no arguments and returns a default 96 value to use if not set. 97 98 Returns: 99 A type-checked property. 100 """ 101 102 def get_fn(option): 103 # pylint: disable=protected-access 104 if name not in option._options: 105 option._options[name] = default_factory() 106 return option._options.get(name) 107 108 def set_fn(option, value): 109 if not isinstance(value, ty): 110 raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" % 111 (name, ty, value, type(value))) 112 option._options[name] = value # pylint: disable=protected-access 113 114 return property(get_fn, set_fn, None, docstring) 115 116 117def merge_options(*options_list): 118 """Merges the given options, returning the result as a new options object. 119 120 The input arguments are expected to have a matching type that derives from 121 `tf.data.OptionsBase` (and thus each represent a set of options). The method 122 outputs an object of the same type created by merging the sets of options 123 represented by the input arguments. 124 125 If an option is set to different values by different options objects, the 126 result will match the setting of the options object that appears in the input 127 list last. 128 129 If an option is an instance of `tf.data.OptionsBase` itself, then this method 130 is applied recursively to the set of options represented by this option. 131 132 Args: 133 *options_list: options to merge 134 135 Raises: 136 TypeError: if the input arguments are incompatible or not derived from 137 `tf.data.OptionsBase` 138 139 Returns: 140 A new options object which is the result of merging the given options. 141 """ 142 if len(options_list) < 1: 143 raise ValueError("At least one options should be provided") 144 result_type = type(options_list[0]) 145 146 for options in options_list: 147 if not isinstance(options, result_type): 148 raise TypeError("Incompatible options type: %r vs %r" % (type(options), 149 result_type)) 150 151 if not isinstance(options_list[0], OptionsBase): 152 raise TypeError("The inputs should inherit from `OptionsBase`") 153 154 default_options = result_type() 155 result = result_type() 156 for options in options_list: 157 # Iterate over all set options and merge them into the result. 158 for name in options._options: # pylint: disable=protected-access 159 this = getattr(result, name) 160 that = getattr(options, name) 161 default = getattr(default_options, name) 162 if that == default: 163 continue 164 elif this == default: 165 setattr(result, name, that) 166 elif isinstance(this, OptionsBase): 167 setattr(result, name, merge_options(this, that)) 168 elif this != that: 169 logging.warning("Changing the value of option %s from %r to %r.", name, 170 this, that) 171 setattr(result, name, that) 172 return result 173