• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for tf.data options."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from absl import logging
24
25
26def _internal_attr_name(name):
27  return "_" + name
28
29
30class OptionsBase(object):
31  """Base class for representing a set of tf.data options.
32
33  Attributes:
34    _options: Stores the option values.
35  """
36
37  def __init__(self):
38    # NOTE: Cannot use `self._options` here as we override `__setattr__`
39    object.__setattr__(self, "_options", {})
40
41  def __eq__(self, other):
42    if not isinstance(other, self.__class__):
43      return NotImplemented
44    for name in set(self._options) | set(other._options):  # pylint: disable=protected-access
45      if getattr(self, name) != getattr(other, name):
46        return False
47    return True
48
49  def __ne__(self, other):
50    if isinstance(other, self.__class__):
51      return not self.__eq__(other)
52    else:
53      return NotImplemented
54
55  def __setattr__(self, name, value):
56    if hasattr(self, name):
57      object.__setattr__(self, name, value)
58    else:
59      raise AttributeError(
60          "Cannot set the property %s on %s." % (name, type(self).__name__))
61
62  def _to_proto(self):
63    """Convert options to protocol buffer."""
64    raise NotImplementedError("%s._to_proto()" % type(self).__name__)
65
66  def _from_proto(self, pb):
67    """Convert protocol buffer to options."""
68    raise NotImplementedError("%s._from_proto()" % type(self).__name__)
69
70
71# Creates a namedtuple with three keys for optimization graph rewrites settings.
72def graph_rewrites():
73  return collections.namedtuple("GraphRewrites",
74                                ["enabled", "disabled", "default"])
75
76
77def create_option(name, ty, docstring, default_factory=lambda: None):
78  """Creates a type-checked property.
79
80  Args:
81    name: The name to use.
82    ty: The type to use. The type of the property will be validated when it
83      is set.
84    docstring: The docstring to use.
85    default_factory: A callable that takes no arguments and returns a default
86      value to use if not set.
87
88  Returns:
89    A type-checked property.
90  """
91
92  def get_fn(option):
93    # pylint: disable=protected-access
94    if name not in option._options:
95      option._options[name] = default_factory()
96    return option._options.get(name)
97
98  def set_fn(option, value):
99    if not isinstance(value, ty):
100      raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
101                      (name, ty, value, type(value)))
102    option._options[name] = value  # pylint: disable=protected-access
103
104  return property(get_fn, set_fn, None, docstring)
105
106
107def merge_options(*options_list):
108  """Merges the given options, returning the result as a new options object.
109
110  The input arguments are expected to have a matching type that derives from
111  `tf.data.OptionsBase` (and thus each represent a set of options). The method
112  outputs an object of the same type created by merging the sets of options
113  represented by the input arguments.
114
115  If an option is set to different values by different options objects, the
116  result will match the setting of the options object that appears in the input
117  list last.
118
119  If an option is an instance of `tf.data.OptionsBase` itself, then this method
120  is applied recursively to the set of options represented by this option.
121
122  Args:
123    *options_list: options to merge
124
125  Raises:
126    TypeError: if the input arguments are incompatible or not derived from
127      `tf.data.OptionsBase`
128
129  Returns:
130    A new options object which is the result of merging the given options.
131  """
132  if len(options_list) < 1:
133    raise ValueError("At least one options should be provided")
134  result_type = type(options_list[0])
135
136  for options in options_list:
137    if not isinstance(options, result_type):
138      raise TypeError("Incompatible options type: %r vs %r" % (type(options),
139                                                               result_type))
140
141  if not isinstance(options_list[0], OptionsBase):
142    raise TypeError("The inputs should inherit from `OptionsBase`")
143
144  default_options = result_type()
145  result = result_type()
146  for options in options_list:
147    # Iterate over all set options and merge them into the result.
148    for name in options._options:  # pylint: disable=protected-access
149      this = getattr(result, name)
150      that = getattr(options, name)
151      default = getattr(default_options, name)
152      if that == default:
153        continue
154      elif this == default:
155        setattr(result, name, that)
156      elif isinstance(this, OptionsBase):
157        setattr(result, name, merge_options(this, that))
158      elif this != that:
159        logging.warning("Changing the value of option %s from %r to %r.", name,
160                        this, that)
161        setattr(result, name, that)
162  return result
163