• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to hold a library of OpDefs and use it to create Brain operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23
24from google.protobuf import text_format
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import tensor_pb2
27from tensorflow.core.framework import tensor_shape_pb2
28from tensorflow.core.framework import types_pb2
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import op_callbacks
31from tensorflow.python.framework import op_def_registry
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.util import _pywrap_utils
36from tensorflow.python.util import compat
37from tensorflow.python.util import tf_contextlib
38
39
40def _Attr(op_def, name):
41  for attr in op_def.attr:
42    if attr.name == name:
43      return attr
44  raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
45                  (op_def.name, name))
46
47
48def _AttrValue(attr_protos, name):
49  if name in attr_protos:
50    return attr_protos[name]
51  raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
52                  (name, attr_protos))
53
54
55def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
56  if attr_def.HasField("allowed_values"):
57    allowed_list = attr_def.allowed_values.list.type
58    if dtype not in allowed_list:
59      raise TypeError(
60          "Value passed to parameter '%s' has DataType %s not in list of "
61          "allowed values: %s" %
62          (param_name, dtypes.as_dtype(dtype).name,
63           ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
64
65
66def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name):
67  if attr_def.has_minimum and length < attr_def.minimum:
68    raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
69                     "less than minimum %d." %
70                     (param_name, op_type_name, length, attr_def.minimum))
71
72
73def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name):
74  if value not in attr_def.allowed_values.list.s:
75    raise ValueError(
76        "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
77        (arg_name, op_type_name, compat.as_text(value), '", "'.join(
78            map(compat.as_text, attr_def.allowed_values.list.s))))
79
80
81def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name):
82  if value < attr_def.minimum:
83    raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." %
84                     (arg_name, op_type_name, value, attr_def.minimum))
85
86
87def _IsListParameter(arg):
88  if arg.number_attr:
89    return True
90  elif arg.type_list_attr:
91    return True
92  return False
93
94
95def _NumTypeFields(arg):
96  num = 0
97  if arg.type != types_pb2.DT_INVALID: num += 1
98  if arg.type_attr: num += 1
99  if arg.type_list_attr: num += 1
100  return num
101
102
103def _IsListValue(v):
104  return isinstance(v, (list, tuple))
105
106
107def _Flatten(l):
108  """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
109  # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
110  l_of_l = [x if _IsListValue(x) else [x] for x in l]
111  # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
112  return [item for sublist in l_of_l for item in sublist]
113
114
115def _Restructure(l, structure):
116  """Returns the elements of list l structured according to the given structure.
117
118  A structure is represented by a list whose elements are either
119  `None` or a non-negative integer. `None` corresponds to a single
120  element in the output list, and an integer N corresponds to a nested
121  list of length N.
122
123  The function returns a data structure whose shape is given by
124  `structure`, and whose elements are taken from `l`. If `structure`
125  is a singleton, the function returns the single data structure
126  implied by the 0th element of `structure`. For example:
127
128      _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
129        -> ["foo", ["bar", "baz"], "qux"]
130
131      _Restructure(["foo"], [None]) -> "foo"
132
133      _Restructure(["foo"], [1]) -> ["foo"]
134
135      _Restructure([], [0]) -> []
136
137  Args:
138    l: A list.
139    structure: A list whose elements are either `None` or a non-negative
140      integer.
141
142  Returns:
143    The elements of `l`, restructured according to `structure`. If
144    `structure` is a list of length 1, this function returns the
145    single data structure implied by `structure[0]`.
146
147  """
148  result = []
149  current_index = 0
150  for element in structure:
151    if element is None:
152      result.append(l[current_index])
153      current_index += 1
154    else:
155      result.append(l[current_index:current_index+element])
156      current_index += element
157
158  if len(result) == 1:
159    return result[0]
160  else:
161    return tuple(result)
162
163
164def _MakeFloat(v, arg_name):
165  if not isinstance(v, compat.real_types):
166    raise TypeError("Expected float for argument '%s' not %s." %
167                    (arg_name, repr(v)))
168  return float(v)
169
170
171def _MakeInt(v, arg_name):
172  if isinstance(v, six.string_types):
173    raise TypeError("Expected int for argument '%s' not %s." %
174                    (arg_name, repr(v)))
175  try:
176    return int(v)
177  except (ValueError, TypeError):
178    raise TypeError("Expected int for argument '%s' not %s." %
179                    (arg_name, repr(v)))
180
181
182def _MakeStr(v, arg_name):
183  if not isinstance(v, compat.bytes_or_text_types):
184    raise TypeError("Expected string for argument '%s' not %s." %
185                    (arg_name, repr(v)))
186  return compat.as_bytes(v)  # Convert unicode strings to bytes.
187
188
189def _MakeBool(v, arg_name):
190  if not isinstance(v, bool):
191    raise TypeError("Expected bool for argument '%s' not %s." %
192                    (arg_name, repr(v)))
193  return v
194
195
196def _MakeType(v, arg_name):
197  try:
198    v = dtypes.as_dtype(v).base_dtype
199  except TypeError:
200    raise TypeError("Expected DataType for argument '%s' not %s." %
201                    (arg_name, repr(v)))
202  return v.as_datatype_enum
203
204
205def _MakeShape(v, arg_name):
206  """Convert v into a TensorShapeProto."""
207  # Args:
208  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
209  #   arg_name: String, for error messages.
210
211  # Returns:
212  #   A TensorShapeProto.
213  if isinstance(v, tensor_shape_pb2.TensorShapeProto):
214    for d in v.dim:
215      if d.name:
216        logging.warning("Warning: TensorShapeProto with a named dimension: %s",
217                        str(v))
218        break
219    return v
220  try:
221    return tensor_shape.as_shape(v).as_proto()
222  except TypeError as e:
223    raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
224  except ValueError as e:
225    raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
226
227
228def _MakeTensor(v, arg_name):
229  """Ensure v is a TensorProto."""
230  if isinstance(v, tensor_pb2.TensorProto):
231    return v
232  raise TypeError(
233      "Don't know how to convert %s to a TensorProto for argument '%s'" %
234      (repr(v), arg_name))
235
236
237def _MakeFunc(v, arg_name):
238  """Ensure v is a func."""
239  if isinstance(v, attr_value_pb2.NameAttrList):
240    return v
241  if isinstance(v, compat.bytes_or_text_types):
242    fn_attr = attr_value_pb2.NameAttrList(name=v)
243  elif hasattr(v, "add_to_graph"):
244    v.add_to_graph(ops.get_default_graph())
245    if hasattr(v, "_as_name_attr_list"):
246      fn_attr = v._as_name_attr_list  # pylint: disable=protected-access
247    else:
248      fn_attr = attr_value_pb2.NameAttrList(name=v.name)
249  else:
250    raise TypeError("Don't know how to convert {} to a func for "
251                    "argument {}".format(v, arg_name))
252  return fn_attr
253
254
255# pylint: disable=g-doc-return-or-yield
256@tf_contextlib.contextmanager
257def _MaybeColocateWith(inputs):
258  """A context manager for (maybe) colocating with a list of input tensors.
259
260  Args:
261    inputs: A list of `Tensor` or `Operation` objects.
262
263  Returns:
264    A context manager.
265  """
266  if not inputs:
267    yield
268  else:
269    # NOTE(mrry): The `ops.colocate_with()` function accepts only a single
270    # op or tensor, so we create one context manager per element in the list.
271    with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):
272      yield
273# pylint: enable=g-doc-return-or-yield
274
275
276def apply_op(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
277  """Add a node invoking a registered Op to a graph.
278
279  Example usage:
280     # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
281     # will convert to a Tensor.
282     op_def_library.apply_op("op", input1=input1, input2=input2)
283     # Can specify a node name.
284     op_def_library.apply_op("op", input1=input1, name="node_name")
285     # Must use keyword arguments, with the names specified in the OpDef.
286     op_def_library.apply_op("op", input_name=input, attr_name=attr)
287
288  All attrs must either be inferred from an input or specified.
289  (If inferred, the attr must not be specified.)  If an attr has a default
290  value specified in the Op's OpDef, then you may pass None as the value
291  of that attr to get the default.
292
293  Args:
294    op_type_name: string. Must match the name field of a registered Op.
295    name: string. Optional name of the created op.
296    **keywords: input Tensor and attr arguments specified by name,
297      and optional parameters to pass when constructing the Operation.
298
299  Returns:
300    The Tensor(s) representing the output of the operation, or the Operation
301    itself if there are no outputs.
302
303  Raises:
304    RuntimeError: On some errors.
305    TypeError: On some errors.
306    ValueError: On some errors.
307  """
308  output_structure, is_stateful, op, outputs = _apply_op_helper(
309      op_type_name, name, **keywords)
310  if output_structure:
311    res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
312    if isinstance(res, list) and not res and is_stateful:
313      return op
314    else:
315      return res
316  else:
317    return op
318
319
320def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
321  """Implementation of apply_op that returns output_structure, op."""
322  op_def = op_def_registry.get(op_type_name)
323  if op_def is None:
324    raise RuntimeError("Unrecognized Op name " + op_type_name)
325
326  # Determine the graph context.
327  try:
328    # Need to flatten all the arguments into a list.
329    # pylint: disable=protected-access
330    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
331    # pylint: enable=protected-access
332  except AssertionError as e:
333    raise RuntimeError(
334        "Cannot determine graph for Op '%s' due to: %s"
335        % (op_type_name, e.message))
336
337  # Default name if not specified.
338  if name is None:
339    name = op_type_name
340
341  # Check for deprecation
342  deprecation_version = op_def.deprecation.version
343  if deprecation_version:
344    producer = g.graph_def_versions.producer
345    if producer >= deprecation_version:
346      raise NotImplementedError(
347          ("Op %s is not available in GraphDef version %d. "
348           "It has been removed in version %d. %s.") %
349          (op_type_name, producer, deprecation_version,
350           op_def.deprecation.explanation))
351
352  # Fill in the list of default types for all "type" attrs.  This
353  # will be used to choose a preferred dtype to convert to in the
354  # absence of input type information.
355  #
356  # TODO(b/31302892): Currently the defaults don't work in the right
357  # way if you have two inputs, one of whose type resolution depends
358  # on the other.  Handling this will require restructuring this code
359  # significantly.
360  default_type_attr_map = {}
361  allowed_list_attr_map = {}
362  for attr_def in op_def.attr:
363    if attr_def.type != "type":
364      continue
365    key = attr_def.name
366    if attr_def.HasField("default_value"):
367      default_type_attr_map[key] = dtypes.as_dtype(
368          attr_def.default_value.type)
369    if attr_def.HasField("allowed_values"):
370      allowed_list_attr_map[key] = attr_def.allowed_values.list.type
371
372  # Requires that op_def has passed validation (using the C++
373  # ValidateOpDef() from ../framework/op_def_util.h).
374  attrs = {}
375  inputs = []
376  input_types = []
377  with g.as_default(), ops.name_scope(name) as scope:
378
379    # Perform input type inference
380    inferred_from = {}
381    for input_arg in op_def.input_arg:
382      input_name = input_arg.name
383      if input_name in keywords:
384        values = keywords.pop(input_name)
385      elif input_name + "_" in keywords:
386        # Handle the case where the name is a keyword or built-in
387        # for Python so we use the name + _ instead.
388        input_name += "_"
389        values = keywords.pop(input_name)
390      else:
391        raise TypeError("No argument for input " + input_name)
392
393      # Goals:
394      # * Convert values to Tensors if it contains constants.
395      # * Verify that values is a list if that matches the input_arg's
396      #   type.
397      # * If the input_arg's type is determined by attrs, either set
398      #   those attrs and validate those attr values are legal (if
399      #   they have not yet been set) or validate the input matches
400      #   the type indicated by the attrs (if they have already been
401      #   inferred via an earlier input).
402      # * If the input_arg has an explicit type, make sure the input
403      #   conforms.
404
405      if _IsListParameter(input_arg):
406        if not _IsListValue(values):
407          raise TypeError(
408              "Expected list for '%s' argument to '%s' Op, not %s." %
409              (input_name, op_type_name, values))
410        # In cases where we expect all elements of the list to have the
411        # same dtype, try to cast non-Tensor elements to that type.
412        dtype = None
413        default_dtype = None
414        if input_arg.type != types_pb2.DT_INVALID:
415          dtype = input_arg.type
416        elif input_arg.number_attr:
417          if input_arg.type_attr in attrs:
418            dtype = attrs[input_arg.type_attr]
419          else:
420            for t in values:
421              if isinstance(t, ops.Tensor):
422                dtype = t.dtype
423                break
424
425          # dtype still not found, prefer using the default dtype
426          # from the attr.
427          if dtype is None and input_arg.type_attr in default_type_attr_map:
428            default_dtype = default_type_attr_map[input_arg.type_attr]
429
430        try:
431          if not input_arg.is_ref and dtype:
432            dtype = dtypes.as_dtype(dtype).base_dtype
433          values = ops.internal_convert_n_to_tensor(
434              values,
435              name=input_arg.name,
436              dtype=dtype if dtype else None,
437              preferred_dtype=default_dtype,
438              as_ref=input_arg.is_ref)
439          if input_arg.number_attr and len(
440              set(v.dtype.base_dtype for v in values)) > 1:
441            raise TypeError()  # All types should match.
442        except (TypeError, ValueError):
443          # What types does the conversion function think values have?
444          observed_types = []
445          for value in values:
446            try:
447              converted_value = ops.convert_to_tensor(
448                  value, as_ref=input_arg.is_ref)
449              observed_types.append(converted_value.dtype.base_dtype.name)
450            except (TypeError, ValueError):
451              observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
452          observed = ", ".join(observed_types)
453
454          prefix = (
455              "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
456              (input_name, op_type_name, observed))
457          if input_arg.number_attr:
458            if input_arg.type != types_pb2.DT_INVALID:
459              raise TypeError("%s that do not match expected type %s." %
460                              (prefix, dtype.name))
461            elif input_arg.type_attr in attrs:
462              raise TypeError("%s that do not match type %s inferred from "
463                              "earlier arguments." %
464                              (prefix, dtype.name))
465            else:
466              raise TypeError("%s that don't all match." % prefix)
467          else:
468            raise TypeError(
469                "%s that are invalid. Tensors: %s" % (prefix, values))
470
471        types = [x.dtype for x in values]
472        inputs.extend(values)
473      else:
474        # In cases where we have an expected type, try to convert non-Tensor
475        # arguments to that type.
476        dtype = None
477        default_dtype = None
478        allowed_list = None
479        if input_arg.type != types_pb2.DT_INVALID:
480          dtype = input_arg.type
481        elif input_arg.type_attr in attrs:
482          dtype = attrs[input_arg.type_attr]
483        elif input_arg.type_attr in default_type_attr_map:
484          # The dtype could not be inferred solely from the inputs,
485          # so we prefer the attr's default, so code that adds a new attr
486          # with a default is backwards compatible.
487          default_dtype = default_type_attr_map[input_arg.type_attr]
488          allowed_list = allowed_list_attr_map.get(input_arg.type_attr)
489
490        try:
491          # First see if we can get a valid dtype with the default conversion
492          # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
493          # not list allowed dtypes, in which case we should skip this.
494          if dtype is None and allowed_list:
495            inferred = None
496            try:
497              inferred = ops.convert_to_tensor(
498                  values, name=input_arg.name, as_ref=input_arg.is_ref)
499            except TypeError as err:
500              # When converting a python object such as a list of Dimensions, we
501              # need a dtype to be specified, thus tensor conversion may throw
502              # an exception which we will ignore and try again below.
503              pass
504
505            # If we did not match an allowed dtype, try again with the default
506            # dtype. This could be because we have an empty tensor and thus we
507            # picked the wrong type.
508            if inferred is not None and inferred.dtype in allowed_list:
509              values = inferred
510            else:
511              values = ops.convert_to_tensor(
512                  values,
513                  name=input_arg.name,
514                  as_ref=input_arg.is_ref,
515                  preferred_dtype=default_dtype)
516          else:
517            values = ops.convert_to_tensor(
518                values,
519                name=input_arg.name,
520                dtype=dtype,
521                as_ref=input_arg.is_ref,
522                preferred_dtype=default_dtype)
523        except TypeError as err:
524          if dtype is None:
525            raise err
526          else:
527            raise TypeError(
528                "Expected %s passed to parameter '%s' of op '%s', got %s of "
529                "type '%s' instead. Error: %s" %
530                (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,
531                 repr(values), type(values).__name__, err))
532        except ValueError:
533          # What type does convert_to_tensor think it has?
534          try:
535            observed = ops.convert_to_tensor(
536                values, as_ref=input_arg.is_ref).dtype.name
537          except ValueError as err:
538            raise ValueError(
539                "Tried to convert '%s' to a tensor and failed. Error: %s" %
540                (input_name, err))
541          prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
542                    (input_name, op_type_name, observed))
543          if input_arg.type != types_pb2.DT_INVALID:
544            raise TypeError("%s expected type of %s." %
545                            (prefix, dtypes.as_dtype(input_arg.type).name))
546          else:
547            # Update the maps with the default, if needed.
548            k = input_arg.type_attr
549            if k in default_type_attr_map:
550              if k not in attrs:
551                attrs[k] = default_type_attr_map[k]
552                if k not in inferred_from:
553                  inferred_from[k] = "Default in OpDef"
554
555            raise TypeError(
556                "%s type %s of argument '%s'." %
557                (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
558                 inferred_from[input_arg.type_attr]))
559
560        types = [values.dtype]
561        inputs.append(values)
562      base_types = [x.base_dtype for x in types]
563
564      if input_arg.number_attr:
565        # <number-attr> * <type> or <number-attr> * <type-attr>
566        if input_arg.number_attr in attrs:
567          if len(values) != attrs[input_arg.number_attr]:
568            raise ValueError(
569                "List argument '%s' to '%s' Op with length %d must match "
570                "length %d of argument '%s'." %
571                (input_name, op_type_name, len(values),
572                 attrs[input_arg.number_attr],
573                 inferred_from[input_arg.number_attr]))
574        else:
575          attrs[input_arg.number_attr] = len(values)
576          inferred_from[input_arg.number_attr] = input_name
577          num_attr = _Attr(op_def, input_arg.number_attr)
578          if num_attr.has_minimum and len(values) < num_attr.minimum:
579            raise ValueError(
580                "List argument '%s' to '%s' Op with length %d shorter "
581                "than minimum length %d." %
582                (input_name, op_type_name, len(values), num_attr.minimum))
583        # All tensors must have the same base type.
584        if any(bt != base_types[0] for bt in base_types):
585          raise TypeError(
586              "All tensors passed to '%s' of '%s' Op "
587              "must have the same type." %
588              (input_name, op_type_name))
589        if input_arg.type != types_pb2.DT_INVALID:
590          # <number-attr> * <type> case
591          if base_types and base_types[0] != input_arg.type:
592            assert False, "Unreachable"
593        elif input_arg.type_attr in attrs:
594          # <number-attr> * <type-attr> case, where <type-attr> already
595          # has an inferred value.
596          if base_types and base_types[0] != attrs[input_arg.type_attr]:
597            assert False, "Unreachable"
598        else:
599          # <number-attr> * <type-attr> case, where we are now setting
600          # the <type-attr> based on this input
601          if not base_types:
602            # If it's in default_type_attr_map, then wait to set it
603            # (in "process remaining attrs", below).
604            if input_arg.type_attr not in default_type_attr_map:
605              raise TypeError(
606                  "Don't know how to infer type variable from empty input "
607                  "list passed to input '%s' of '%s' Op." %
608                  (input_name, op_type_name))
609          else:
610            attrs[input_arg.type_attr] = base_types[0]
611            inferred_from[input_arg.type_attr] = input_name
612            type_attr = _Attr(op_def, input_arg.type_attr)
613            _SatisfiesTypeConstraint(base_types[0], type_attr,
614                                     param_name=input_name)
615      elif input_arg.type_attr:
616        # <type-attr>
617        attr_value = base_types[0]
618        if input_arg.type_attr in attrs:
619          if attrs[input_arg.type_attr] != attr_value:
620            raise TypeError(
621                "Input '%s' of '%s' Op has type %s that does not "
622                "match type %s of argument '%s'." %
623                (input_name, op_type_name, dtypes.as_dtype(attr_value).name,
624                 dtypes.as_dtype(attrs[input_arg.type_attr]).name,
625                 inferred_from[input_arg.type_attr]))
626        else:
627          for base_type in base_types:
628            _SatisfiesTypeConstraint(base_type,
629                                     _Attr(op_def, input_arg.type_attr),
630                                     param_name=input_name)
631          attrs[input_arg.type_attr] = attr_value
632          inferred_from[input_arg.type_attr] = input_name
633      elif input_arg.type_list_attr:
634        # <type-list-attr>
635        attr_value = base_types
636        if input_arg.type_list_attr in attrs:
637          if attrs[input_arg.type_list_attr] != attr_value:
638            raise TypeError(
639                "Input '%s' of '%s' Op has type list of %s that does not "
640                "match type list %s of argument '%s'." %
641                (input_name, op_type_name,
642                 ", ".join(dtypes.as_dtype(x).name for x in attr_value),
643                 ", ".join(dtypes.as_dtype(x).name
644                           for x in attrs[input_arg.type_list_attr]),
645                 inferred_from[input_arg.type_list_attr]))
646        else:
647          for base_type in base_types:
648            _SatisfiesTypeConstraint(base_type,
649                                     _Attr(op_def, input_arg.type_list_attr),
650                                     param_name=input_name)
651          attrs[input_arg.type_list_attr] = attr_value
652          inferred_from[input_arg.type_list_attr] = input_name
653      else:
654        # single Tensor with specified type
655        if base_types[0] != input_arg.type:
656          assert False, "Unreachable"
657
658      if input_arg.is_ref:
659        if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
660          raise TypeError(
661              ("'%s' Op requires that input '%s' be a mutable tensor "
662               "(e.g.: a tf.Variable)") % (op_type_name, input_name))
663        input_types.extend(types)
664      else:
665        input_types.extend(base_types)
666
667    # Process remaining attrs
668    for attr in op_def.attr:
669      # Skip attrs that have already had their values inferred
670      if attr.name in attrs:
671        if attr.name in keywords:
672          raise TypeError(
673              "Should not specify value for inferred attr '%s'." % attr.name)
674        continue
675      if attr.name in keywords:
676        attrs[attr.name] = keywords.pop(attr.name)
677      elif attr.name + "_" in keywords:
678        # Attrs whose names match Python keywords have an extra '_'
679        # appended, so we must check for that as well.
680        attrs[attr.name] = keywords.pop(attr.name + "_")
681      elif attr.name in default_type_attr_map:
682        attrs[attr.name] = default_type_attr_map[attr.name]
683        inferred_from.setdefault(attr.name, "Default in OpDef")
684      else:
685        raise TypeError("No argument for attr " + attr.name)
686
687    # Convert attr values to AttrValue protos.
688    attr_protos = {}
689    for attr_def in op_def.attr:
690      key = attr_def.name
691      value = attrs[key]
692
693      if attr_def.HasField("default_value") and value is None:
694        attr_value = attr_value_pb2.AttrValue()
695        attr_value.CopyFrom(attr_def.default_value)
696        attr_protos[key] = attr_value
697        continue
698
699      attr_value = value_to_attr_value(value, attr_def.type, key)
700      if attr_def.type.startswith("list("):
701        _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
702      if attr_def.HasField("allowed_values"):
703        if attr_def.type == "string":
704          _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
705                                             op_type_name)
706        elif attr_def.type == "list(string)":
707          for value in attr_value.list.s:
708            _SatisfiesAllowedStringsConstraint(value, attr_def, key,
709                                               op_type_name)
710      if attr_def.has_minimum and attr_def.type == "int":
711        _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
712                                       op_type_name)
713      if attr_def.type == "type":
714        _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
715      if attr_def.type == "list(type)":
716        for value in attr_value.list.type:
717          _SatisfiesTypeConstraint(value, attr_def, key)
718
719      attr_protos[key] = attr_value
720    del attrs  # attrs is no longer authoritative, use attr_protos instead
721
722    # Determine output types (possibly using attrs)
723    output_structure = []
724    for arg in op_def.output_arg:
725      if arg.number_attr:
726        n = _AttrValue(attr_protos, arg.number_attr).i
727        output_structure.append(n)
728      elif arg.type_attr:
729        t = _AttrValue(attr_protos, arg.type_attr)
730        output_structure.append(None)
731      elif arg.type_list_attr:
732        t = _AttrValue(attr_protos, arg.type_list_attr)
733        output_structure.append(len(t.list.type))
734      else:
735        output_structure.append(None)
736
737    if keywords:
738      raise TypeError("apply_op() got unexpected keyword arguments: " +
739                      ", ".join(sorted(keywords.keys())))
740
741    # NOTE(mrry): We add an explicit colocation constraint between
742    # the newly created op and any of its reference-typed inputs.
743    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
744                            if arg.is_ref]
745    with _MaybeColocateWith(must_colocate_inputs):
746      # Add Op to graph
747      # pylint: disable=protected-access
748      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
749                                 name=scope, input_types=input_types,
750                                 attrs=attr_protos, op_def=op_def)
751
752    # `outputs` is returned as a separate return value so that the output
753    # tensors can the `op` per se can be decoupled so that the
754    # `op_callbacks` can function properly. See framework/op_callbacks.py
755    # for more details.
756    outputs = op.outputs
757    # Conditionally invoke tfdbg v2's op callback(s).
758    if op_callbacks.should_invoke_op_callbacks():
759      callback_outputs = op_callbacks.invoke_op_callbacks(
760          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
761          op_name=op.name, graph=g)
762      if callback_outputs is not None:
763        outputs = callback_outputs
764
765    return output_structure, op_def.is_stateful, op, outputs
766
767
768def value_to_attr_value(value, attr_type, arg_name):  # pylint: disable=invalid-name
769  """Encodes a Python value as an `AttrValue` proto message.
770
771  Args:
772    value: The value to convert.
773    attr_type: The value type (string) -- see the AttrValue proto definition for
774      valid strings.
775    arg_name: Argument name (for error messages).
776
777  Returns:
778    An AttrValue proto message that encodes `value`.
779  """
780  attr_value = attr_value_pb2.AttrValue()
781
782  if attr_type.startswith("list("):
783    if not _IsListValue(value):
784      raise TypeError("Expected list for attr " + arg_name)
785
786  if attr_type == "string":
787    attr_value.s = _MakeStr(value, arg_name)
788  elif attr_type == "list(string)":
789    attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value])
790  elif attr_type == "int":
791    attr_value.i = _MakeInt(value, arg_name)
792  elif attr_type == "list(int)":
793    attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value])
794  elif attr_type == "float":
795    attr_value.f = _MakeFloat(value, arg_name)
796  elif attr_type == "list(float)":
797    attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value])
798  elif attr_type == "bool":
799    attr_value.b = _MakeBool(value, arg_name)
800  elif attr_type == "list(bool)":
801    attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value])
802  elif attr_type == "type":
803    attr_value.type = _MakeType(value, arg_name)
804  elif attr_type == "list(type)":
805    attr_value.list.type.extend([_MakeType(x, arg_name) for x in value])
806  elif attr_type == "shape":
807    attr_value.shape.CopyFrom(_MakeShape(value, arg_name))
808  elif attr_type == "list(shape)":
809    attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value])
810  elif attr_type == "tensor":
811    attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name))
812  elif attr_type == "list(tensor)":
813    attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value])
814  elif attr_type == "func":
815    attr_value.func.CopyFrom(_MakeFunc(value, arg_name))
816  elif attr_type == "list(func)":
817    attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value])
818  else:
819    raise TypeError("Unrecognized Attr type " + attr_type)
820  return attr_value
821
822
823# The following symbols are used by op_def_util.cc.
824_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
825_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)
826_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape)
827_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape)
828_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto)
829_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse)
830_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor)
831