• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for tf.data options."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22def _internal_attr_name(name):
23  return "_" + name
24
25
26class OptionsBase(object):
27  """Base class for representing a set of tf.data options.
28
29  Attributes:
30    _options: Stores the option values.
31  """
32
33  def __init__(self):
34    # NOTE: Cannot use `self._options` here as we override `__setattr__`
35    object.__setattr__(self, "_options", {})
36
37  def __eq__(self, other):
38    if not isinstance(other, self.__class__):
39      return NotImplemented
40    for name in set(self._options) | set(other._options):  # pylint: disable=protected-access
41      if getattr(self, name) != getattr(other, name):
42        return False
43    return True
44
45  def __ne__(self, other):
46    if isinstance(other, self.__class__):
47      return not self.__eq__(other)
48    else:
49      return NotImplemented
50
51  def __setattr__(self, name, value):
52    if hasattr(self, name):
53      object.__setattr__(self, name, value)
54    else:
55      raise AttributeError(
56          "Cannot set the property %s on %s." % (name, type(self).__name__))
57
58
59def create_option(name, ty, docstring, default_factory=lambda: None):
60  """Creates a type-checked property.
61
62  Args:
63    name: The name to use.
64    ty: The type to use. The type of the property will be validated when it
65      is set.
66    docstring: The docstring to use.
67    default_factory: A callable that takes no arguments and returns a default
68      value to use if not set.
69
70  Returns:
71    A type-checked property.
72  """
73
74  def get_fn(option):
75    # pylint: disable=protected-access
76    if name not in option._options:
77      option._options[name] = default_factory()
78    return option._options.get(name)
79
80  def set_fn(option, value):
81    if not isinstance(value, ty):
82      raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
83                      (name, ty, value, type(value)))
84    option._options[name] = value  # pylint: disable=protected-access
85
86  return property(get_fn, set_fn, None, docstring)
87
88
89def merge_options(*options_list):
90  """Merges the given options, returning the result as a new options object.
91
92  The input arguments are expected to have a matching type that derives from
93  `OptionsBase` (and thus each represent a set of options). The method outputs
94  an object of the same type created by merging the sets of options represented
95  by the input arguments.
96
97  The sets of options can be merged as long as there does not exist an option
98  with different non-default values.
99
100  If an option is an instance of `OptionsBase` itself, then this method is
101  applied recursively to the set of options represented by this option.
102
103  Args:
104    *options_list: options to merge
105
106  Raises:
107    TypeError: if the input arguments are incompatible or not derived from
108      `OptionsBase`
109    ValueError: if the given options cannot be merged
110
111  Returns:
112    A new options object which is the result of merging the given options.
113  """
114  if len(options_list) < 1:
115    raise ValueError("At least one options should be provided")
116  result_type = type(options_list[0])
117
118  for options in options_list:
119    if not isinstance(options, result_type):
120      raise TypeError("Incompatible options type: %r vs %r" % (type(options),
121                                                               result_type))
122
123  if not isinstance(options_list[0], OptionsBase):
124    raise TypeError("The inputs should inherit from `OptionsBase`")
125
126  default_options = result_type()
127  result = result_type()
128  for options in options_list:
129    # Iterate over all set options and merge the into the result.
130    for name in options._options:  # pylint: disable=protected-access
131      this = getattr(result, name)
132      that = getattr(options, name)
133      default = getattr(default_options, name)
134      if that == default:
135        continue
136      elif this == default:
137        setattr(result, name, that)
138      elif isinstance(this, OptionsBase):
139        setattr(result, name, merge_options(this, that))
140      elif this != that:
141        raise ValueError(
142            "Cannot merge incompatible values (%r and %r) of option: %s" %
143            (this, that, name))
144  return result
145