1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for tf.data options.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from absl import logging 24 25 26def _internal_attr_name(name): 27 return "_" + name 28 29 30class OptionsBase(object): 31 """Base class for representing a set of tf.data options. 32 33 Attributes: 34 _options: Stores the option values. 35 """ 36 37 def __init__(self): 38 # NOTE: Cannot use `self._options` here as we override `__setattr__` 39 object.__setattr__(self, "_options", {}) 40 41 def __eq__(self, other): 42 if not isinstance(other, self.__class__): 43 return NotImplemented 44 for name in set(self._options) | set(other._options): # pylint: disable=protected-access 45 if getattr(self, name) != getattr(other, name): 46 return False 47 return True 48 49 def __ne__(self, other): 50 if isinstance(other, self.__class__): 51 return not self.__eq__(other) 52 else: 53 return NotImplemented 54 55 def __setattr__(self, name, value): 56 if hasattr(self, name): 57 object.__setattr__(self, name, value) 58 else: 59 raise AttributeError( 60 "Cannot set the property %s on %s." % (name, type(self).__name__)) 61 62 def _to_proto(self): 63 """Convert options to protocol buffer.""" 64 raise NotImplementedError("%s._to_proto()" % type(self).__name__) 65 66 def _from_proto(self, pb): 67 """Convert protocol buffer to options.""" 68 raise NotImplementedError("%s._from_proto()" % type(self).__name__) 69 70 71# Creates a namedtuple with three keys for optimization graph rewrites settings. 72def graph_rewrites(): 73 return collections.namedtuple("GraphRewrites", 74 ["enabled", "disabled", "default"]) 75 76 77def create_option(name, ty, docstring, default_factory=lambda: None): 78 """Creates a type-checked property. 79 80 Args: 81 name: The name to use. 82 ty: The type to use. The type of the property will be validated when it 83 is set. 84 docstring: The docstring to use. 85 default_factory: A callable that takes no arguments and returns a default 86 value to use if not set. 87 88 Returns: 89 A type-checked property. 90 """ 91 92 def get_fn(option): 93 # pylint: disable=protected-access 94 if name not in option._options: 95 option._options[name] = default_factory() 96 return option._options.get(name) 97 98 def set_fn(option, value): 99 if not isinstance(value, ty): 100 raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" % 101 (name, ty, value, type(value))) 102 option._options[name] = value # pylint: disable=protected-access 103 104 return property(get_fn, set_fn, None, docstring) 105 106 107def merge_options(*options_list): 108 """Merges the given options, returning the result as a new options object. 109 110 The input arguments are expected to have a matching type that derives from 111 `tf.data.OptionsBase` (and thus each represent a set of options). The method 112 outputs an object of the same type created by merging the sets of options 113 represented by the input arguments. 114 115 If an option is set to different values by different options objects, the 116 result will match the setting of the options object that appears in the input 117 list last. 118 119 If an option is an instance of `tf.data.OptionsBase` itself, then this method 120 is applied recursively to the set of options represented by this option. 121 122 Args: 123 *options_list: options to merge 124 125 Raises: 126 TypeError: if the input arguments are incompatible or not derived from 127 `tf.data.OptionsBase` 128 129 Returns: 130 A new options object which is the result of merging the given options. 131 """ 132 if len(options_list) < 1: 133 raise ValueError("At least one options should be provided") 134 result_type = type(options_list[0]) 135 136 for options in options_list: 137 if not isinstance(options, result_type): 138 raise TypeError("Incompatible options type: %r vs %r" % (type(options), 139 result_type)) 140 141 if not isinstance(options_list[0], OptionsBase): 142 raise TypeError("The inputs should inherit from `OptionsBase`") 143 144 default_options = result_type() 145 result = result_type() 146 for options in options_list: 147 # Iterate over all set options and merge them into the result. 148 for name in options._options: # pylint: disable=protected-access 149 this = getattr(result, name) 150 that = getattr(options, name) 151 default = getattr(default_options, name) 152 if that == default: 153 continue 154 elif this == default: 155 setattr(result, name, that) 156 elif isinstance(this, OptionsBase): 157 setattr(result, name, merge_options(this, that)) 158 elif this != that: 159 logging.warning("Changing the value of option %s from %r to %r.", name, 160 this, that) 161 setattr(result, name, that) 162 return result 163