1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Class to hold a library of OpDefs and use it to create Brain operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import six 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.core.framework import op_def_pb2 26from tensorflow.core.framework import tensor_pb2 27from tensorflow.core.framework import tensor_shape_pb2 28from tensorflow.core.framework import types_pb2 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.util import compat 34from tensorflow.python.util import tf_contextlib 35 36 37def _Attr(op_def, name): 38 for attr in op_def.attr: 39 if attr.name == name: 40 return attr 41 raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" % 42 (op_def.name, name)) 43 44 45def _AttrValue(attr_protos, name): 46 if name in attr_protos: 47 return attr_protos[name] 48 raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." % 49 (name, attr_protos)) 50 51 52def _SatisfiesTypeConstraint(dtype, attr_def, param_name): 53 if attr_def.HasField("allowed_values"): 54 allowed_list = attr_def.allowed_values.list.type 55 if dtype not in allowed_list: 56 raise TypeError( 57 "Value passed to parameter '%s' has DataType %s not in list of " 58 "allowed values: %s" % 59 (param_name, dtypes.as_dtype(dtype).name, 60 ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) 61 62 63def _IsListParameter(arg): 64 if arg.number_attr: 65 return True 66 elif arg.type_list_attr: 67 return True 68 return False 69 70 71def _NumTypeFields(arg): 72 num = 0 73 if arg.type != types_pb2.DT_INVALID: num += 1 74 if arg.type_attr: num += 1 75 if arg.type_list_attr: num += 1 76 return num 77 78 79def _IsListValue(v): 80 return isinstance(v, (list, tuple)) 81 82 83def _Flatten(l): 84 """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5].""" 85 # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]] 86 l_of_l = [x if _IsListValue(x) else [x] for x in l] 87 # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5] 88 return [item for sublist in l_of_l for item in sublist] 89 90 91def _Restructure(l, structure): 92 """Returns the elements of list l structured according to the given structure. 93 94 A structure is represented by a list whose elements are either 95 `None` or a non-negative integer. `None` corresponds to a single 96 element in the output list, and an integer N corresponds to a nested 97 list of length N. 98 99 The function returns a data structure whose shape is given by 100 `structure`, and whose elements are taken from `l`. If `structure` 101 is a singleton, the function returns the single data structure 102 implied by the 0th element of `structure`. For example: 103 104 _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None]) 105 -> ["foo", ["bar", "baz"], "qux"] 106 107 _Restructure(["foo"], [None]) -> "foo" 108 109 _Restructure(["foo"], [1]) -> ["foo"] 110 111 _Restructure([], [0]) -> [] 112 113 Args: 114 l: A list. 115 structure: A list whose elements are either `None` or a non-negative 116 integer. 117 118 Returns: 119 The elements of `l`, restructured according to `structure`. If 120 `structure` is a list of length 1, this function returns the 121 single data structure implied by `structure[0]`. 122 123 """ 124 result = [] 125 current_index = 0 126 for element in structure: 127 if element is None: 128 result.append(l[current_index]) 129 current_index += 1 130 else: 131 result.append(l[current_index:current_index+element]) 132 current_index += element 133 134 if len(result) == 1: 135 return result[0] 136 else: 137 return tuple(result) 138 139 140def _MakeFloat(v, arg_name): 141 if not isinstance(v, compat.real_types): 142 raise TypeError("Expected float for argument '%s' not %s." % 143 (arg_name, repr(v))) 144 return float(v) 145 146 147def _MakeInt(v, arg_name): 148 if isinstance(v, six.string_types): 149 raise TypeError("Expected int for argument '%s' not %s." % 150 (arg_name, repr(v))) 151 try: 152 return int(v) 153 except (ValueError, TypeError): 154 raise TypeError("Expected int for argument '%s' not %s." % 155 (arg_name, repr(v))) 156 157 158def _MakeStr(v, arg_name): 159 if not isinstance(v, compat.bytes_or_text_types): 160 raise TypeError("Expected string for argument '%s' not %s." % 161 (arg_name, repr(v))) 162 return compat.as_bytes(v) # Convert unicode strings to bytes. 163 164 165def _MakeBool(v, arg_name): 166 if not isinstance(v, bool): 167 raise TypeError("Expected bool for argument '%s' not %s." % 168 (arg_name, repr(v))) 169 return v 170 171 172def _MakeType(v, attr_def): 173 try: 174 v = dtypes.as_dtype(v).base_dtype 175 except TypeError: 176 raise TypeError("Expected DataType for argument '%s' not %s." % 177 (attr_def.name, repr(v))) 178 i = v.as_datatype_enum 179 _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name) 180 return i 181 182 183def _MakeShape(v, arg_name): 184 """Convert v into a TensorShapeProto.""" 185 # Args: 186 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 187 # arg_name: String, for error messages. 188 189 # Returns: 190 # A TensorShapeProto. 191 if isinstance(v, tensor_shape_pb2.TensorShapeProto): 192 for d in v.dim: 193 if d.name: 194 logging.warning("Warning: TensorShapeProto with a named dimension: %s", 195 str(v)) 196 break 197 return v 198 try: 199 return tensor_shape.as_shape(v).as_proto() 200 except TypeError as e: 201 raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e)) 202 except ValueError as e: 203 raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e)) 204 205 206def _MakeTensor(v, arg_name): 207 """Ensure v is a TensorProto.""" 208 if isinstance(v, tensor_pb2.TensorProto): 209 return v 210 raise TypeError( 211 "Don't know how to convert %s to a TensorProto for argument '%s'" % 212 (repr(v), arg_name)) 213 214 215def _MakeFunc(v, arg_name): 216 """Ensure v is a func.""" 217 if isinstance(v, attr_value_pb2.NameAttrList): 218 return v 219 fn_attr = attr_value_pb2.NameAttrList() 220 if isinstance(v, compat.bytes_or_text_types): 221 fn_attr.name = v 222 elif hasattr(v, "add_to_graph"): 223 v.add_to_graph(ops.get_default_graph()) 224 fn_attr.name = v.name 225 else: 226 raise TypeError("Don't know how to convert {} to a func for " 227 "argument {}".format(v, arg_name)) 228 return fn_attr 229 230 231class _OpInfo(object): 232 """All per-Op state we would like to precompute/validate.""" 233 234 def __init__(self, op_def): 235 self.op_def = op_def 236 # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it 237 # here, instead of these checks. 238 for arg in list(op_def.input_arg) + list(op_def.output_arg): 239 num_type_fields = _NumTypeFields(arg) 240 if num_type_fields != 1: 241 raise TypeError("Arg '%s' of '%s' must have one type field not %d" % 242 (arg.name, op_def.name, num_type_fields)) 243 if arg.type_attr: 244 attr_type = _Attr(op_def, arg.type_attr).type 245 if attr_type != "type": 246 raise TypeError("Attr '%s' of '%s' used as a type_attr " 247 "but has type %s" % 248 (arg.type_attr, op_def.name, attr_type)) 249 if arg.type_list_attr: 250 attr_type = _Attr(op_def, arg.type_list_attr).type 251 if attr_type != "list(type)": 252 raise TypeError( 253 "Attr '%s' of '%s' used as a type_list_attr but has type %s" % 254 (arg.type_attr, op_def.name, attr_type)) 255 if arg.number_attr: 256 attr_type = _Attr(op_def, arg.number_attr).type 257 if attr_type != "int": 258 raise TypeError( 259 "Attr '%s' of '%s' used as a number_attr but has type %s" % 260 (arg.number_attr, op_def.name, attr_type)) 261 262 263# pylint: disable=g-doc-return-or-yield 264@tf_contextlib.contextmanager 265def _MaybeColocateWith(inputs): 266 """A context manager for (maybe) colocating with a list of input tensors. 267 268 Args: 269 inputs: A list of `Tensor` or `Operation` objects. 270 271 Returns: 272 A context manager. 273 """ 274 if not inputs: 275 yield 276 else: 277 # NOTE(mrry): The `ops.colocate_with()` function accepts only a single 278 # op or tensor, so we create one context manager per element in the list. 279 with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): 280 yield 281# pylint: enable=g-doc-return-or-yield 282 283 284class OpDefLibrary(object): 285 """Holds a collection of OpDefs, can add the corresponding Ops to a graph.""" 286 287 def __init__(self): 288 self._ops = {} 289 290 # pylint: disable=invalid-name 291 def add_op(self, op_def): 292 """Register an OpDef. May call apply_op with the name afterwards.""" 293 if not isinstance(op_def, op_def_pb2.OpDef): 294 raise TypeError("%s is %s, not an op_def_pb2.OpDef" % 295 (op_def, type(op_def))) 296 if not op_def.name: 297 raise ValueError("%s missing name." % op_def) 298 if op_def.name in self._ops: 299 raise RuntimeError("Op name %s registered twice." % op_def.name) 300 self._ops[op_def.name] = _OpInfo(op_def) 301 302 def add_op_list(self, op_list): 303 """Register the OpDefs from an OpList.""" 304 if not isinstance(op_list, op_def_pb2.OpList): 305 raise TypeError("%s is %s, not an op_def_pb2.OpList" % 306 (op_list, type(op_list))) 307 for op_def in op_list.op: 308 self.add_op(op_def) 309 310 def apply_op(self, op_type_name, name=None, **keywords): 311 # pylint: disable=g-doc-args 312 """Add a node invoking a registered Op to a graph. 313 314 Example usage: 315 # input1 and input2 can be Tensors or anything ops.convert_to_tensor() 316 # will convert to a Tensor. 317 op_def_library.apply_op("op", input1=input1, input2=input2) 318 # Can specify a node name. 319 op_def_library.apply_op("op", input1=input1, name="node_name") 320 # Must use keyword arguments, with the names specified in the OpDef. 321 op_def_library.apply_op("op", input_name=input, attr_name=attr) 322 323 All attrs must either be inferred from an input or specified. 324 (If inferred, the attr must not be specified.) If an attr has a default 325 value specified in the Op's OpDef, then you may pass None as the value 326 of that attr to get the default. 327 328 Args: 329 op_type_name: string. Must match the name field of a registered Op. 330 name: string. Optional name of the created op. 331 **keywords: input Tensor and attr arguments specified by name, 332 and optional parameters to pass when constructing the Operation. 333 334 Returns: 335 The Tensor(s) representing the output of the operation, or the Operation 336 itself if there are no outputs. 337 338 Raises: 339 RuntimeError: On some errors. 340 TypeError: On some errors. 341 ValueError: On some errors. 342 """ 343 output_structure, is_stateful, op = self._apply_op_helper( 344 op_type_name, name, **keywords) 345 if output_structure: 346 outputs = op.outputs 347 res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) 348 if isinstance(res, list) and not res and is_stateful: 349 return op 350 else: 351 return res 352 else: 353 return op 354 355 def _apply_op_helper(self, op_type_name, name=None, **keywords): 356 """Implementation of apply_op that returns output_structure, op.""" 357 op_info = self._ops.get(op_type_name, None) 358 if op_info is None: 359 raise RuntimeError("Unrecognized Op name " + op_type_name) 360 op_def = op_info.op_def 361 362 # Determine the graph context. 363 try: 364 # Need to flatten all the arguments into a list. 365 # pylint: disable=protected-access 366 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 367 # pylint: enable=protected-access 368 except AssertionError as e: 369 raise RuntimeError( 370 "Cannot determine graph for Op '%s' due to: %s" 371 % (op_type_name, e.message)) 372 373 # Default name if not specified. 374 if name is None: 375 name = op_type_name 376 377 # Check for deprecation 378 deprecation_version = op_def.deprecation.version 379 if deprecation_version: 380 producer = g.graph_def_versions.producer 381 if producer >= deprecation_version: 382 raise NotImplementedError( 383 ("Op %s is not available in GraphDef version %d. " 384 "It has been removed in version %d. %s.") % 385 (op_type_name, producer, deprecation_version, 386 op_def.deprecation.explanation)) 387 388 # Fill in the list of default types for all "type" attrs. This 389 # will be used to choose a preferred dtype to convert to in the 390 # absence of input type information. 391 # 392 # TODO(b/31302892): Currently the defaults don't work in the right 393 # way if you have two inputs, one of whose type resolution depends 394 # on the other. Handling this will require restructuring this code 395 # significantly. 396 default_type_attr_map = {} 397 for attr_def in op_def.attr: 398 if attr_def.type != "type": 399 continue 400 key = attr_def.name 401 if attr_def.HasField("default_value"): 402 default_type_attr_map[key] = dtypes.as_dtype( 403 attr_def.default_value.type) 404 405 # Requires that op_def has passed validation (using the C++ 406 # ValidateOpDef() from ../framework/op_def_util.h). 407 attrs = {} 408 inputs = [] 409 input_types = [] 410 with g.as_default(), ops.name_scope(name) as scope: 411 412 # Perform input type inference 413 inferred_from = {} 414 for input_arg in op_def.input_arg: 415 input_name = input_arg.name 416 if input_name in keywords: 417 values = keywords.pop(input_name) 418 elif input_name + "_" in keywords: 419 # Handle the case where the name is a keyword or built-in 420 # for Python so we use the name + _ instead. 421 input_name += "_" 422 values = keywords.pop(input_name) 423 else: 424 raise TypeError("No argument for input " + input_name) 425 426 # Goals: 427 # * Convert values to Tensors if it contains constants. 428 # * Verify that values is a list if that matches the input_arg's 429 # type. 430 # * If the input_arg's type is determined by attrs, either set 431 # those attrs and validate those attr values are legal (if 432 # they have not yet been set) or validate the input matches 433 # the type indicated by the attrs (if they have already been 434 # inferred via an earlier input). 435 # * If the input_arg has an explicit type, make sure the input 436 # conforms. 437 438 if _IsListParameter(input_arg): 439 if not _IsListValue(values): 440 raise TypeError( 441 "Expected list for '%s' argument to '%s' Op, not %s." % 442 (input_name, op_type_name, values)) 443 # In cases where we expect all elements of the list to have the 444 # same dtype, try to cast non-Tensor elements to that type. 445 dtype = None 446 default_dtype = None 447 if input_arg.type != types_pb2.DT_INVALID: 448 dtype = input_arg.type 449 elif input_arg.number_attr: 450 if input_arg.type_attr in attrs: 451 dtype = attrs[input_arg.type_attr] 452 else: 453 for t in values: 454 if isinstance(t, ops.Tensor): 455 dtype = t.dtype 456 break 457 458 # dtype still not found, prefer using the default dtype 459 # from the attr. 460 if dtype is None and input_arg.type_attr in default_type_attr_map: 461 default_dtype = default_type_attr_map[input_arg.type_attr] 462 463 try: 464 if not input_arg.is_ref and dtype: 465 dtype = dtypes.as_dtype(dtype).base_dtype 466 values = ops.internal_convert_n_to_tensor( 467 values, 468 name=input_arg.name, 469 dtype=dtype if dtype else None, 470 preferred_dtype=default_dtype, 471 as_ref=input_arg.is_ref) 472 if input_arg.number_attr and len( 473 set(v.dtype.base_dtype for v in values)) > 1: 474 raise TypeError() # All types should match. 475 except (TypeError, ValueError): 476 # What types does the conversion function think values have? 477 observed_types = [] 478 for value in values: 479 try: 480 converted_value = ops.internal_convert_to_tensor( 481 value, as_ref=input_arg.is_ref) 482 observed_types.append(converted_value.dtype.base_dtype.name) 483 except (TypeError, ValueError): 484 observed_types.append("<NOT CONVERTIBLE TO TENSOR>") 485 observed = ", ".join(observed_types) 486 487 prefix = ( 488 "Tensors in list passed to '%s' of '%s' Op have types [%s]" % 489 (input_name, op_type_name, observed)) 490 if input_arg.number_attr: 491 if input_arg.type != types_pb2.DT_INVALID: 492 raise TypeError("%s that do not match expected type %s." % 493 (prefix, dtype.name)) 494 elif input_arg.type_attr in attrs: 495 raise TypeError("%s that do not match type %s inferred from " 496 "earlier arguments." % 497 (prefix, dtype.name)) 498 else: 499 raise TypeError("%s that don't all match." % prefix) 500 else: 501 raise TypeError( 502 "%s that are invalid. Tensors: %s" % (prefix, values)) 503 504 types = [x.dtype for x in values] 505 inputs.extend(values) 506 else: 507 # In cases where we have an expected type, try to convert non-Tensor 508 # arguments to that type. 509 dtype = None 510 default_dtype = None 511 if input_arg.type != types_pb2.DT_INVALID: 512 dtype = input_arg.type 513 elif input_arg.type_attr in attrs: 514 dtype = attrs[input_arg.type_attr] 515 elif input_arg.type_attr in default_type_attr_map: 516 # The dtype could not be inferred solely from the inputs, 517 # so we prefer the attr's default, so code that adds a new attr 518 # with a default is backwards compatible. 519 default_dtype = default_type_attr_map[input_arg.type_attr] 520 521 try: 522 values = ops.internal_convert_to_tensor( 523 values, 524 name=input_arg.name, 525 dtype=dtype, 526 as_ref=input_arg.is_ref, 527 preferred_dtype=default_dtype) 528 except TypeError as err: 529 if dtype is None: 530 raise err 531 else: 532 raise TypeError( 533 "Expected %s passed to parameter '%s' of op '%s', got %s of " 534 "type '%s' instead. Error: %s" % 535 (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, 536 repr(values), type(values).__name__, err)) 537 except ValueError: 538 # What type does convert_to_tensor think it has? 539 try: 540 observed = ops.internal_convert_to_tensor( 541 values, as_ref=input_arg.is_ref).dtype.name 542 except ValueError as err: 543 raise ValueError( 544 "Tried to convert '%s' to a tensor and failed. Error: %s" % 545 (input_name, err)) 546 prefix = ("Input '%s' of '%s' Op has type %s that does not match" % 547 (input_name, op_type_name, observed)) 548 if input_arg.type != types_pb2.DT_INVALID: 549 raise TypeError("%s expected type of %s." % 550 (prefix, dtypes.as_dtype(input_arg.type).name)) 551 else: 552 # Update the maps with the default, if needed. 553 k = input_arg.type_attr 554 if k in default_type_attr_map: 555 if k not in attrs: 556 attrs[k] = default_type_attr_map[k] 557 if k not in inferred_from: 558 inferred_from[k] = "Default in OpDef" 559 560 raise TypeError( 561 "%s type %s of argument '%s'." % 562 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, 563 inferred_from[input_arg.type_attr])) 564 565 types = [values.dtype] 566 inputs.append(values) 567 base_types = [x.base_dtype for x in types] 568 569 if input_arg.number_attr: 570 # <number-attr> * <type> or <number-attr> * <type-attr> 571 if input_arg.number_attr in attrs: 572 if len(values) != attrs[input_arg.number_attr]: 573 raise ValueError( 574 "List argument '%s' to '%s' Op with length %d must match " 575 "length %d of argument '%s'." % 576 (input_name, op_type_name, len(values), 577 attrs[input_arg.number_attr], 578 inferred_from[input_arg.number_attr])) 579 else: 580 attrs[input_arg.number_attr] = len(values) 581 inferred_from[input_arg.number_attr] = input_name 582 num_attr = _Attr(op_def, input_arg.number_attr) 583 if num_attr.has_minimum and len(values) < num_attr.minimum: 584 raise ValueError( 585 "List argument '%s' to '%s' Op with length %d shorter " 586 "than minimum length %d." % 587 (input_name, op_type_name, len(values), num_attr.minimum)) 588 # All tensors must have the same base type. 589 if any(bt != base_types[0] for bt in base_types): 590 raise TypeError( 591 "All tensors passed to '%s' of '%s' Op " 592 "must have the same type." % 593 (input_name, op_type_name)) 594 if input_arg.type != types_pb2.DT_INVALID: 595 # <number-attr> * <type> case 596 if base_types and base_types[0] != input_arg.type: 597 assert False, "Unreachable" 598 elif input_arg.type_attr in attrs: 599 # <number-attr> * <type-attr> case, where <type-attr> already 600 # has an inferred value. 601 if base_types and base_types[0] != attrs[input_arg.type_attr]: 602 assert False, "Unreachable" 603 else: 604 # <number-attr> * <type-attr> case, where we are now setting 605 # the <type-attr> based on this input 606 if not base_types: 607 raise TypeError( 608 "Don't know how to infer type variable from empty input " 609 "list passed to input '%s' of '%s' Op." % 610 (input_name, op_type_name)) 611 attrs[input_arg.type_attr] = base_types[0] 612 inferred_from[input_arg.type_attr] = input_name 613 type_attr = _Attr(op_def, input_arg.type_attr) 614 _SatisfiesTypeConstraint(base_types[0], type_attr, 615 param_name=input_name) 616 elif input_arg.type_attr: 617 # <type-attr> 618 attr_value = base_types[0] 619 if input_arg.type_attr in attrs: 620 if attrs[input_arg.type_attr] != attr_value: 621 assert False, "Unreachable" 622 else: 623 for base_type in base_types: 624 _SatisfiesTypeConstraint(base_type, 625 _Attr(op_def, input_arg.type_attr), 626 param_name=input_name) 627 attrs[input_arg.type_attr] = attr_value 628 inferred_from[input_arg.type_attr] = input_name 629 elif input_arg.type_list_attr: 630 # <type-list-attr> 631 attr_value = base_types 632 if input_arg.type_list_attr in attrs: 633 if attrs[input_arg.type_list_attr] != attr_value: 634 raise TypeError( 635 "Input '%s' of '%s' Op has type list of %s that does not " 636 "match type list %s of argument '%s'." % 637 (input_name, op_type_name, 638 ", ".join(dtypes.as_dtype(x).name for x in attr_value), 639 ", ".join(dtypes.as_dtype(x).name 640 for x in attrs[input_arg.type_list_attr]), 641 inferred_from[input_arg.type_list_attr])) 642 else: 643 for base_type in base_types: 644 _SatisfiesTypeConstraint(base_type, 645 _Attr(op_def, input_arg.type_list_attr), 646 param_name=input_name) 647 attrs[input_arg.type_list_attr] = attr_value 648 inferred_from[input_arg.type_list_attr] = input_name 649 else: 650 # single Tensor with specified type 651 if base_types[0] != input_arg.type: 652 assert False, "Unreachable" 653 654 if input_arg.is_ref: 655 if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access 656 raise TypeError( 657 ("'%s' Op requires that input '%s' be a mutable tensor " 658 "(e.g.: a tf.Variable)") % (op_type_name, input_name)) 659 input_types.extend(types) 660 else: 661 input_types.extend(base_types) 662 663 # Process remaining attrs 664 for attr in op_def.attr: 665 # Skip attrs that have already had their values inferred 666 if attr.name in attrs: 667 if attr.name in keywords: 668 raise TypeError( 669 "Should not specify value for inferred attr '%s'." % attr.name) 670 continue 671 if attr.name in keywords: 672 attrs[attr.name] = keywords.pop(attr.name) 673 elif attr.name + "_" in keywords: 674 # Attrs whose names match Python keywords have an extra '_' 675 # appended, so we must check for that as well. 676 attrs[attr.name] = keywords.pop(attr.name + "_") 677 else: 678 raise TypeError("No argument for attr " + attr.name) 679 680 # Convert attr values to AttrValue protos. 681 attr_protos = {} 682 for attr_def in op_def.attr: 683 key = attr_def.name 684 value = attrs[key] 685 attr_value = attr_value_pb2.AttrValue() 686 if attr_def.HasField("default_value") and value is None: 687 attr_value.CopyFrom(attr_def.default_value) 688 attr_protos[key] = attr_value 689 continue 690 if attr_def.type.startswith("list("): 691 if not _IsListValue(value): 692 raise TypeError("Expected list for attr " + key) 693 if attr_def.has_minimum: 694 if len(value) < attr_def.minimum: 695 raise ValueError("Attr '%s' of '%s' Op passed list of length %d " 696 "less than minimum %d." % 697 (key, op_type_name, len(value), 698 attr_def.minimum)) 699 attr_value.list.SetInParent() 700 if attr_def.type == "string": 701 attr_value.s = _MakeStr(value, key) 702 if attr_def.HasField("allowed_values"): 703 if attr_value.s not in attr_def.allowed_values.list.s: 704 raise ValueError( 705 "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % 706 (key, op_type_name, compat.as_text(attr_value.s), 707 '", "'.join(map(compat.as_text, 708 attr_def.allowed_values.list.s)))) 709 elif attr_def.type == "list(string)": 710 attr_value.list.s.extend([_MakeStr(x, key) for x in value]) 711 if attr_def.HasField("allowed_values"): 712 for x in attr_value.list.s: 713 if x not in attr_def.allowed_values.list.s: 714 raise ValueError( 715 "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % 716 (key, op_type_name, compat.as_text(x), 717 '", "'.join(map(compat.as_text, 718 attr_def.allowed_values.list.s)))) 719 elif attr_def.type == "int": 720 attr_value.i = _MakeInt(value, key) 721 if attr_def.has_minimum: 722 if attr_value.i < attr_def.minimum: 723 raise ValueError( 724 "Attr '%s' of '%s' Op passed %d less than minimum %d." % 725 (key, op_type_name, attr_value.i, attr_def.minimum)) 726 elif attr_def.type == "list(int)": 727 attr_value.list.i.extend([_MakeInt(x, key) for x in value]) 728 elif attr_def.type == "float": 729 attr_value.f = _MakeFloat(value, key) 730 elif attr_def.type == "list(float)": 731 attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) 732 elif attr_def.type == "bool": 733 attr_value.b = _MakeBool(value, key) 734 elif attr_def.type == "list(bool)": 735 attr_value.list.b.extend([_MakeBool(x, key) for x in value]) 736 elif attr_def.type == "type": 737 attr_value.type = _MakeType(value, attr_def) 738 elif attr_def.type == "list(type)": 739 attr_value.list.type.extend( 740 [_MakeType(x, attr_def) for x in value]) 741 elif attr_def.type == "shape": 742 attr_value.shape.CopyFrom(_MakeShape(value, key)) 743 elif attr_def.type == "list(shape)": 744 attr_value.list.shape.extend( 745 [_MakeShape(x, key) for x in value]) 746 elif attr_def.type == "tensor": 747 attr_value.tensor.CopyFrom(_MakeTensor(value, key)) 748 elif attr_def.type == "list(tensor)": 749 attr_value.list.tensor.extend( 750 [_MakeTensor(x, key) for x in value]) 751 elif attr_def.type == "func": 752 attr_value.func.CopyFrom(_MakeFunc(value, key)) 753 elif attr_def.type == "list(func)": 754 attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) 755 else: 756 raise TypeError("Unrecognized Attr type " + attr_def.type) 757 758 attr_protos[key] = attr_value 759 del attrs # attrs is no longer authoritative, use attr_protos instead 760 761 # Determine output types (possibly using attrs) 762 output_structure = [] 763 for arg in op_def.output_arg: 764 if arg.number_attr: 765 n = _AttrValue(attr_protos, arg.number_attr).i 766 output_structure.append(n) 767 elif arg.type_attr: 768 t = _AttrValue(attr_protos, arg.type_attr) 769 output_structure.append(None) 770 elif arg.type_list_attr: 771 t = _AttrValue(attr_protos, arg.type_list_attr) 772 output_structure.append(len(t.list.type)) 773 else: 774 output_structure.append(None) 775 776 if keywords: 777 raise TypeError("apply_op() got unexpected keyword arguments: " + 778 ", ".join(sorted(keywords.keys()))) 779 780 # NOTE(mrry): We add an explicit colocation constraint between 781 # the newly created op and any of its reference-typed inputs. 782 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) 783 if arg.is_ref] 784 with _MaybeColocateWith(must_colocate_inputs): 785 # Add Op to graph 786 op = g.create_op(op_type_name, inputs, dtypes=None, name=scope, 787 input_types=input_types, attrs=attr_protos, 788 op_def=op_def) 789 return output_structure, op_def.is_stateful, op 790 791# pylint: enable=invalid-name 792