• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to hold a library of OpDefs and use it to create Brain operations."""
17
18from google.protobuf import text_format
19from tensorflow.core.config import flags
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.core.framework import tensor_pb2
22from tensorflow.core.framework import tensor_shape_pb2
23from tensorflow.core.framework import types_pb2
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import op_callbacks
26from tensorflow.python.framework import op_def_library_pybind
27from tensorflow.python.framework import op_def_registry
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import _pywrap_utils
32from tensorflow.python.util import compat
33from tensorflow.python.util import tf_contextlib
34
35
36def _Attr(op_def, name):
37  for attr in op_def.attr:
38    if attr.name == name:
39      return attr
40  raise TypeError(f"Inconsistent OpDef for '{op_def.name}', missing attr "
41                  f"'{name}'")
42
43
44def _AttrValue(attr_protos, name, op_type_name):
45  if name in attr_protos:
46    return attr_protos[name]
47  raise TypeError(f"Inconsistent OpDef for '{op_type_name}', missing attr "
48                  f"'{name}' from '{attr_protos}'.")
49
50
51def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
52  if attr_def.HasField("allowed_values"):
53    allowed_list = attr_def.allowed_values.list.type
54    allowed_values = ", ".join(dtypes.as_dtype(x).name for x in allowed_list)
55    if dtype not in allowed_list:
56      raise TypeError(
57          f"Value passed to parameter '{param_name}' has DataType "
58          f"{dtypes.as_dtype(dtype).name} not in list of allowed values: "
59          f"{allowed_values}")
60
61
62def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name):
63  if attr_def.has_minimum and length < attr_def.minimum:
64    raise ValueError(f"Attr '{param_name}' of '{op_type_name}' Op passed list "
65                     f"of length {length} less than minimum "
66                     f"{attr_def.minimum}.")
67
68
69def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name):
70  if value not in attr_def.allowed_values.list.s:
71    allowed_values = '", "'.join(
72        map(compat.as_text, attr_def.allowed_values.list.s))
73    raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed string "
74                     f"'{compat.as_text(value)}' not in: \"{allowed_values}\".")
75
76
77def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name):
78  if value < attr_def.minimum:
79    raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed {value} "
80                     f"less than minimum {attr_def.minimum}.")
81
82
83def _IsListParameter(arg):
84  if arg.number_attr:
85    return True
86  elif arg.type_list_attr:
87    return True
88  return False
89
90
91def _NumTypeFields(arg):
92  num = 0
93  if arg.type != types_pb2.DT_INVALID: num += 1
94  if arg.type_attr: num += 1
95  if arg.type_list_attr: num += 1
96  return num
97
98
99def _IsListValue(v):
100  return isinstance(v, (list, tuple))
101
102
103def _Flatten(l):
104  """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
105  # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
106  l_of_l = [x if _IsListValue(x) else [x] for x in l]
107  # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
108  return [item for sublist in l_of_l for item in sublist]
109
110
111def _Restructure(l, structure):
112  """Returns the elements of list l structured according to the given structure.
113
114  A structure is represented by a list whose elements are either
115  `None` or a non-negative integer. `None` corresponds to a single
116  element in the output list, and an integer N corresponds to a nested
117  list of length N.
118
119  The function returns a data structure whose shape is given by
120  `structure`, and whose elements are taken from `l`. If `structure`
121  is a singleton, the function returns the single data structure
122  implied by the 0th element of `structure`. For example:
123
124      _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
125        -> ["foo", ["bar", "baz"], "qux"]
126
127      _Restructure(["foo"], [None]) -> "foo"
128
129      _Restructure(["foo"], [1]) -> ["foo"]
130
131      _Restructure([], [0]) -> []
132
133  Args:
134    l: A list.
135    structure: A list whose elements are either `None` or a non-negative
136      integer.
137
138  Returns:
139    The elements of `l`, restructured according to `structure`. If
140    `structure` is a list of length 1, this function returns the
141    single data structure implied by `structure[0]`.
142
143  """
144  result = []
145  current_index = 0
146  for element in structure:
147    if element is None:
148      result.append(l[current_index])
149      current_index += 1
150    else:
151      result.append(l[current_index:current_index+element])
152      current_index += element
153
154  if len(result) == 1:
155    return result[0]
156  else:
157    return tuple(result)
158
159
160def _MakeFloat(v, arg_name):
161  if not isinstance(v, compat.real_types):
162    raise TypeError(f"Expected float for argument '{arg_name}' not {repr(v)}.")
163  return float(v)
164
165
166def _MakeInt(v, arg_name):
167  if isinstance(v, str):
168    raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.")
169  try:
170    return int(v)
171  except (ValueError, TypeError):
172    raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.")
173
174
175def _MakeStr(v, arg_name):
176  if not isinstance(v, compat.bytes_or_text_types):
177    raise TypeError(f"Expected string for argument '{arg_name}' not {repr(v)}.")
178  return compat.as_bytes(v)  # Convert unicode strings to bytes.
179
180
181def _MakeBool(v, arg_name):
182  if not isinstance(v, bool):
183    raise TypeError(f"Expected bool for argument '{arg_name}' not {repr(v)}.")
184  return v
185
186
187def _MakeType(v, arg_name):
188  try:
189    v = dtypes.as_dtype(v).base_dtype
190  except TypeError:
191    raise TypeError(f"Expected DataType for argument '{arg_name}' not "
192                    f"{repr(v)}.")
193  return v.as_datatype_enum
194
195
196def _MakeShape(v, arg_name):
197  """Convert v into a TensorShapeProto."""
198  # Args:
199  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
200  #   arg_name: String, for error messages.
201
202  # Returns:
203  #   A TensorShapeProto.
204  if isinstance(v, tensor_shape_pb2.TensorShapeProto):
205    for d in v.dim:
206      if d.name:
207        logging.warning("Warning: TensorShapeProto with a named dimension: %s",
208                        str(v))
209        break
210    return v
211  try:
212    return tensor_shape.as_shape(v).as_proto()
213  except TypeError as e:
214    raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a "
215                    f"TensorShape: {e}")
216  except ValueError as e:
217    raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a "
218                    f"TensorShape: {e}")
219
220
221def _MakeTensor(v, arg_name):
222  """Ensure v is a TensorProto."""
223  if isinstance(v, tensor_pb2.TensorProto):
224    return v
225  raise TypeError(
226      f"Don't know how to convert {repr(v)} to a TensorProto for argument "
227      f"'{arg_name}'")
228
229
230def _MakeFunc(v, arg_name):
231  """Ensure v is a func."""
232  if isinstance(v, attr_value_pb2.NameAttrList):
233    return v
234  if isinstance(v, compat.bytes_or_text_types):
235    fn_attr = attr_value_pb2.NameAttrList(name=v)
236  elif hasattr(v, "add_to_graph"):
237    v.add_to_graph(ops.get_default_graph())
238    if hasattr(v, "_as_name_attr_list"):
239      fn_attr = v._as_name_attr_list  # pylint: disable=protected-access
240    else:
241      fn_attr = attr_value_pb2.NameAttrList(name=v.name)
242  else:
243    raise TypeError(f"Don't know how to convert {repr(v)} to a func for "
244                    f"argument {arg_name}")
245  return fn_attr
246
247
248# pylint: disable=g-doc-return-or-yield
249@tf_contextlib.contextmanager
250def _MaybeColocateWith(inputs):
251  """A context manager for (maybe) colocating with a list of input tensors.
252
253  Args:
254    inputs: A list of `Tensor` or `Operation` objects.
255
256  Returns:
257    A context manager.
258  """
259  if not inputs:
260    yield
261  else:
262    # NOTE(mrry): The `ops.colocate_with()` function accepts only a single
263    # op or tensor, so we create one context manager per element in the list.
264    with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):
265      yield
266# pylint: enable=g-doc-return-or-yield
267
268
269def apply_op(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
270  """Add a node invoking a registered Op to a graph.
271
272  Example usage:
273     # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
274     # will convert to a Tensor.
275     op_def_library.apply_op("op", input1=input1, input2=input2)
276     # Can specify a node name.
277     op_def_library.apply_op("op", input1=input1, name="node_name")
278     # Must use keyword arguments, with the names specified in the OpDef.
279     op_def_library.apply_op("op", input_name=input, attr_name=attr)
280
281  All attrs must either be inferred from an input or specified.
282  (If inferred, the attr must not be specified.)  If an attr has a default
283  value specified in the Op's OpDef, then you may pass None as the value
284  of that attr to get the default.
285
286  Args:
287    op_type_name: string. Must match the name field of a registered Op.
288    name: string. Optional name of the created op.
289    **keywords: input Tensor and attr arguments specified by name, and optional
290      parameters to pass when constructing the Operation.
291
292  Returns:
293    The Tensor(s) representing the output of the operation, or the Operation
294    itself if there are no outputs.
295
296  Raises:
297    RuntimeError: On some errors.
298    TypeError: On some errors.
299    ValueError: On some errors.
300  """
301  output_structure, is_stateful, op, outputs = _apply_op_helper(
302      op_type_name, name, **keywords)
303  if output_structure:
304    res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
305    if isinstance(res, list) and not res and is_stateful:
306      return op
307    else:
308      return res
309  else:
310    return op
311
312
313# This is temporary Python/C++ code duplication until all of it can be ported
314# over to C++.
315# LINT.IfChange
316def _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos):
317  """Extracts `attr_protos`. For use in _apply_op_helper."""
318  for attr_def in op_def.attr:
319    key = attr_def.name
320    value = attrs[key]
321
322    if attr_def.HasField("default_value") and value is None:
323      attr_value = attr_value_pb2.AttrValue()
324      attr_value.CopyFrom(attr_def.default_value)
325      attr_protos[key] = attr_value
326      continue
327
328    attr_value = value_to_attr_value(value, attr_def.type, key)
329    if attr_def.type.startswith("list("):
330      _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
331    if attr_def.HasField("allowed_values"):
332      if attr_def.type == "string":
333        _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
334                                           op_type_name)
335      elif attr_def.type == "list(string)":
336        for value in attr_value.list.s:
337          _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name)
338    if attr_def.has_minimum and attr_def.type == "int":
339      _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name)
340    if attr_def.type == "type":
341      _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
342    if attr_def.type == "list(type)":
343      for value in attr_value.list.type:
344        _SatisfiesTypeConstraint(value, attr_def, key)
345
346    attr_protos[key] = attr_value
347
348
349def _ExtractOutputStructure(op_type_name, op_def, attr_protos,
350                            output_structure):
351  """Extracts `output_structure`. For use in _apply_op_helper."""
352  for arg in op_def.output_arg:
353    if arg.number_attr:
354      n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i
355      output_structure.append(n)
356    elif arg.type_attr:
357      t = _AttrValue(attr_protos, arg.type_attr, op_type_name)
358      output_structure.append(None)
359    elif arg.type_list_attr:
360      t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name)
361      output_structure.append(len(t.list.type))
362    else:
363      output_structure.append(None)
364
365
366def _CanExtractAttrsFastPath(op_def, keywords):
367  """Check if the fast path for _apply_op_helper is applicable."""
368  # Check if all inputs are already tf.Tensor
369  for input_arg in op_def.input_arg:
370    value = keywords.get(input_arg.name, None)
371    if not isinstance(value, ops.Tensor):
372      return False
373
374  # Check that attrs are not `func` or `list(func)` type.
375  for attr_def in op_def.attr:
376    if attr_def.type == "func" or attr_def.type == "list(func)":
377      return False
378
379  return True
380
381
382def _CheckOpDeprecation(op_type_name, op_def, producer):
383  """Checks if the op is deprecated."""
384  deprecation_version = op_def.deprecation.version
385  if deprecation_version and producer >= deprecation_version:
386    raise NotImplementedError(
387        f"Op {op_type_name} is not available in GraphDef version {producer}. "
388        f"It has been removed in version {deprecation_version}. "
389        f"{op_def.deprecation.explanation}.")
390
391
392def _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map,
393                                        allowed_list_attr_map):
394  """Extracts the `default_type_attr_map` and `allowed_list_attr_map`."""
395  # TODO(b/31302892): Currently the defaults don't work in the right
396  # way if you have two inputs, one of whose type resolution depends
397  # on the other.  Handling this will require restructuring this code
398  # significantly.
399  for attr_def in op_def.attr:
400    if attr_def.type != "type":
401      continue
402    key = attr_def.name
403    if attr_def.HasField("default_value"):
404      default_type_attr_map[key] = dtypes.as_dtype(
405          attr_def.default_value.type)
406    if attr_def.HasField("allowed_values"):
407      allowed_list_attr_map[key] = attr_def.allowed_values.list.type
408
409
410def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
411                           keywords, default_type_attr_map, attrs, inputs,
412                           input_types):
413  """Extracts `attrs`, `inputs`, and `input_types` in _apply_op_helper."""
414  inferred_from = {}
415  for input_arg in op_def.input_arg:
416    input_name = input_arg.name
417    if input_name in keywords:
418      values = keywords.pop(input_name)
419    elif input_name + "_" in keywords:
420      # Handle the case where the name is a keyword or built-in
421      # for Python so we use the name + _ instead.
422      input_name += "_"
423      values = keywords.pop(input_name)
424    else:
425      raise TypeError(f"No argument for input {input_name} found in {op_def}")
426
427    # Goals:
428    # * Convert values to Tensors if it contains constants.
429    # * Verify that values is a list if that matches the input_arg's
430    #   type.
431    # * If the input_arg's type is determined by attrs, either set
432    #   those attrs and validate those attr values are legal (if
433    #   they have not yet been set) or validate the input matches
434    #   the type indicated by the attrs (if they have already been
435    #   inferred via an earlier input).
436    # * If the input_arg has an explicit type, make sure the input
437    #   conforms.
438
439    if _IsListParameter(input_arg):
440      if not _IsListValue(values):
441        raise TypeError(
442            f"Expected list for '{input_name}' argument to '{op_type_name}' "
443            f"Op, not {values}.")
444      # In cases where we expect all elements of the list to have the
445      # same dtype, try to cast non-Tensor elements to that type.
446      dtype = None
447      default_dtype = None
448      if input_arg.type != types_pb2.DT_INVALID:
449        dtype = input_arg.type
450      elif input_arg.number_attr:
451        if input_arg.type_attr in attrs:
452          dtype = attrs[input_arg.type_attr]
453        else:
454          for t in values:
455            if isinstance(t, ops.Tensor):
456              dtype = t.dtype
457              break
458
459        # dtype still not found, prefer using the default dtype
460        # from the attr.
461        if dtype is None and input_arg.type_attr in default_type_attr_map:
462          default_dtype = default_type_attr_map[input_arg.type_attr]
463
464      try:
465        if not input_arg.is_ref and dtype:
466          dtype = dtypes.as_dtype(dtype).base_dtype
467        values = ops.internal_convert_n_to_tensor(
468            values,
469            name=input_arg.name,
470            dtype=dtype if dtype else None,
471            preferred_dtype=default_dtype,
472            as_ref=input_arg.is_ref)
473        all_types = set(v.dtype.base_dtype for v in values)
474        if input_arg.number_attr and len(all_types) > 1:
475          # All types should match.
476          raise TypeError(f"Not all types matched for {input_arg.name} for "
477                          f"{op_type_name}. Got {all_types}")
478      except (TypeError, ValueError):
479        # What types does the conversion function think values have?
480        observed_types = []
481        for value in values:
482          try:
483            converted_value = ops.convert_to_tensor(
484                value, as_ref=input_arg.is_ref)
485            observed_types.append(converted_value.dtype.base_dtype.name)
486          except (TypeError, ValueError):
487            observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
488        observed = ", ".join(observed_types)
489
490        prefix = ("Tensors in list passed to '%s' of '%s' Op have types [%s]" %
491                  (input_name, op_type_name, observed))
492        if input_arg.number_attr:
493          if input_arg.type != types_pb2.DT_INVALID:
494            raise TypeError(f"{prefix} that do not match expected type "
495                            f"{dtype.name}.")
496          elif input_arg.type_attr in attrs:
497            raise TypeError(f"{prefix} that do not match type {dtype.name} "
498                            "inferred from earlier arguments.")
499          else:
500            raise TypeError(f"{prefix} that don't all match.")
501        else:
502          raise TypeError(f"{prefix} that are invalid. Tensors: {values}")
503
504      types = [x.dtype for x in values]
505      inputs.extend(values)
506    else:
507      # In cases where we have an expected type, try to convert non-Tensor
508      # arguments to that type.
509      dtype = None
510      default_dtype = None
511      allowed_list = None
512      if input_arg.type != types_pb2.DT_INVALID:
513        dtype = input_arg.type
514      elif input_arg.type_attr in attrs:
515        dtype = attrs[input_arg.type_attr]
516      elif input_arg.type_attr in default_type_attr_map:
517        # The dtype could not be inferred solely from the inputs,
518        # so we prefer the attr's default, so code that adds a new attr
519        # with a default is backwards compatible.
520        default_dtype = default_type_attr_map[input_arg.type_attr]
521        allowed_list = allowed_list_attr_map.get(input_arg.type_attr)
522
523      try:
524        # First see if we can get a valid dtype with the default conversion
525        # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
526        # not list allowed dtypes, in which case we should skip this.
527        if dtype is None and allowed_list:
528          inferred = None
529          try:
530            inferred = ops.convert_to_tensor(
531                values, name=input_arg.name, as_ref=input_arg.is_ref)
532          except TypeError as err:
533            # When converting a python object such as a list of Dimensions, we
534            # need a dtype to be specified, thus tensor conversion may throw
535            # an exception which we will ignore and try again below.
536            pass
537
538          # If we did not match an allowed dtype, try again with the default
539          # dtype. This could be because we have an empty tensor and thus we
540          # picked the wrong type.
541          if inferred is not None and inferred.dtype in allowed_list:
542            values = inferred
543          else:
544            values = ops.convert_to_tensor(
545                values,
546                name=input_arg.name,
547                as_ref=input_arg.is_ref,
548                preferred_dtype=default_dtype)
549        else:
550          values = ops.convert_to_tensor(
551              values,
552              name=input_arg.name,
553              dtype=dtype,
554              as_ref=input_arg.is_ref,
555              preferred_dtype=default_dtype)
556      except TypeError as err:
557        if dtype is None:
558          raise err
559        else:
560          raise TypeError(
561              f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
562              f"'{input_arg.name}' of op '{op_type_name}', got "
563              f"{repr(values)} of type '{type(values).__name__}' instead. "
564              f"Error: {err}")
565      except ValueError:
566        # What type does convert_to_tensor think it has?
567        try:
568          observed = ops.convert_to_tensor(
569              values, as_ref=input_arg.is_ref).dtype.name
570        except ValueError as err:
571          raise ValueError(
572              f"Tried to convert '{input_name}' to a tensor and failed. "
573              f"Error: {err}")
574        prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
575                  (input_name, op_type_name, observed))
576        if input_arg.type != types_pb2.DT_INVALID:
577          raise TypeError(f"{prefix} expected type of "
578                          f"{dtypes.as_dtype(input_arg.type).name}.")
579        else:
580          # Update the maps with the default, if needed.
581          k = input_arg.type_attr
582          if k in default_type_attr_map:
583            if k not in attrs:
584              attrs[k] = default_type_attr_map[k]
585              if k not in inferred_from:
586                inferred_from[k] = "Default in OpDef"
587
588          raise TypeError(
589              f"{prefix} type "
590              f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
591              f"argument '{inferred_from[input_arg.type_attr]}'.")
592
593      types = [values.dtype]
594      inputs.append(values)
595    base_types = [x.base_dtype for x in types]
596
597    if input_arg.number_attr:
598      # <number-attr> * <type> or <number-attr> * <type-attr>
599      if input_arg.number_attr in attrs:
600        if len(values) != attrs[input_arg.number_attr]:
601          raise ValueError(
602              f"List argument '{input_name}' to '{op_type_name}' Op with "
603              f"length {len(values)} must match length "
604              f"{attrs[input_arg.number_attr]} of argument "
605              f"'{inferred_from[input_arg.number_attr]}'.")
606      else:
607        attrs[input_arg.number_attr] = len(values)
608        inferred_from[input_arg.number_attr] = input_name
609        num_attr = _Attr(op_def, input_arg.number_attr)
610        if num_attr.has_minimum and len(values) < num_attr.minimum:
611          raise ValueError(
612              f"List argument '{input_name}' to '{op_type_name}' Op with "
613              f"length {len(values)} shorter than minimum length "
614              f"{num_attr.minimum}.")
615      # All tensors must have the same base type.
616      if any(bt != base_types[0] for bt in base_types):
617        raise TypeError(
618            f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
619            f"must have the same type. Got {base_types} instead.")
620      if input_arg.type != types_pb2.DT_INVALID:
621        # <number-attr> * <type> case
622        if base_types and base_types[0] != input_arg.type:
623          assert False, "Unreachable"
624      elif input_arg.type_attr in attrs:
625        # <number-attr> * <type-attr> case, where <type-attr> already
626        # has an inferred value.
627        if base_types and base_types[0] != attrs[input_arg.type_attr]:
628          assert False, "Unreachable"
629      else:
630        # <number-attr> * <type-attr> case, where we are now setting
631        # the <type-attr> based on this input
632        if not base_types:
633          # If it's in default_type_attr_map, then wait to set it
634          # (in "process remaining attrs", below).
635          if input_arg.type_attr not in default_type_attr_map:
636            raise TypeError(
637                "Don't know how to infer type variable from empty input "
638                f"list passed to input '{input_name}' of '{op_type_name}' "
639                "Op.")
640        else:
641          attrs[input_arg.type_attr] = base_types[0]
642          inferred_from[input_arg.type_attr] = input_name
643          type_attr = _Attr(op_def, input_arg.type_attr)
644          _SatisfiesTypeConstraint(
645              base_types[0], type_attr, param_name=input_name)
646    elif input_arg.type_attr:
647      # <type-attr>
648      attr_value = base_types[0]
649      if input_arg.type_attr in attrs:
650        if attrs[input_arg.type_attr] != attr_value:
651          raise TypeError(
652              f"Input '{input_name}' of '{op_type_name}' Op has type "
653              f"{dtypes.as_dtype(attr_value).name} that does not match type "
654              f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
655              f"argument '{inferred_from[input_arg.type_attr]}'.")
656      else:
657        for base_type in base_types:
658          _SatisfiesTypeConstraint(
659              base_type,
660              _Attr(op_def, input_arg.type_attr),
661              param_name=input_name)
662        attrs[input_arg.type_attr] = attr_value
663        inferred_from[input_arg.type_attr] = input_name
664    elif input_arg.type_list_attr:
665      # <type-list-attr>
666      attr_value = base_types
667      if input_arg.type_list_attr in attrs:
668        if attrs[input_arg.type_list_attr] != attr_value:
669          actual_types = ", ".join(dtypes.as_dtype(x).name for x in attr_value)
670          expected_types = ", ".join(
671              dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr])
672          raise TypeError(
673              f"Input '{input_name}' of '{op_type_name}' Op has type list of "
674              f"{actual_types} that does not match type list {expected_types}"
675              f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
676      else:
677        for base_type in base_types:
678          _SatisfiesTypeConstraint(
679              base_type,
680              _Attr(op_def, input_arg.type_list_attr),
681              param_name=input_name)
682        attrs[input_arg.type_list_attr] = attr_value
683        inferred_from[input_arg.type_list_attr] = input_name
684    else:
685      # single Tensor with specified type
686      if base_types[0] != input_arg.type:
687        assert False, "Unreachable"
688
689    if input_arg.is_ref:
690      if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
691        raise TypeError(
692            f"'{op_type_name}' Op requires that input '{input_name}' be a "
693            "mutable tensor (e.g.: a tf.Variable)")
694      input_types.extend(types)
695    else:
696      input_types.extend(base_types)
697
698
699def _ExtractRemainingAttrs(op_type_name, op_def, keywords,
700                           default_type_attr_map, attrs):
701  """Extracts the remaining attributes into `attrs` in _apply_op_helper."""
702  for attr in op_def.attr:
703    # Skip attrs that have already had their values inferred
704    if attr.name in attrs:
705      if attr.name in keywords:
706        raise TypeError(
707            f"Should not specify value for inferred attr '{attr.name}' for "
708            f"{op_type_name}.")
709      continue
710    if attr.name in keywords:
711      attrs[attr.name] = keywords.pop(attr.name)
712    elif attr.name + "_" in keywords:
713      # Attrs whose names match Python keywords have an extra '_'
714      # appended, so we must check for that as well.
715      attrs[attr.name] = keywords.pop(attr.name + "_")
716    elif attr.name in default_type_attr_map:
717      attrs[attr.name] = default_type_attr_map[attr.name]
718    else:
719      raise TypeError(f"No argument found for attr {attr.name} for "
720                      f"{op_type_name}")
721
722
723def _GetOpDef(op_type_name, keywords):
724  """Returns the OpDef, Graph and Producer. For use in _apply_op_helper."""
725  op_def = op_def_registry.get(op_type_name)
726  if op_def is None:
727    raise RuntimeError(f"Unrecognized Op name {op_type_name}")
728
729  # Determine the graph context.
730  try:
731    # Need to flatten all the arguments into a list.
732    # pylint: disable=protected-access
733    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
734    producer = g.graph_def_versions.producer
735    # pylint: enable=protected-access
736  except AssertionError as e:
737    raise RuntimeError(
738        f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")
739
740  return op_def, g, producer
741
742
743def _CheckAllInputsUsed(op_type_name, keywords):
744  """Ensures all inputs passed into _apply_op_helper were used."""
745  if keywords:
746    all_keywords = ", ".join(sorted(keywords.keys()))
747    raise TypeError(f"{op_type_name} got unexpected keyword arguments: "
748                    f"{all_keywords}.")
749
750
751def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
752  """Implementation of apply_op that returns output_structure, op."""
753
754  op_def, g, producer = _GetOpDef(op_type_name, keywords)
755  name = name if name else op_type_name
756
757  attrs, attr_protos = {}, {}
758  default_type_attr_map, allowed_list_attr_map = {}, {}
759  inputs, input_types, output_structure = [], [], []
760  fallback = True
761
762  if (_CanExtractAttrsFastPath(op_def, keywords) and
763      flags.config().graph_building_optimization.value()):
764    fallback = False
765    attr_protos, inputs, input_types, output_structure = (
766        op_def_library_pybind.process_inputs(op_type_name, producer, keywords))
767
768  if fallback:
769    _CheckOpDeprecation(op_type_name, op_def, producer)
770    _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map,
771                                        allowed_list_attr_map)
772
773  # Requires that op_def has passed validation (using the C++
774  # ValidateOpDef() from ../framework/op_def_util.h).
775  with g.as_default(), ops.name_scope(name) as scope:
776    if fallback:
777      _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
778                             keywords, default_type_attr_map, attrs, inputs,
779                             input_types)
780      _ExtractRemainingAttrs(op_type_name, op_def, keywords,
781                             default_type_attr_map, attrs)
782      _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos)
783      del attrs  # attrs is no longer authoritative, use attr_protos instead
784      _ExtractOutputStructure(op_type_name, op_def, attr_protos,
785                              output_structure)
786      _CheckAllInputsUsed(op_type_name, keywords)
787
788    # NOTE(mrry): We add an explicit colocation constraint between
789    # the newly created op and any of its reference-typed inputs.
790    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
791                            if arg.is_ref]
792    with _MaybeColocateWith(must_colocate_inputs):
793      # Add Op to graph
794      # pylint: disable=protected-access
795      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
796                                 name=scope, input_types=input_types,
797                                 attrs=attr_protos, op_def=op_def)
798
799    # `outputs` is returned as a separate return value so that the output
800    # tensors can the `op` per se can be decoupled so that the
801    # `op_callbacks` can function properly. See framework/op_callbacks.py
802    # for more details.
803    outputs = op.outputs
804    # Conditionally invoke tfdbg v2's op callback(s).
805    if op_callbacks.should_invoke_op_callbacks():
806      callback_outputs = op_callbacks.invoke_op_callbacks(
807          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
808          op_name=op.name, graph=g)
809      if callback_outputs is not None:
810        outputs = callback_outputs
811
812    return output_structure, op_def.is_stateful, op, outputs
813
814
815def value_to_attr_value(value, attr_type, arg_name):  # pylint: disable=invalid-name
816  """Encodes a Python value as an `AttrValue` proto message.
817
818  Args:
819    value: The value to convert.
820    attr_type: The value type (string) -- see the AttrValue proto definition for
821      valid strings.
822    arg_name: Argument name (for error messages).
823
824  Returns:
825    An AttrValue proto message that encodes `value`.
826  """
827  attr_value = attr_value_pb2.AttrValue()
828
829  if attr_type.startswith("list("):
830    if not _IsListValue(value):
831      raise TypeError(f"Expected list for attr {arg_name}, obtained "
832                      f"{type(value).__name__} instead.")
833
834  if attr_type == "string":
835    attr_value.s = _MakeStr(value, arg_name)
836  elif attr_type == "list(string)":
837    attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value])
838  elif attr_type == "int":
839    attr_value.i = _MakeInt(value, arg_name)
840  elif attr_type == "list(int)":
841    attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value])
842  elif attr_type == "float":
843    attr_value.f = _MakeFloat(value, arg_name)
844  elif attr_type == "list(float)":
845    attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value])
846  elif attr_type == "bool":
847    attr_value.b = _MakeBool(value, arg_name)
848  elif attr_type == "list(bool)":
849    attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value])
850  elif attr_type == "type":
851    attr_value.type = _MakeType(value, arg_name)
852  elif attr_type == "list(type)":
853    attr_value.list.type.extend([_MakeType(x, arg_name) for x in value])
854  elif attr_type == "shape":
855    attr_value.shape.CopyFrom(_MakeShape(value, arg_name))
856  elif attr_type == "list(shape)":
857    attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value])
858  elif attr_type == "tensor":
859    attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name))
860  elif attr_type == "list(tensor)":
861    attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value])
862  elif attr_type == "func":
863    attr_value.func.CopyFrom(_MakeFunc(value, arg_name))
864  elif attr_type == "list(func)":
865    attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value])
866  else:
867    raise TypeError(f"Unrecognized Attr type {attr_type} for {arg_name}.")
868  return attr_value
869# LINT.ThenChange(//tensorflow/python/framework/op_def_library_pybind.cc)
870
871
872# The following symbols are used by op_def_util.cc.
873_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
874_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)
875_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape)
876_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape)
877_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto)
878_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse)
879_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor)
880