• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to hold a library of OpDefs and use it to create Brain operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import op_def_pb2
26from tensorflow.core.framework import tensor_pb2
27from tensorflow.core.framework import tensor_shape_pb2
28from tensorflow.core.framework import types_pb2
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util import compat
34from tensorflow.python.util import tf_contextlib
35
36
37def _Attr(op_def, name):
38  for attr in op_def.attr:
39    if attr.name == name:
40      return attr
41  raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
42                  (op_def.name, name))
43
44
45def _AttrValue(attr_protos, name):
46  if name in attr_protos:
47    return attr_protos[name]
48  raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
49                  (name, attr_protos))
50
51
52def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
53  if attr_def.HasField("allowed_values"):
54    allowed_list = attr_def.allowed_values.list.type
55    if dtype not in allowed_list:
56      raise TypeError(
57          "Value passed to parameter '%s' has DataType %s not in list of "
58          "allowed values: %s" %
59          (param_name, dtypes.as_dtype(dtype).name,
60           ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
61
62
63def _IsListParameter(arg):
64  if arg.number_attr:
65    return True
66  elif arg.type_list_attr:
67    return True
68  return False
69
70
71def _NumTypeFields(arg):
72  num = 0
73  if arg.type != types_pb2.DT_INVALID: num += 1
74  if arg.type_attr: num += 1
75  if arg.type_list_attr: num += 1
76  return num
77
78
79def _IsListValue(v):
80  return isinstance(v, (list, tuple))
81
82
83def _Flatten(l):
84  """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
85  # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
86  l_of_l = [x if _IsListValue(x) else [x] for x in l]
87  # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
88  return [item for sublist in l_of_l for item in sublist]
89
90
91def _Restructure(l, structure):
92  """Returns the elements of list l structured according to the given structure.
93
94  A structure is represented by a list whose elements are either
95  `None` or a non-negative integer. `None` corresponds to a single
96  element in the output list, and an integer N corresponds to a nested
97  list of length N.
98
99  The function returns a data structure whose shape is given by
100  `structure`, and whose elements are taken from `l`. If `structure`
101  is a singleton, the function returns the single data structure
102  implied by the 0th element of `structure`. For example:
103
104      _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
105        -> ["foo", ["bar", "baz"], "qux"]
106
107      _Restructure(["foo"], [None]) -> "foo"
108
109      _Restructure(["foo"], [1]) -> ["foo"]
110
111      _Restructure([], [0]) -> []
112
113  Args:
114    l: A list.
115    structure: A list whose elements are either `None` or a non-negative
116      integer.
117
118  Returns:
119    The elements of `l`, restructured according to `structure`. If
120    `structure` is a list of length 1, this function returns the
121    single data structure implied by `structure[0]`.
122
123  """
124  result = []
125  current_index = 0
126  for element in structure:
127    if element is None:
128      result.append(l[current_index])
129      current_index += 1
130    else:
131      result.append(l[current_index:current_index+element])
132      current_index += element
133
134  if len(result) == 1:
135    return result[0]
136  else:
137    return tuple(result)
138
139
140def _MakeFloat(v, arg_name):
141  if not isinstance(v, compat.real_types):
142    raise TypeError("Expected float for argument '%s' not %s." %
143                    (arg_name, repr(v)))
144  return float(v)
145
146
147def _MakeInt(v, arg_name):
148  if isinstance(v, six.string_types):
149    raise TypeError("Expected int for argument '%s' not %s." %
150                    (arg_name, repr(v)))
151  try:
152    return int(v)
153  except (ValueError, TypeError):
154    raise TypeError("Expected int for argument '%s' not %s." %
155                    (arg_name, repr(v)))
156
157
158def _MakeStr(v, arg_name):
159  if not isinstance(v, compat.bytes_or_text_types):
160    raise TypeError("Expected string for argument '%s' not %s." %
161                    (arg_name, repr(v)))
162  return compat.as_bytes(v)  # Convert unicode strings to bytes.
163
164
165def _MakeBool(v, arg_name):
166  if not isinstance(v, bool):
167    raise TypeError("Expected bool for argument '%s' not %s." %
168                    (arg_name, repr(v)))
169  return v
170
171
172def _MakeType(v, attr_def):
173  try:
174    v = dtypes.as_dtype(v).base_dtype
175  except TypeError:
176    raise TypeError("Expected DataType for argument '%s' not %s." %
177                    (attr_def.name, repr(v)))
178  i = v.as_datatype_enum
179  _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)
180  return i
181
182
183def _MakeShape(v, arg_name):
184  """Convert v into a TensorShapeProto."""
185  # Args:
186  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
187  #   arg_name: String, for error messages.
188
189  # Returns:
190  #   A TensorShapeProto.
191  if isinstance(v, tensor_shape_pb2.TensorShapeProto):
192    for d in v.dim:
193      if d.name:
194        logging.warning("Warning: TensorShapeProto with a named dimension: %s",
195                        str(v))
196        break
197    return v
198  try:
199    return tensor_shape.as_shape(v).as_proto()
200  except TypeError as e:
201    raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e))
202  except ValueError as e:
203    raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e))
204
205
206def _MakeTensor(v, arg_name):
207  """Ensure v is a TensorProto."""
208  if isinstance(v, tensor_pb2.TensorProto):
209    return v
210  raise TypeError(
211      "Don't know how to convert %s to a TensorProto for argument '%s'" %
212      (repr(v), arg_name))
213
214
215def _MakeFunc(v, arg_name):
216  """Ensure v is a func."""
217  if isinstance(v, attr_value_pb2.NameAttrList):
218    return v
219  fn_attr = attr_value_pb2.NameAttrList()
220  if isinstance(v, compat.bytes_or_text_types):
221    fn_attr.name = v
222  elif hasattr(v, "add_to_graph"):
223    v.add_to_graph(ops.get_default_graph())
224    fn_attr.name = v.name
225  else:
226    raise TypeError("Don't know how to convert {} to a func for "
227                    "argument {}".format(v, arg_name))
228  return fn_attr
229
230
231class _OpInfo(object):
232  """All per-Op state we would like to precompute/validate."""
233
234  def __init__(self, op_def):
235    self.op_def = op_def
236    # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
237    # here, instead of these checks.
238    for arg in list(op_def.input_arg) + list(op_def.output_arg):
239      num_type_fields = _NumTypeFields(arg)
240      if num_type_fields != 1:
241        raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
242                        (arg.name, op_def.name, num_type_fields))
243      if arg.type_attr:
244        attr_type = _Attr(op_def, arg.type_attr).type
245        if attr_type != "type":
246          raise TypeError("Attr '%s' of '%s' used as a type_attr "
247                          "but has type %s" %
248                          (arg.type_attr, op_def.name, attr_type))
249      if arg.type_list_attr:
250        attr_type = _Attr(op_def, arg.type_list_attr).type
251        if attr_type != "list(type)":
252          raise TypeError(
253              "Attr '%s' of '%s' used as a type_list_attr but has type %s" %
254              (arg.type_attr, op_def.name, attr_type))
255      if arg.number_attr:
256        attr_type = _Attr(op_def, arg.number_attr).type
257        if attr_type != "int":
258          raise TypeError(
259              "Attr '%s' of '%s' used as a number_attr but has type %s" %
260              (arg.number_attr, op_def.name, attr_type))
261
262
263# pylint: disable=g-doc-return-or-yield
264@tf_contextlib.contextmanager
265def _MaybeColocateWith(inputs):
266  """A context manager for (maybe) colocating with a list of input tensors.
267
268  Args:
269    inputs: A list of `Tensor` or `Operation` objects.
270
271  Returns:
272    A context manager.
273  """
274  if not inputs:
275    yield
276  else:
277    # NOTE(mrry): The `ops.colocate_with()` function accepts only a single
278    # op or tensor, so we create one context manager per element in the list.
279    with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]):
280      yield
281# pylint: enable=g-doc-return-or-yield
282
283
284class OpDefLibrary(object):
285  """Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
286
287  def __init__(self):
288    self._ops = {}
289
290  # pylint: disable=invalid-name
291  def add_op(self, op_def):
292    """Register an OpDef. May call apply_op with the name afterwards."""
293    if not isinstance(op_def, op_def_pb2.OpDef):
294      raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
295                      (op_def, type(op_def)))
296    if not op_def.name:
297      raise ValueError("%s missing name." % op_def)
298    if op_def.name in self._ops:
299      raise RuntimeError("Op name %s registered twice." % op_def.name)
300    self._ops[op_def.name] = _OpInfo(op_def)
301
302  def add_op_list(self, op_list):
303    """Register the OpDefs from an OpList."""
304    if not isinstance(op_list, op_def_pb2.OpList):
305      raise TypeError("%s is %s, not an op_def_pb2.OpList" %
306                      (op_list, type(op_list)))
307    for op_def in op_list.op:
308      self.add_op(op_def)
309
310  def apply_op(self, op_type_name, name=None, **keywords):
311    # pylint: disable=g-doc-args
312    """Add a node invoking a registered Op to a graph.
313
314    Example usage:
315       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
316       # will convert to a Tensor.
317       op_def_library.apply_op("op", input1=input1, input2=input2)
318       # Can specify a node name.
319       op_def_library.apply_op("op", input1=input1, name="node_name")
320       # Must use keyword arguments, with the names specified in the OpDef.
321       op_def_library.apply_op("op", input_name=input, attr_name=attr)
322
323    All attrs must either be inferred from an input or specified.
324    (If inferred, the attr must not be specified.)  If an attr has a default
325    value specified in the Op's OpDef, then you may pass None as the value
326    of that attr to get the default.
327
328    Args:
329      op_type_name: string. Must match the name field of a registered Op.
330      name: string. Optional name of the created op.
331      **keywords: input Tensor and attr arguments specified by name,
332        and optional parameters to pass when constructing the Operation.
333
334    Returns:
335      The Tensor(s) representing the output of the operation, or the Operation
336      itself if there are no outputs.
337
338    Raises:
339      RuntimeError: On some errors.
340      TypeError: On some errors.
341      ValueError: On some errors.
342    """
343    output_structure, is_stateful, op = self._apply_op_helper(
344        op_type_name, name, **keywords)
345    if output_structure:
346      outputs = op.outputs
347      res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
348      if isinstance(res, list) and not res and is_stateful:
349        return op
350      else:
351        return res
352    else:
353      return op
354
355  def _apply_op_helper(self, op_type_name, name=None, **keywords):
356    """Implementation of apply_op that returns output_structure, op."""
357    op_info = self._ops.get(op_type_name, None)
358    if op_info is None:
359      raise RuntimeError("Unrecognized Op name " + op_type_name)
360    op_def = op_info.op_def
361
362    # Determine the graph context.
363    try:
364      # Need to flatten all the arguments into a list.
365      # pylint: disable=protected-access
366      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
367      # pylint: enable=protected-access
368    except AssertionError as e:
369      raise RuntimeError(
370          "Cannot determine graph for Op '%s' due to: %s"
371          % (op_type_name, e.message))
372
373    # Default name if not specified.
374    if name is None:
375      name = op_type_name
376
377    # Check for deprecation
378    deprecation_version = op_def.deprecation.version
379    if deprecation_version:
380      producer = g.graph_def_versions.producer
381      if producer >= deprecation_version:
382        raise NotImplementedError(
383            ("Op %s is not available in GraphDef version %d. "
384             "It has been removed in version %d. %s.") %
385            (op_type_name, producer, deprecation_version,
386             op_def.deprecation.explanation))
387
388    # Fill in the list of default types for all "type" attrs.  This
389    # will be used to choose a preferred dtype to convert to in the
390    # absence of input type information.
391    #
392    # TODO(b/31302892): Currently the defaults don't work in the right
393    # way if you have two inputs, one of whose type resolution depends
394    # on the other.  Handling this will require restructuring this code
395    # significantly.
396    default_type_attr_map = {}
397    for attr_def in op_def.attr:
398      if attr_def.type != "type":
399        continue
400      key = attr_def.name
401      if attr_def.HasField("default_value"):
402        default_type_attr_map[key] = dtypes.as_dtype(
403            attr_def.default_value.type)
404
405    # Requires that op_def has passed validation (using the C++
406    # ValidateOpDef() from ../framework/op_def_util.h).
407    attrs = {}
408    inputs = []
409    input_types = []
410    with g.as_default(), ops.name_scope(name) as scope:
411
412      # Perform input type inference
413      inferred_from = {}
414      for input_arg in op_def.input_arg:
415        input_name = input_arg.name
416        if input_name in keywords:
417          values = keywords.pop(input_name)
418        elif input_name + "_" in keywords:
419          # Handle the case where the name is a keyword or built-in
420          # for Python so we use the name + _ instead.
421          input_name += "_"
422          values = keywords.pop(input_name)
423        else:
424          raise TypeError("No argument for input " + input_name)
425
426        # Goals:
427        # * Convert values to Tensors if it contains constants.
428        # * Verify that values is a list if that matches the input_arg's
429        #   type.
430        # * If the input_arg's type is determined by attrs, either set
431        #   those attrs and validate those attr values are legal (if
432        #   they have not yet been set) or validate the input matches
433        #   the type indicated by the attrs (if they have already been
434        #   inferred via an earlier input).
435        # * If the input_arg has an explicit type, make sure the input
436        #   conforms.
437
438        if _IsListParameter(input_arg):
439          if not _IsListValue(values):
440            raise TypeError(
441                "Expected list for '%s' argument to '%s' Op, not %s." %
442                (input_name, op_type_name, values))
443          # In cases where we expect all elements of the list to have the
444          # same dtype, try to cast non-Tensor elements to that type.
445          dtype = None
446          default_dtype = None
447          if input_arg.type != types_pb2.DT_INVALID:
448            dtype = input_arg.type
449          elif input_arg.number_attr:
450            if input_arg.type_attr in attrs:
451              dtype = attrs[input_arg.type_attr]
452            else:
453              for t in values:
454                if isinstance(t, ops.Tensor):
455                  dtype = t.dtype
456                  break
457
458            # dtype still not found, prefer using the default dtype
459            # from the attr.
460            if dtype is None and input_arg.type_attr in default_type_attr_map:
461              default_dtype = default_type_attr_map[input_arg.type_attr]
462
463          try:
464            if not input_arg.is_ref and dtype:
465              dtype = dtypes.as_dtype(dtype).base_dtype
466            values = ops.internal_convert_n_to_tensor(
467                values,
468                name=input_arg.name,
469                dtype=dtype if dtype else None,
470                preferred_dtype=default_dtype,
471                as_ref=input_arg.is_ref)
472            if input_arg.number_attr and len(
473                set(v.dtype.base_dtype for v in values)) > 1:
474              raise TypeError()  # All types should match.
475          except (TypeError, ValueError):
476            # What types does the conversion function think values have?
477            observed_types = []
478            for value in values:
479              try:
480                converted_value = ops.internal_convert_to_tensor(
481                    value, as_ref=input_arg.is_ref)
482                observed_types.append(converted_value.dtype.base_dtype.name)
483              except (TypeError, ValueError):
484                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
485            observed = ", ".join(observed_types)
486
487            prefix = (
488                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
489                (input_name, op_type_name, observed))
490            if input_arg.number_attr:
491              if input_arg.type != types_pb2.DT_INVALID:
492                raise TypeError("%s that do not match expected type %s." %
493                                (prefix, dtype.name))
494              elif input_arg.type_attr in attrs:
495                raise TypeError("%s that do not match type %s inferred from "
496                                "earlier arguments." %
497                                (prefix, dtype.name))
498              else:
499                raise TypeError("%s that don't all match." % prefix)
500            else:
501              raise TypeError(
502                  "%s that are invalid. Tensors: %s" % (prefix, values))
503
504          types = [x.dtype for x in values]
505          inputs.extend(values)
506        else:
507          # In cases where we have an expected type, try to convert non-Tensor
508          # arguments to that type.
509          dtype = None
510          default_dtype = None
511          if input_arg.type != types_pb2.DT_INVALID:
512            dtype = input_arg.type
513          elif input_arg.type_attr in attrs:
514            dtype = attrs[input_arg.type_attr]
515          elif input_arg.type_attr in default_type_attr_map:
516            # The dtype could not be inferred solely from the inputs,
517            # so we prefer the attr's default, so code that adds a new attr
518            # with a default is backwards compatible.
519            default_dtype = default_type_attr_map[input_arg.type_attr]
520
521          try:
522            values = ops.internal_convert_to_tensor(
523                values,
524                name=input_arg.name,
525                dtype=dtype,
526                as_ref=input_arg.is_ref,
527                preferred_dtype=default_dtype)
528          except TypeError as err:
529            if dtype is None:
530              raise err
531            else:
532              raise TypeError(
533                  "Expected %s passed to parameter '%s' of op '%s', got %s of "
534                  "type '%s' instead. Error: %s" %
535                  (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,
536                   repr(values), type(values).__name__, err))
537          except ValueError:
538            # What type does convert_to_tensor think it has?
539            try:
540              observed = ops.internal_convert_to_tensor(
541                  values, as_ref=input_arg.is_ref).dtype.name
542            except ValueError as err:
543              raise ValueError(
544                  "Tried to convert '%s' to a tensor and failed. Error: %s" %
545                  (input_name, err))
546            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
547                      (input_name, op_type_name, observed))
548            if input_arg.type != types_pb2.DT_INVALID:
549              raise TypeError("%s expected type of %s." %
550                              (prefix, dtypes.as_dtype(input_arg.type).name))
551            else:
552              # Update the maps with the default, if needed.
553              k = input_arg.type_attr
554              if k in default_type_attr_map:
555                if k not in attrs:
556                  attrs[k] = default_type_attr_map[k]
557                  if k not in inferred_from:
558                    inferred_from[k] = "Default in OpDef"
559
560              raise TypeError(
561                  "%s type %s of argument '%s'." %
562                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
563                   inferred_from[input_arg.type_attr]))
564
565          types = [values.dtype]
566          inputs.append(values)
567        base_types = [x.base_dtype for x in types]
568
569        if input_arg.number_attr:
570          # <number-attr> * <type> or <number-attr> * <type-attr>
571          if input_arg.number_attr in attrs:
572            if len(values) != attrs[input_arg.number_attr]:
573              raise ValueError(
574                  "List argument '%s' to '%s' Op with length %d must match "
575                  "length %d of argument '%s'." %
576                  (input_name, op_type_name, len(values),
577                   attrs[input_arg.number_attr],
578                   inferred_from[input_arg.number_attr]))
579          else:
580            attrs[input_arg.number_attr] = len(values)
581            inferred_from[input_arg.number_attr] = input_name
582            num_attr = _Attr(op_def, input_arg.number_attr)
583            if num_attr.has_minimum and len(values) < num_attr.minimum:
584              raise ValueError(
585                  "List argument '%s' to '%s' Op with length %d shorter "
586                  "than minimum length %d." %
587                  (input_name, op_type_name, len(values), num_attr.minimum))
588          # All tensors must have the same base type.
589          if any(bt != base_types[0] for bt in base_types):
590            raise TypeError(
591                "All tensors passed to '%s' of '%s' Op "
592                "must have the same type." %
593                (input_name, op_type_name))
594          if input_arg.type != types_pb2.DT_INVALID:
595            # <number-attr> * <type> case
596            if base_types and base_types[0] != input_arg.type:
597              assert False, "Unreachable"
598          elif input_arg.type_attr in attrs:
599            # <number-attr> * <type-attr> case, where <type-attr> already
600            # has an inferred value.
601            if base_types and base_types[0] != attrs[input_arg.type_attr]:
602              assert False, "Unreachable"
603          else:
604            # <number-attr> * <type-attr> case, where we are now setting
605            # the <type-attr> based on this input
606            if not base_types:
607              raise TypeError(
608                  "Don't know how to infer type variable from empty input "
609                  "list passed to input '%s' of '%s' Op." %
610                  (input_name, op_type_name))
611            attrs[input_arg.type_attr] = base_types[0]
612            inferred_from[input_arg.type_attr] = input_name
613            type_attr = _Attr(op_def, input_arg.type_attr)
614            _SatisfiesTypeConstraint(base_types[0], type_attr,
615                                     param_name=input_name)
616        elif input_arg.type_attr:
617          # <type-attr>
618          attr_value = base_types[0]
619          if input_arg.type_attr in attrs:
620            if attrs[input_arg.type_attr] != attr_value:
621              assert False, "Unreachable"
622          else:
623            for base_type in base_types:
624              _SatisfiesTypeConstraint(base_type,
625                                       _Attr(op_def, input_arg.type_attr),
626                                       param_name=input_name)
627            attrs[input_arg.type_attr] = attr_value
628            inferred_from[input_arg.type_attr] = input_name
629        elif input_arg.type_list_attr:
630          # <type-list-attr>
631          attr_value = base_types
632          if input_arg.type_list_attr in attrs:
633            if attrs[input_arg.type_list_attr] != attr_value:
634              raise TypeError(
635                  "Input '%s' of '%s' Op has type list of %s that does not "
636                  "match type list %s of argument '%s'." %
637                  (input_name, op_type_name,
638                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
639                   ", ".join(dtypes.as_dtype(x).name
640                             for x in attrs[input_arg.type_list_attr]),
641                   inferred_from[input_arg.type_list_attr]))
642          else:
643            for base_type in base_types:
644              _SatisfiesTypeConstraint(base_type,
645                                       _Attr(op_def, input_arg.type_list_attr),
646                                       param_name=input_name)
647            attrs[input_arg.type_list_attr] = attr_value
648            inferred_from[input_arg.type_list_attr] = input_name
649        else:
650          # single Tensor with specified type
651          if base_types[0] != input_arg.type:
652            assert False, "Unreachable"
653
654        if input_arg.is_ref:
655          if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
656            raise TypeError(
657                ("'%s' Op requires that input '%s' be a mutable tensor "
658                 "(e.g.: a tf.Variable)") % (op_type_name, input_name))
659          input_types.extend(types)
660        else:
661          input_types.extend(base_types)
662
663      # Process remaining attrs
664      for attr in op_def.attr:
665        # Skip attrs that have already had their values inferred
666        if attr.name in attrs:
667          if attr.name in keywords:
668            raise TypeError(
669                "Should not specify value for inferred attr '%s'." % attr.name)
670          continue
671        if attr.name in keywords:
672          attrs[attr.name] = keywords.pop(attr.name)
673        elif attr.name + "_" in keywords:
674          # Attrs whose names match Python keywords have an extra '_'
675          # appended, so we must check for that as well.
676          attrs[attr.name] = keywords.pop(attr.name + "_")
677        else:
678          raise TypeError("No argument for attr " + attr.name)
679
680      # Convert attr values to AttrValue protos.
681      attr_protos = {}
682      for attr_def in op_def.attr:
683        key = attr_def.name
684        value = attrs[key]
685        attr_value = attr_value_pb2.AttrValue()
686        if attr_def.HasField("default_value") and value is None:
687          attr_value.CopyFrom(attr_def.default_value)
688          attr_protos[key] = attr_value
689          continue
690        if attr_def.type.startswith("list("):
691          if not _IsListValue(value):
692            raise TypeError("Expected list for attr " + key)
693          if attr_def.has_minimum:
694            if len(value) < attr_def.minimum:
695              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
696                               "less than minimum %d." %
697                               (key, op_type_name, len(value),
698                                attr_def.minimum))
699          attr_value.list.SetInParent()
700        if attr_def.type == "string":
701          attr_value.s = _MakeStr(value, key)
702          if attr_def.HasField("allowed_values"):
703            if attr_value.s not in attr_def.allowed_values.list.s:
704              raise ValueError(
705                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
706                  (key, op_type_name, compat.as_text(attr_value.s),
707                   '", "'.join(map(compat.as_text,
708                                   attr_def.allowed_values.list.s))))
709        elif attr_def.type == "list(string)":
710          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
711          if attr_def.HasField("allowed_values"):
712            for x in attr_value.list.s:
713              if x not in attr_def.allowed_values.list.s:
714                raise ValueError(
715                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
716                    (key, op_type_name, compat.as_text(x),
717                     '", "'.join(map(compat.as_text,
718                                     attr_def.allowed_values.list.s))))
719        elif attr_def.type == "int":
720          attr_value.i = _MakeInt(value, key)
721          if attr_def.has_minimum:
722            if attr_value.i < attr_def.minimum:
723              raise ValueError(
724                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
725                  (key, op_type_name, attr_value.i, attr_def.minimum))
726        elif attr_def.type == "list(int)":
727          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
728        elif attr_def.type == "float":
729          attr_value.f = _MakeFloat(value, key)
730        elif attr_def.type == "list(float)":
731          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
732        elif attr_def.type == "bool":
733          attr_value.b = _MakeBool(value, key)
734        elif attr_def.type == "list(bool)":
735          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
736        elif attr_def.type == "type":
737          attr_value.type = _MakeType(value, attr_def)
738        elif attr_def.type == "list(type)":
739          attr_value.list.type.extend(
740              [_MakeType(x, attr_def) for x in value])
741        elif attr_def.type == "shape":
742          attr_value.shape.CopyFrom(_MakeShape(value, key))
743        elif attr_def.type == "list(shape)":
744          attr_value.list.shape.extend(
745              [_MakeShape(x, key) for x in value])
746        elif attr_def.type == "tensor":
747          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
748        elif attr_def.type == "list(tensor)":
749          attr_value.list.tensor.extend(
750              [_MakeTensor(x, key) for x in value])
751        elif attr_def.type == "func":
752          attr_value.func.CopyFrom(_MakeFunc(value, key))
753        elif attr_def.type == "list(func)":
754          attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
755        else:
756          raise TypeError("Unrecognized Attr type " + attr_def.type)
757
758        attr_protos[key] = attr_value
759      del attrs  # attrs is no longer authoritative, use attr_protos instead
760
761      # Determine output types (possibly using attrs)
762      output_structure = []
763      for arg in op_def.output_arg:
764        if arg.number_attr:
765          n = _AttrValue(attr_protos, arg.number_attr).i
766          output_structure.append(n)
767        elif arg.type_attr:
768          t = _AttrValue(attr_protos, arg.type_attr)
769          output_structure.append(None)
770        elif arg.type_list_attr:
771          t = _AttrValue(attr_protos, arg.type_list_attr)
772          output_structure.append(len(t.list.type))
773        else:
774          output_structure.append(None)
775
776      if keywords:
777        raise TypeError("apply_op() got unexpected keyword arguments: " +
778                        ", ".join(sorted(keywords.keys())))
779
780      # NOTE(mrry): We add an explicit colocation constraint between
781      # the newly created op and any of its reference-typed inputs.
782      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
783                              if arg.is_ref]
784      with _MaybeColocateWith(must_colocate_inputs):
785        # Add Op to graph
786        op = g.create_op(op_type_name, inputs, dtypes=None, name=scope,
787                         input_types=input_types, attrs=attr_protos,
788                         op_def=op_def)
789      return output_structure, op_def.is_stateful, op
790
791# pylint: enable=invalid-name
792