• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for tf.data options."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from absl import logging
24
25
26def _internal_attr_name(name):
27  return "_" + name
28
29
30class OptionsBase(object):
31  """Base class for representing a set of tf.data options.
32
33  Attributes:
34    _options: Stores the option values.
35  """
36
37  def __init__(self):
38    # NOTE: Cannot use `self._options` here as we override `__setattr__`
39    object.__setattr__(self, "_options", {})
40    object.__setattr__(self, "_mutable", True)
41
42  def __eq__(self, other):
43    if not isinstance(other, self.__class__):
44      return NotImplemented
45    for name in set(self._options) | set(other._options):  # pylint: disable=protected-access
46      if getattr(self, name) != getattr(other, name):
47        return False
48    return True
49
50  def __ne__(self, other):
51    if isinstance(other, self.__class__):
52      return not self.__eq__(other)
53    else:
54      return NotImplemented
55
56  def __setattr__(self, name, value):
57    if not self._mutable:
58      raise ValueError("Mutating `tf.data.Options()` returned by "
59                       "`tf.data.Dataset.options()` has no effect. Use "
60                       "`tf.data.Dataset.with_options(options)` to set or "
61                       "update dataset options.")
62    if hasattr(self, name):
63      object.__setattr__(self, name, value)
64    else:
65      raise AttributeError(
66          "Cannot set the property %s on %s." % (name, type(self).__name__))
67
68  def _set_mutable(self, mutable):
69    """Change the mutability property to `mutable`."""
70    object.__setattr__(self, "_mutable", mutable)
71
72  def _to_proto(self):
73    """Convert options to protocol buffer."""
74    raise NotImplementedError("%s._to_proto()" % type(self).__name__)
75
76  def _from_proto(self, pb):
77    """Convert protocol buffer to options."""
78    raise NotImplementedError("%s._from_proto()" % type(self).__name__)
79
80
81# Creates a namedtuple with three keys for optimization graph rewrites settings.
82def graph_rewrites():
83  return collections.namedtuple("GraphRewrites",
84                                ["enabled", "disabled", "default"])
85
86
87def create_option(name, ty, docstring, default_factory=lambda: None):
88  """Creates a type-checked property.
89
90  Args:
91    name: The name to use.
92    ty: The type to use. The type of the property will be validated when it
93      is set.
94    docstring: The docstring to use.
95    default_factory: A callable that takes no arguments and returns a default
96      value to use if not set.
97
98  Returns:
99    A type-checked property.
100  """
101
102  def get_fn(option):
103    # pylint: disable=protected-access
104    if name not in option._options:
105      option._options[name] = default_factory()
106    return option._options.get(name)
107
108  def set_fn(option, value):
109    if not isinstance(value, ty):
110      raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
111                      (name, ty, value, type(value)))
112    option._options[name] = value  # pylint: disable=protected-access
113
114  return property(get_fn, set_fn, None, docstring)
115
116
117def merge_options(*options_list):
118  """Merges the given options, returning the result as a new options object.
119
120  The input arguments are expected to have a matching type that derives from
121  `tf.data.OptionsBase` (and thus each represent a set of options). The method
122  outputs an object of the same type created by merging the sets of options
123  represented by the input arguments.
124
125  If an option is set to different values by different options objects, the
126  result will match the setting of the options object that appears in the input
127  list last.
128
129  If an option is an instance of `tf.data.OptionsBase` itself, then this method
130  is applied recursively to the set of options represented by this option.
131
132  Args:
133    *options_list: options to merge
134
135  Raises:
136    TypeError: if the input arguments are incompatible or not derived from
137      `tf.data.OptionsBase`
138
139  Returns:
140    A new options object which is the result of merging the given options.
141  """
142  if len(options_list) < 1:
143    raise ValueError("At least one options should be provided")
144  result_type = type(options_list[0])
145
146  for options in options_list:
147    if not isinstance(options, result_type):
148      raise TypeError("Incompatible options type: %r vs %r" % (type(options),
149                                                               result_type))
150
151  if not isinstance(options_list[0], OptionsBase):
152    raise TypeError("The inputs should inherit from `OptionsBase`")
153
154  default_options = result_type()
155  result = result_type()
156  for options in options_list:
157    # Iterate over all set options and merge them into the result.
158    for name in options._options:  # pylint: disable=protected-access
159      this = getattr(result, name)
160      that = getattr(options, name)
161      default = getattr(default_options, name)
162      if that == default:
163        continue
164      elif this == default:
165        setattr(result, name, that)
166      elif isinstance(this, OptionsBase):
167        setattr(result, name, merge_options(this, that))
168      elif this != that:
169        logging.warning("Changing the value of option %s from %r to %r.", name,
170                        this, that)
171        setattr(result, name, that)
172  return result
173