1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Class to hold a library of OpDefs and use it to create Brain operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import six 23 24from google.protobuf import text_format 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import tensor_pb2 27from tensorflow.core.framework import tensor_shape_pb2 28from tensorflow.core.framework import types_pb2 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import op_callbacks 31from tensorflow.python.framework import op_def_registry 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.util import _pywrap_utils 36from tensorflow.python.util import compat 37from tensorflow.python.util import tf_contextlib 38 39 40def _Attr(op_def, name): 41 for attr in op_def.attr: 42 if attr.name == name: 43 return attr 44 raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" % 45 (op_def.name, name)) 46 47 48def _AttrValue(attr_protos, name): 49 if name in attr_protos: 50 return attr_protos[name] 51 raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." % 52 (name, attr_protos)) 53 54 55def _SatisfiesTypeConstraint(dtype, attr_def, param_name): 56 if attr_def.HasField("allowed_values"): 57 allowed_list = attr_def.allowed_values.list.type 58 if dtype not in allowed_list: 59 raise TypeError( 60 "Value passed to parameter '%s' has DataType %s not in list of " 61 "allowed values: %s" % 62 (param_name, dtypes.as_dtype(dtype).name, 63 ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) 64 65 66def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): 67 if attr_def.has_minimum and length < attr_def.minimum: 68 raise ValueError("Attr '%s' of '%s' Op passed list of length %d " 69 "less than minimum %d." % 70 (param_name, op_type_name, length, attr_def.minimum)) 71 72 73def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): 74 if value not in attr_def.allowed_values.list.s: 75 raise ValueError( 76 "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % 77 (arg_name, op_type_name, compat.as_text(value), '", "'.join( 78 map(compat.as_text, attr_def.allowed_values.list.s)))) 79 80 81def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): 82 if value < attr_def.minimum: 83 raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." % 84 (arg_name, op_type_name, value, attr_def.minimum)) 85 86 87def _IsListParameter(arg): 88 if arg.number_attr: 89 return True 90 elif arg.type_list_attr: 91 return True 92 return False 93 94 95def _NumTypeFields(arg): 96 num = 0 97 if arg.type != types_pb2.DT_INVALID: num += 1 98 if arg.type_attr: num += 1 99 if arg.type_list_attr: num += 1 100 return num 101 102 103def _IsListValue(v): 104 return isinstance(v, (list, tuple)) 105 106 107def _Flatten(l): 108 """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5].""" 109 # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]] 110 l_of_l = [x if _IsListValue(x) else [x] for x in l] 111 # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5] 112 return [item for sublist in l_of_l for item in sublist] 113 114 115def _Restructure(l, structure): 116 """Returns the elements of list l structured according to the given structure. 117 118 A structure is represented by a list whose elements are either 119 `None` or a non-negative integer. `None` corresponds to a single 120 element in the output list, and an integer N corresponds to a nested 121 list of length N. 122 123 The function returns a data structure whose shape is given by 124 `structure`, and whose elements are taken from `l`. If `structure` 125 is a singleton, the function returns the single data structure 126 implied by the 0th element of `structure`. For example: 127 128 _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None]) 129 -> ["foo", ["bar", "baz"], "qux"] 130 131 _Restructure(["foo"], [None]) -> "foo" 132 133 _Restructure(["foo"], [1]) -> ["foo"] 134 135 _Restructure([], [0]) -> [] 136 137 Args: 138 l: A list. 139 structure: A list whose elements are either `None` or a non-negative 140 integer. 141 142 Returns: 143 The elements of `l`, restructured according to `structure`. If 144 `structure` is a list of length 1, this function returns the 145 single data structure implied by `structure[0]`. 146 147 """ 148 result = [] 149 current_index = 0 150 for element in structure: 151 if element is None: 152 result.append(l[current_index]) 153 current_index += 1 154 else: 155 result.append(l[current_index:current_index+element]) 156 current_index += element 157 158 if len(result) == 1: 159 return result[0] 160 else: 161 return tuple(result) 162 163 164def _MakeFloat(v, arg_name): 165 if not isinstance(v, compat.real_types): 166 raise TypeError("Expected float for argument '%s' not %s." % 167 (arg_name, repr(v))) 168 return float(v) 169 170 171def _MakeInt(v, arg_name): 172 if isinstance(v, six.string_types): 173 raise TypeError("Expected int for argument '%s' not %s." % 174 (arg_name, repr(v))) 175 try: 176 return int(v) 177 except (ValueError, TypeError): 178 raise TypeError("Expected int for argument '%s' not %s." % 179 (arg_name, repr(v))) 180 181 182def _MakeStr(v, arg_name): 183 if not isinstance(v, compat.bytes_or_text_types): 184 raise TypeError("Expected string for argument '%s' not %s." % 185 (arg_name, repr(v))) 186 return compat.as_bytes(v) # Convert unicode strings to bytes. 187 188 189def _MakeBool(v, arg_name): 190 if not isinstance(v, bool): 191 raise TypeError("Expected bool for argument '%s' not %s." % 192 (arg_name, repr(v))) 193 return v 194 195 196def _MakeType(v, arg_name): 197 try: 198 v = dtypes.as_dtype(v).base_dtype 199 except TypeError: 200 raise TypeError("Expected DataType for argument '%s' not %s." % 201 (arg_name, repr(v))) 202 return v.as_datatype_enum 203 204 205def _MakeShape(v, arg_name): 206 """Convert v into a TensorShapeProto.""" 207 # Args: 208 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 209 # arg_name: String, for error messages. 210 211 # Returns: 212 # A TensorShapeProto. 213 if isinstance(v, tensor_shape_pb2.TensorShapeProto): 214 for d in v.dim: 215 if d.name: 216 logging.warning("Warning: TensorShapeProto with a named dimension: %s", 217 str(v)) 218 break 219 return v 220 try: 221 return tensor_shape.as_shape(v).as_proto() 222 except TypeError as e: 223 raise TypeError("Error converting %s to a TensorShape: %s" % (arg_name, e)) 224 except ValueError as e: 225 raise ValueError("Error converting %s to a TensorShape: %s" % (arg_name, e)) 226 227 228def _MakeTensor(v, arg_name): 229 """Ensure v is a TensorProto.""" 230 if isinstance(v, tensor_pb2.TensorProto): 231 return v 232 raise TypeError( 233 "Don't know how to convert %s to a TensorProto for argument '%s'" % 234 (repr(v), arg_name)) 235 236 237def _MakeFunc(v, arg_name): 238 """Ensure v is a func.""" 239 if isinstance(v, attr_value_pb2.NameAttrList): 240 return v 241 if isinstance(v, compat.bytes_or_text_types): 242 fn_attr = attr_value_pb2.NameAttrList(name=v) 243 elif hasattr(v, "add_to_graph"): 244 v.add_to_graph(ops.get_default_graph()) 245 if hasattr(v, "_as_name_attr_list"): 246 fn_attr = v._as_name_attr_list # pylint: disable=protected-access 247 else: 248 fn_attr = attr_value_pb2.NameAttrList(name=v.name) 249 else: 250 raise TypeError("Don't know how to convert {} to a func for " 251 "argument {}".format(v, arg_name)) 252 return fn_attr 253 254 255# pylint: disable=g-doc-return-or-yield 256@tf_contextlib.contextmanager 257def _MaybeColocateWith(inputs): 258 """A context manager for (maybe) colocating with a list of input tensors. 259 260 Args: 261 inputs: A list of `Tensor` or `Operation` objects. 262 263 Returns: 264 A context manager. 265 """ 266 if not inputs: 267 yield 268 else: 269 # NOTE(mrry): The `ops.colocate_with()` function accepts only a single 270 # op or tensor, so we create one context manager per element in the list. 271 with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): 272 yield 273# pylint: enable=g-doc-return-or-yield 274 275 276def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name 277 """Add a node invoking a registered Op to a graph. 278 279 Example usage: 280 # input1 and input2 can be Tensors or anything ops.convert_to_tensor() 281 # will convert to a Tensor. 282 op_def_library.apply_op("op", input1=input1, input2=input2) 283 # Can specify a node name. 284 op_def_library.apply_op("op", input1=input1, name="node_name") 285 # Must use keyword arguments, with the names specified in the OpDef. 286 op_def_library.apply_op("op", input_name=input, attr_name=attr) 287 288 All attrs must either be inferred from an input or specified. 289 (If inferred, the attr must not be specified.) If an attr has a default 290 value specified in the Op's OpDef, then you may pass None as the value 291 of that attr to get the default. 292 293 Args: 294 op_type_name: string. Must match the name field of a registered Op. 295 name: string. Optional name of the created op. 296 **keywords: input Tensor and attr arguments specified by name, 297 and optional parameters to pass when constructing the Operation. 298 299 Returns: 300 The Tensor(s) representing the output of the operation, or the Operation 301 itself if there are no outputs. 302 303 Raises: 304 RuntimeError: On some errors. 305 TypeError: On some errors. 306 ValueError: On some errors. 307 """ 308 output_structure, is_stateful, op, outputs = _apply_op_helper( 309 op_type_name, name, **keywords) 310 if output_structure: 311 res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) 312 if isinstance(res, list) and not res and is_stateful: 313 return op 314 else: 315 return res 316 else: 317 return op 318 319 320def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name 321 """Implementation of apply_op that returns output_structure, op.""" 322 op_def = op_def_registry.get(op_type_name) 323 if op_def is None: 324 raise RuntimeError("Unrecognized Op name " + op_type_name) 325 326 # Determine the graph context. 327 try: 328 # Need to flatten all the arguments into a list. 329 # pylint: disable=protected-access 330 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 331 # pylint: enable=protected-access 332 except AssertionError as e: 333 raise RuntimeError( 334 "Cannot determine graph for Op '%s' due to: %s" 335 % (op_type_name, e.message)) 336 337 # Default name if not specified. 338 if name is None: 339 name = op_type_name 340 341 # Check for deprecation 342 deprecation_version = op_def.deprecation.version 343 if deprecation_version: 344 producer = g.graph_def_versions.producer 345 if producer >= deprecation_version: 346 raise NotImplementedError( 347 ("Op %s is not available in GraphDef version %d. " 348 "It has been removed in version %d. %s.") % 349 (op_type_name, producer, deprecation_version, 350 op_def.deprecation.explanation)) 351 352 # Fill in the list of default types for all "type" attrs. This 353 # will be used to choose a preferred dtype to convert to in the 354 # absence of input type information. 355 # 356 # TODO(b/31302892): Currently the defaults don't work in the right 357 # way if you have two inputs, one of whose type resolution depends 358 # on the other. Handling this will require restructuring this code 359 # significantly. 360 default_type_attr_map = {} 361 allowed_list_attr_map = {} 362 for attr_def in op_def.attr: 363 if attr_def.type != "type": 364 continue 365 key = attr_def.name 366 if attr_def.HasField("default_value"): 367 default_type_attr_map[key] = dtypes.as_dtype( 368 attr_def.default_value.type) 369 if attr_def.HasField("allowed_values"): 370 allowed_list_attr_map[key] = attr_def.allowed_values.list.type 371 372 # Requires that op_def has passed validation (using the C++ 373 # ValidateOpDef() from ../framework/op_def_util.h). 374 attrs = {} 375 inputs = [] 376 input_types = [] 377 with g.as_default(), ops.name_scope(name) as scope: 378 379 # Perform input type inference 380 inferred_from = {} 381 for input_arg in op_def.input_arg: 382 input_name = input_arg.name 383 if input_name in keywords: 384 values = keywords.pop(input_name) 385 elif input_name + "_" in keywords: 386 # Handle the case where the name is a keyword or built-in 387 # for Python so we use the name + _ instead. 388 input_name += "_" 389 values = keywords.pop(input_name) 390 else: 391 raise TypeError("No argument for input " + input_name) 392 393 # Goals: 394 # * Convert values to Tensors if it contains constants. 395 # * Verify that values is a list if that matches the input_arg's 396 # type. 397 # * If the input_arg's type is determined by attrs, either set 398 # those attrs and validate those attr values are legal (if 399 # they have not yet been set) or validate the input matches 400 # the type indicated by the attrs (if they have already been 401 # inferred via an earlier input). 402 # * If the input_arg has an explicit type, make sure the input 403 # conforms. 404 405 if _IsListParameter(input_arg): 406 if not _IsListValue(values): 407 raise TypeError( 408 "Expected list for '%s' argument to '%s' Op, not %s." % 409 (input_name, op_type_name, values)) 410 # In cases where we expect all elements of the list to have the 411 # same dtype, try to cast non-Tensor elements to that type. 412 dtype = None 413 default_dtype = None 414 if input_arg.type != types_pb2.DT_INVALID: 415 dtype = input_arg.type 416 elif input_arg.number_attr: 417 if input_arg.type_attr in attrs: 418 dtype = attrs[input_arg.type_attr] 419 else: 420 for t in values: 421 if isinstance(t, ops.Tensor): 422 dtype = t.dtype 423 break 424 425 # dtype still not found, prefer using the default dtype 426 # from the attr. 427 if dtype is None and input_arg.type_attr in default_type_attr_map: 428 default_dtype = default_type_attr_map[input_arg.type_attr] 429 430 try: 431 if not input_arg.is_ref and dtype: 432 dtype = dtypes.as_dtype(dtype).base_dtype 433 values = ops.internal_convert_n_to_tensor( 434 values, 435 name=input_arg.name, 436 dtype=dtype if dtype else None, 437 preferred_dtype=default_dtype, 438 as_ref=input_arg.is_ref) 439 if input_arg.number_attr and len( 440 set(v.dtype.base_dtype for v in values)) > 1: 441 raise TypeError() # All types should match. 442 except (TypeError, ValueError): 443 # What types does the conversion function think values have? 444 observed_types = [] 445 for value in values: 446 try: 447 converted_value = ops.convert_to_tensor( 448 value, as_ref=input_arg.is_ref) 449 observed_types.append(converted_value.dtype.base_dtype.name) 450 except (TypeError, ValueError): 451 observed_types.append("<NOT CONVERTIBLE TO TENSOR>") 452 observed = ", ".join(observed_types) 453 454 prefix = ( 455 "Tensors in list passed to '%s' of '%s' Op have types [%s]" % 456 (input_name, op_type_name, observed)) 457 if input_arg.number_attr: 458 if input_arg.type != types_pb2.DT_INVALID: 459 raise TypeError("%s that do not match expected type %s." % 460 (prefix, dtype.name)) 461 elif input_arg.type_attr in attrs: 462 raise TypeError("%s that do not match type %s inferred from " 463 "earlier arguments." % 464 (prefix, dtype.name)) 465 else: 466 raise TypeError("%s that don't all match." % prefix) 467 else: 468 raise TypeError( 469 "%s that are invalid. Tensors: %s" % (prefix, values)) 470 471 types = [x.dtype for x in values] 472 inputs.extend(values) 473 else: 474 # In cases where we have an expected type, try to convert non-Tensor 475 # arguments to that type. 476 dtype = None 477 default_dtype = None 478 allowed_list = None 479 if input_arg.type != types_pb2.DT_INVALID: 480 dtype = input_arg.type 481 elif input_arg.type_attr in attrs: 482 dtype = attrs[input_arg.type_attr] 483 elif input_arg.type_attr in default_type_attr_map: 484 # The dtype could not be inferred solely from the inputs, 485 # so we prefer the attr's default, so code that adds a new attr 486 # with a default is backwards compatible. 487 default_dtype = default_type_attr_map[input_arg.type_attr] 488 allowed_list = allowed_list_attr_map.get(input_arg.type_attr) 489 490 try: 491 # First see if we can get a valid dtype with the default conversion 492 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may 493 # not list allowed dtypes, in which case we should skip this. 494 if dtype is None and allowed_list: 495 inferred = None 496 try: 497 inferred = ops.convert_to_tensor( 498 values, name=input_arg.name, as_ref=input_arg.is_ref) 499 except TypeError as err: 500 # When converting a python object such as a list of Dimensions, we 501 # need a dtype to be specified, thus tensor conversion may throw 502 # an exception which we will ignore and try again below. 503 pass 504 505 # If we did not match an allowed dtype, try again with the default 506 # dtype. This could be because we have an empty tensor and thus we 507 # picked the wrong type. 508 if inferred is not None and inferred.dtype in allowed_list: 509 values = inferred 510 else: 511 values = ops.convert_to_tensor( 512 values, 513 name=input_arg.name, 514 as_ref=input_arg.is_ref, 515 preferred_dtype=default_dtype) 516 else: 517 values = ops.convert_to_tensor( 518 values, 519 name=input_arg.name, 520 dtype=dtype, 521 as_ref=input_arg.is_ref, 522 preferred_dtype=default_dtype) 523 except TypeError as err: 524 if dtype is None: 525 raise err 526 else: 527 raise TypeError( 528 "Expected %s passed to parameter '%s' of op '%s', got %s of " 529 "type '%s' instead. Error: %s" % 530 (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, 531 repr(values), type(values).__name__, err)) 532 except ValueError: 533 # What type does convert_to_tensor think it has? 534 try: 535 observed = ops.convert_to_tensor( 536 values, as_ref=input_arg.is_ref).dtype.name 537 except ValueError as err: 538 raise ValueError( 539 "Tried to convert '%s' to a tensor and failed. Error: %s" % 540 (input_name, err)) 541 prefix = ("Input '%s' of '%s' Op has type %s that does not match" % 542 (input_name, op_type_name, observed)) 543 if input_arg.type != types_pb2.DT_INVALID: 544 raise TypeError("%s expected type of %s." % 545 (prefix, dtypes.as_dtype(input_arg.type).name)) 546 else: 547 # Update the maps with the default, if needed. 548 k = input_arg.type_attr 549 if k in default_type_attr_map: 550 if k not in attrs: 551 attrs[k] = default_type_attr_map[k] 552 if k not in inferred_from: 553 inferred_from[k] = "Default in OpDef" 554 555 raise TypeError( 556 "%s type %s of argument '%s'." % 557 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, 558 inferred_from[input_arg.type_attr])) 559 560 types = [values.dtype] 561 inputs.append(values) 562 base_types = [x.base_dtype for x in types] 563 564 if input_arg.number_attr: 565 # <number-attr> * <type> or <number-attr> * <type-attr> 566 if input_arg.number_attr in attrs: 567 if len(values) != attrs[input_arg.number_attr]: 568 raise ValueError( 569 "List argument '%s' to '%s' Op with length %d must match " 570 "length %d of argument '%s'." % 571 (input_name, op_type_name, len(values), 572 attrs[input_arg.number_attr], 573 inferred_from[input_arg.number_attr])) 574 else: 575 attrs[input_arg.number_attr] = len(values) 576 inferred_from[input_arg.number_attr] = input_name 577 num_attr = _Attr(op_def, input_arg.number_attr) 578 if num_attr.has_minimum and len(values) < num_attr.minimum: 579 raise ValueError( 580 "List argument '%s' to '%s' Op with length %d shorter " 581 "than minimum length %d." % 582 (input_name, op_type_name, len(values), num_attr.minimum)) 583 # All tensors must have the same base type. 584 if any(bt != base_types[0] for bt in base_types): 585 raise TypeError( 586 "All tensors passed to '%s' of '%s' Op " 587 "must have the same type." % 588 (input_name, op_type_name)) 589 if input_arg.type != types_pb2.DT_INVALID: 590 # <number-attr> * <type> case 591 if base_types and base_types[0] != input_arg.type: 592 assert False, "Unreachable" 593 elif input_arg.type_attr in attrs: 594 # <number-attr> * <type-attr> case, where <type-attr> already 595 # has an inferred value. 596 if base_types and base_types[0] != attrs[input_arg.type_attr]: 597 assert False, "Unreachable" 598 else: 599 # <number-attr> * <type-attr> case, where we are now setting 600 # the <type-attr> based on this input 601 if not base_types: 602 # If it's in default_type_attr_map, then wait to set it 603 # (in "process remaining attrs", below). 604 if input_arg.type_attr not in default_type_attr_map: 605 raise TypeError( 606 "Don't know how to infer type variable from empty input " 607 "list passed to input '%s' of '%s' Op." % 608 (input_name, op_type_name)) 609 else: 610 attrs[input_arg.type_attr] = base_types[0] 611 inferred_from[input_arg.type_attr] = input_name 612 type_attr = _Attr(op_def, input_arg.type_attr) 613 _SatisfiesTypeConstraint(base_types[0], type_attr, 614 param_name=input_name) 615 elif input_arg.type_attr: 616 # <type-attr> 617 attr_value = base_types[0] 618 if input_arg.type_attr in attrs: 619 if attrs[input_arg.type_attr] != attr_value: 620 raise TypeError( 621 "Input '%s' of '%s' Op has type %s that does not " 622 "match type %s of argument '%s'." % 623 (input_name, op_type_name, dtypes.as_dtype(attr_value).name, 624 dtypes.as_dtype(attrs[input_arg.type_attr]).name, 625 inferred_from[input_arg.type_attr])) 626 else: 627 for base_type in base_types: 628 _SatisfiesTypeConstraint(base_type, 629 _Attr(op_def, input_arg.type_attr), 630 param_name=input_name) 631 attrs[input_arg.type_attr] = attr_value 632 inferred_from[input_arg.type_attr] = input_name 633 elif input_arg.type_list_attr: 634 # <type-list-attr> 635 attr_value = base_types 636 if input_arg.type_list_attr in attrs: 637 if attrs[input_arg.type_list_attr] != attr_value: 638 raise TypeError( 639 "Input '%s' of '%s' Op has type list of %s that does not " 640 "match type list %s of argument '%s'." % 641 (input_name, op_type_name, 642 ", ".join(dtypes.as_dtype(x).name for x in attr_value), 643 ", ".join(dtypes.as_dtype(x).name 644 for x in attrs[input_arg.type_list_attr]), 645 inferred_from[input_arg.type_list_attr])) 646 else: 647 for base_type in base_types: 648 _SatisfiesTypeConstraint(base_type, 649 _Attr(op_def, input_arg.type_list_attr), 650 param_name=input_name) 651 attrs[input_arg.type_list_attr] = attr_value 652 inferred_from[input_arg.type_list_attr] = input_name 653 else: 654 # single Tensor with specified type 655 if base_types[0] != input_arg.type: 656 assert False, "Unreachable" 657 658 if input_arg.is_ref: 659 if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access 660 raise TypeError( 661 ("'%s' Op requires that input '%s' be a mutable tensor " 662 "(e.g.: a tf.Variable)") % (op_type_name, input_name)) 663 input_types.extend(types) 664 else: 665 input_types.extend(base_types) 666 667 # Process remaining attrs 668 for attr in op_def.attr: 669 # Skip attrs that have already had their values inferred 670 if attr.name in attrs: 671 if attr.name in keywords: 672 raise TypeError( 673 "Should not specify value for inferred attr '%s'." % attr.name) 674 continue 675 if attr.name in keywords: 676 attrs[attr.name] = keywords.pop(attr.name) 677 elif attr.name + "_" in keywords: 678 # Attrs whose names match Python keywords have an extra '_' 679 # appended, so we must check for that as well. 680 attrs[attr.name] = keywords.pop(attr.name + "_") 681 elif attr.name in default_type_attr_map: 682 attrs[attr.name] = default_type_attr_map[attr.name] 683 inferred_from.setdefault(attr.name, "Default in OpDef") 684 else: 685 raise TypeError("No argument for attr " + attr.name) 686 687 # Convert attr values to AttrValue protos. 688 attr_protos = {} 689 for attr_def in op_def.attr: 690 key = attr_def.name 691 value = attrs[key] 692 693 if attr_def.HasField("default_value") and value is None: 694 attr_value = attr_value_pb2.AttrValue() 695 attr_value.CopyFrom(attr_def.default_value) 696 attr_protos[key] = attr_value 697 continue 698 699 attr_value = value_to_attr_value(value, attr_def.type, key) 700 if attr_def.type.startswith("list("): 701 _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) 702 if attr_def.HasField("allowed_values"): 703 if attr_def.type == "string": 704 _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, 705 op_type_name) 706 elif attr_def.type == "list(string)": 707 for value in attr_value.list.s: 708 _SatisfiesAllowedStringsConstraint(value, attr_def, key, 709 op_type_name) 710 if attr_def.has_minimum and attr_def.type == "int": 711 _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, 712 op_type_name) 713 if attr_def.type == "type": 714 _SatisfiesTypeConstraint(attr_value.type, attr_def, key) 715 if attr_def.type == "list(type)": 716 for value in attr_value.list.type: 717 _SatisfiesTypeConstraint(value, attr_def, key) 718 719 attr_protos[key] = attr_value 720 del attrs # attrs is no longer authoritative, use attr_protos instead 721 722 # Determine output types (possibly using attrs) 723 output_structure = [] 724 for arg in op_def.output_arg: 725 if arg.number_attr: 726 n = _AttrValue(attr_protos, arg.number_attr).i 727 output_structure.append(n) 728 elif arg.type_attr: 729 t = _AttrValue(attr_protos, arg.type_attr) 730 output_structure.append(None) 731 elif arg.type_list_attr: 732 t = _AttrValue(attr_protos, arg.type_list_attr) 733 output_structure.append(len(t.list.type)) 734 else: 735 output_structure.append(None) 736 737 if keywords: 738 raise TypeError("apply_op() got unexpected keyword arguments: " + 739 ", ".join(sorted(keywords.keys()))) 740 741 # NOTE(mrry): We add an explicit colocation constraint between 742 # the newly created op and any of its reference-typed inputs. 743 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) 744 if arg.is_ref] 745 with _MaybeColocateWith(must_colocate_inputs): 746 # Add Op to graph 747 # pylint: disable=protected-access 748 op = g._create_op_internal(op_type_name, inputs, dtypes=None, 749 name=scope, input_types=input_types, 750 attrs=attr_protos, op_def=op_def) 751 752 # `outputs` is returned as a separate return value so that the output 753 # tensors can the `op` per se can be decoupled so that the 754 # `op_callbacks` can function properly. See framework/op_callbacks.py 755 # for more details. 756 outputs = op.outputs 757 # Conditionally invoke tfdbg v2's op callback(s). 758 if op_callbacks.should_invoke_op_callbacks(): 759 callback_outputs = op_callbacks.invoke_op_callbacks( 760 op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), 761 op_name=op.name, graph=g) 762 if callback_outputs is not None: 763 outputs = callback_outputs 764 765 return output_structure, op_def.is_stateful, op, outputs 766 767 768def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name 769 """Encodes a Python value as an `AttrValue` proto message. 770 771 Args: 772 value: The value to convert. 773 attr_type: The value type (string) -- see the AttrValue proto definition for 774 valid strings. 775 arg_name: Argument name (for error messages). 776 777 Returns: 778 An AttrValue proto message that encodes `value`. 779 """ 780 attr_value = attr_value_pb2.AttrValue() 781 782 if attr_type.startswith("list("): 783 if not _IsListValue(value): 784 raise TypeError("Expected list for attr " + arg_name) 785 786 if attr_type == "string": 787 attr_value.s = _MakeStr(value, arg_name) 788 elif attr_type == "list(string)": 789 attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) 790 elif attr_type == "int": 791 attr_value.i = _MakeInt(value, arg_name) 792 elif attr_type == "list(int)": 793 attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) 794 elif attr_type == "float": 795 attr_value.f = _MakeFloat(value, arg_name) 796 elif attr_type == "list(float)": 797 attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) 798 elif attr_type == "bool": 799 attr_value.b = _MakeBool(value, arg_name) 800 elif attr_type == "list(bool)": 801 attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) 802 elif attr_type == "type": 803 attr_value.type = _MakeType(value, arg_name) 804 elif attr_type == "list(type)": 805 attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) 806 elif attr_type == "shape": 807 attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) 808 elif attr_type == "list(shape)": 809 attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) 810 elif attr_type == "tensor": 811 attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) 812 elif attr_type == "list(tensor)": 813 attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) 814 elif attr_type == "func": 815 attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) 816 elif attr_type == "list(func)": 817 attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) 818 else: 819 raise TypeError("Unrecognized Attr type " + attr_type) 820 return attr_value 821 822 823# The following symbols are used by op_def_util.cc. 824_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) 825_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype) 826_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape) 827_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape) 828_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto) 829_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse) 830_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor) 831