• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for tf.data options."""
16
17import collections
18
19from absl import logging
20
21
22def _internal_attr_name(name):
23  return "_" + name
24
25
26class OptionsBase:
27  """Base class for representing a set of tf.data options.
28
29  Attributes:
30    _options: Stores the option values.
31  """
32
33  def __init__(self):
34    # NOTE: Cannot use `self._options` here as we override `__setattr__`
35    object.__setattr__(self, "_options", {})
36    object.__setattr__(self, "_mutable", True)
37
38  def __eq__(self, other):
39    if not isinstance(other, self.__class__):
40      return NotImplemented
41    for name in set(self._options) | set(other._options):  # pylint: disable=protected-access
42      if getattr(self, name) != getattr(other, name):
43        return False
44    return True
45
46  def __ne__(self, other):
47    if isinstance(other, self.__class__):
48      return not self.__eq__(other)
49    else:
50      return NotImplemented
51
52  def __setattr__(self, name, value):
53    if not self._mutable:
54      raise ValueError("Mutating `tf.data.Options()` returned by "
55                       "`tf.data.Dataset.options()` has no effect. Use "
56                       "`tf.data.Dataset.with_options(options)` to set or "
57                       "update dataset options.")
58    if hasattr(self, name):
59      object.__setattr__(self, name, value)
60    else:
61      raise AttributeError("Cannot set the property {} on {}.".format(
62          name,
63          type(self).__name__))
64
65  def _set_mutable(self, mutable):
66    """Change the mutability property to `mutable`."""
67    object.__setattr__(self, "_mutable", mutable)
68
69  def _set_mutable(self, mutable):
70    """Change the mutability property to `mutable`."""
71    object.__setattr__(self, "_mutable", mutable)
72
73  def _to_proto(self):
74    """Convert options to protocol buffer."""
75    raise NotImplementedError("{}._to_proto()".format(type(self).__name__))
76
77  def _from_proto(self, pb):
78    """Convert protocol buffer to options."""
79    raise NotImplementedError("{}._from_proto()".format(type(self).__name__))
80
81
82# Creates a namedtuple with three keys for optimization graph rewrites settings.
83def graph_rewrites():
84  return collections.namedtuple("GraphRewrites",
85                                ["enabled", "disabled", "default"])
86
87
88def create_option(name, ty, docstring, default_factory=lambda: None):
89  """Creates a type-checked property.
90
91  Args:
92    name: The name to use.
93    ty: The type to use. The type of the property will be validated when it
94      is set.
95    docstring: The docstring to use.
96    default_factory: A callable that takes no arguments and returns a default
97      value to use if not set.
98
99  Returns:
100    A type-checked property.
101  """
102
103  def get_fn(option):
104    # pylint: disable=protected-access
105    if name not in option._options:
106      option._options[name] = default_factory()
107    return option._options.get(name)
108
109  def set_fn(option, value):
110    if not isinstance(value, ty):
111      raise TypeError(
112          "Property \"{}\" must be of type {}, got: {} (type: {})".format(
113              name, ty, value, type(value)))
114    option._options[name] = value  # pylint: disable=protected-access
115
116  return property(get_fn, set_fn, None, docstring)
117
118
119def merge_options(*options_list):
120  """Merges the given options, returning the result as a new options object.
121
122  The input arguments are expected to have a matching type that derives from
123  `tf.data.OptionsBase` (and thus each represent a set of options). The method
124  outputs an object of the same type created by merging the sets of options
125  represented by the input arguments.
126
127  If an option is set to different values by different options objects, the
128  result will match the setting of the options object that appears in the input
129  list last.
130
131  If an option is an instance of `tf.data.OptionsBase` itself, then this method
132  is applied recursively to the set of options represented by this option.
133
134  Args:
135    *options_list: options to merge
136
137  Raises:
138    TypeError: if the input arguments are incompatible or not derived from
139      `tf.data.OptionsBase`
140
141  Returns:
142    A new options object which is the result of merging the given options.
143  """
144  if len(options_list) < 1:
145    raise ValueError("At least one options should be provided")
146  result_type = type(options_list[0])
147
148  for options in options_list:
149    if not isinstance(options, result_type):
150      raise TypeError(
151          "Could not merge incompatible options of type {} and {}.".format(
152              type(options), result_type))
153
154  if not isinstance(options_list[0], OptionsBase):
155    raise TypeError(
156        "All options to be merged should inherit from `OptionsBase` but found "
157        "option of type {} which does not.".format(type(options_list[0])))
158
159  default_options = result_type()
160  result = result_type()
161  for options in options_list:
162    # Iterate over all set options and merge them into the result.
163    for name in options._options:  # pylint: disable=protected-access
164      this = getattr(result, name)
165      that = getattr(options, name)
166      default = getattr(default_options, name)
167      if that == default:
168        continue
169      elif this == default:
170        setattr(result, name, that)
171      elif isinstance(this, OptionsBase):
172        setattr(result, name, merge_options(this, that))
173      elif this != that:
174        logging.warning("Changing the value of option %s from %r to %r.", name,
175                        this, that)
176        setattr(result, name, that)
177  return result
178