1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Class to hold a library of OpDefs and use it to create Brain operations.""" 17 18from google.protobuf import text_format 19from tensorflow.core.config import flags 20from tensorflow.core.framework import attr_value_pb2 21from tensorflow.core.framework import tensor_pb2 22from tensorflow.core.framework import tensor_shape_pb2 23from tensorflow.core.framework import types_pb2 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import op_callbacks 26from tensorflow.python.framework import op_def_library_pybind 27from tensorflow.python.framework import op_def_registry 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import _pywrap_utils 32from tensorflow.python.util import compat 33from tensorflow.python.util import tf_contextlib 34 35 36def _Attr(op_def, name): 37 for attr in op_def.attr: 38 if attr.name == name: 39 return attr 40 raise TypeError(f"Inconsistent OpDef for '{op_def.name}', missing attr " 41 f"'{name}'") 42 43 44def _AttrValue(attr_protos, name, op_type_name): 45 if name in attr_protos: 46 return attr_protos[name] 47 raise TypeError(f"Inconsistent OpDef for '{op_type_name}', missing attr " 48 f"'{name}' from '{attr_protos}'.") 49 50 51def _SatisfiesTypeConstraint(dtype, attr_def, param_name): 52 if attr_def.HasField("allowed_values"): 53 allowed_list = attr_def.allowed_values.list.type 54 allowed_values = ", ".join(dtypes.as_dtype(x).name for x in allowed_list) 55 if dtype not in allowed_list: 56 raise TypeError( 57 f"Value passed to parameter '{param_name}' has DataType " 58 f"{dtypes.as_dtype(dtype).name} not in list of allowed values: " 59 f"{allowed_values}") 60 61 62def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): 63 if attr_def.has_minimum and length < attr_def.minimum: 64 raise ValueError(f"Attr '{param_name}' of '{op_type_name}' Op passed list " 65 f"of length {length} less than minimum " 66 f"{attr_def.minimum}.") 67 68 69def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): 70 if value not in attr_def.allowed_values.list.s: 71 allowed_values = '", "'.join( 72 map(compat.as_text, attr_def.allowed_values.list.s)) 73 raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed string " 74 f"'{compat.as_text(value)}' not in: \"{allowed_values}\".") 75 76 77def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): 78 if value < attr_def.minimum: 79 raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed {value} " 80 f"less than minimum {attr_def.minimum}.") 81 82 83def _IsListParameter(arg): 84 if arg.number_attr: 85 return True 86 elif arg.type_list_attr: 87 return True 88 return False 89 90 91def _NumTypeFields(arg): 92 num = 0 93 if arg.type != types_pb2.DT_INVALID: num += 1 94 if arg.type_attr: num += 1 95 if arg.type_list_attr: num += 1 96 return num 97 98 99def _IsListValue(v): 100 return isinstance(v, (list, tuple)) 101 102 103def _Flatten(l): 104 """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5].""" 105 # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]] 106 l_of_l = [x if _IsListValue(x) else [x] for x in l] 107 # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5] 108 return [item for sublist in l_of_l for item in sublist] 109 110 111def _Restructure(l, structure): 112 """Returns the elements of list l structured according to the given structure. 113 114 A structure is represented by a list whose elements are either 115 `None` or a non-negative integer. `None` corresponds to a single 116 element in the output list, and an integer N corresponds to a nested 117 list of length N. 118 119 The function returns a data structure whose shape is given by 120 `structure`, and whose elements are taken from `l`. If `structure` 121 is a singleton, the function returns the single data structure 122 implied by the 0th element of `structure`. For example: 123 124 _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None]) 125 -> ["foo", ["bar", "baz"], "qux"] 126 127 _Restructure(["foo"], [None]) -> "foo" 128 129 _Restructure(["foo"], [1]) -> ["foo"] 130 131 _Restructure([], [0]) -> [] 132 133 Args: 134 l: A list. 135 structure: A list whose elements are either `None` or a non-negative 136 integer. 137 138 Returns: 139 The elements of `l`, restructured according to `structure`. If 140 `structure` is a list of length 1, this function returns the 141 single data structure implied by `structure[0]`. 142 143 """ 144 result = [] 145 current_index = 0 146 for element in structure: 147 if element is None: 148 result.append(l[current_index]) 149 current_index += 1 150 else: 151 result.append(l[current_index:current_index+element]) 152 current_index += element 153 154 if len(result) == 1: 155 return result[0] 156 else: 157 return tuple(result) 158 159 160def _MakeFloat(v, arg_name): 161 if not isinstance(v, compat.real_types): 162 raise TypeError(f"Expected float for argument '{arg_name}' not {repr(v)}.") 163 return float(v) 164 165 166def _MakeInt(v, arg_name): 167 if isinstance(v, str): 168 raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.") 169 try: 170 return int(v) 171 except (ValueError, TypeError): 172 raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.") 173 174 175def _MakeStr(v, arg_name): 176 if not isinstance(v, compat.bytes_or_text_types): 177 raise TypeError(f"Expected string for argument '{arg_name}' not {repr(v)}.") 178 return compat.as_bytes(v) # Convert unicode strings to bytes. 179 180 181def _MakeBool(v, arg_name): 182 if not isinstance(v, bool): 183 raise TypeError(f"Expected bool for argument '{arg_name}' not {repr(v)}.") 184 return v 185 186 187def _MakeType(v, arg_name): 188 try: 189 v = dtypes.as_dtype(v).base_dtype 190 except TypeError: 191 raise TypeError(f"Expected DataType for argument '{arg_name}' not " 192 f"{repr(v)}.") 193 return v.as_datatype_enum 194 195 196def _MakeShape(v, arg_name): 197 """Convert v into a TensorShapeProto.""" 198 # Args: 199 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 200 # arg_name: String, for error messages. 201 202 # Returns: 203 # A TensorShapeProto. 204 if isinstance(v, tensor_shape_pb2.TensorShapeProto): 205 for d in v.dim: 206 if d.name: 207 logging.warning("Warning: TensorShapeProto with a named dimension: %s", 208 str(v)) 209 break 210 return v 211 try: 212 return tensor_shape.as_shape(v).as_proto() 213 except TypeError as e: 214 raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a " 215 f"TensorShape: {e}") 216 except ValueError as e: 217 raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a " 218 f"TensorShape: {e}") 219 220 221def _MakeTensor(v, arg_name): 222 """Ensure v is a TensorProto.""" 223 if isinstance(v, tensor_pb2.TensorProto): 224 return v 225 raise TypeError( 226 f"Don't know how to convert {repr(v)} to a TensorProto for argument " 227 f"'{arg_name}'") 228 229 230def _MakeFunc(v, arg_name): 231 """Ensure v is a func.""" 232 if isinstance(v, attr_value_pb2.NameAttrList): 233 return v 234 if isinstance(v, compat.bytes_or_text_types): 235 fn_attr = attr_value_pb2.NameAttrList(name=v) 236 elif hasattr(v, "add_to_graph"): 237 v.add_to_graph(ops.get_default_graph()) 238 if hasattr(v, "_as_name_attr_list"): 239 fn_attr = v._as_name_attr_list # pylint: disable=protected-access 240 else: 241 fn_attr = attr_value_pb2.NameAttrList(name=v.name) 242 else: 243 raise TypeError(f"Don't know how to convert {repr(v)} to a func for " 244 f"argument {arg_name}") 245 return fn_attr 246 247 248# pylint: disable=g-doc-return-or-yield 249@tf_contextlib.contextmanager 250def _MaybeColocateWith(inputs): 251 """A context manager for (maybe) colocating with a list of input tensors. 252 253 Args: 254 inputs: A list of `Tensor` or `Operation` objects. 255 256 Returns: 257 A context manager. 258 """ 259 if not inputs: 260 yield 261 else: 262 # NOTE(mrry): The `ops.colocate_with()` function accepts only a single 263 # op or tensor, so we create one context manager per element in the list. 264 with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): 265 yield 266# pylint: enable=g-doc-return-or-yield 267 268 269def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name 270 """Add a node invoking a registered Op to a graph. 271 272 Example usage: 273 # input1 and input2 can be Tensors or anything ops.convert_to_tensor() 274 # will convert to a Tensor. 275 op_def_library.apply_op("op", input1=input1, input2=input2) 276 # Can specify a node name. 277 op_def_library.apply_op("op", input1=input1, name="node_name") 278 # Must use keyword arguments, with the names specified in the OpDef. 279 op_def_library.apply_op("op", input_name=input, attr_name=attr) 280 281 All attrs must either be inferred from an input or specified. 282 (If inferred, the attr must not be specified.) If an attr has a default 283 value specified in the Op's OpDef, then you may pass None as the value 284 of that attr to get the default. 285 286 Args: 287 op_type_name: string. Must match the name field of a registered Op. 288 name: string. Optional name of the created op. 289 **keywords: input Tensor and attr arguments specified by name, and optional 290 parameters to pass when constructing the Operation. 291 292 Returns: 293 The Tensor(s) representing the output of the operation, or the Operation 294 itself if there are no outputs. 295 296 Raises: 297 RuntimeError: On some errors. 298 TypeError: On some errors. 299 ValueError: On some errors. 300 """ 301 output_structure, is_stateful, op, outputs = _apply_op_helper( 302 op_type_name, name, **keywords) 303 if output_structure: 304 res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) 305 if isinstance(res, list) and not res and is_stateful: 306 return op 307 else: 308 return res 309 else: 310 return op 311 312 313# This is temporary Python/C++ code duplication until all of it can be ported 314# over to C++. 315# LINT.IfChange 316def _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos): 317 """Extracts `attr_protos`. For use in _apply_op_helper.""" 318 for attr_def in op_def.attr: 319 key = attr_def.name 320 value = attrs[key] 321 322 if attr_def.HasField("default_value") and value is None: 323 attr_value = attr_value_pb2.AttrValue() 324 attr_value.CopyFrom(attr_def.default_value) 325 attr_protos[key] = attr_value 326 continue 327 328 attr_value = value_to_attr_value(value, attr_def.type, key) 329 if attr_def.type.startswith("list("): 330 _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) 331 if attr_def.HasField("allowed_values"): 332 if attr_def.type == "string": 333 _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, 334 op_type_name) 335 elif attr_def.type == "list(string)": 336 for value in attr_value.list.s: 337 _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name) 338 if attr_def.has_minimum and attr_def.type == "int": 339 _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name) 340 if attr_def.type == "type": 341 _SatisfiesTypeConstraint(attr_value.type, attr_def, key) 342 if attr_def.type == "list(type)": 343 for value in attr_value.list.type: 344 _SatisfiesTypeConstraint(value, attr_def, key) 345 346 attr_protos[key] = attr_value 347 348 349def _ExtractOutputStructure(op_type_name, op_def, attr_protos, 350 output_structure): 351 """Extracts `output_structure`. For use in _apply_op_helper.""" 352 for arg in op_def.output_arg: 353 if arg.number_attr: 354 n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i 355 output_structure.append(n) 356 elif arg.type_attr: 357 t = _AttrValue(attr_protos, arg.type_attr, op_type_name) 358 output_structure.append(None) 359 elif arg.type_list_attr: 360 t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name) 361 output_structure.append(len(t.list.type)) 362 else: 363 output_structure.append(None) 364 365 366def _CanExtractAttrsFastPath(op_def, keywords): 367 """Check if the fast path for _apply_op_helper is applicable.""" 368 # Check if all inputs are already tf.Tensor 369 for input_arg in op_def.input_arg: 370 value = keywords.get(input_arg.name, None) 371 if not isinstance(value, ops.Tensor): 372 return False 373 374 # Check that attrs are not `func` or `list(func)` type. 375 for attr_def in op_def.attr: 376 if attr_def.type == "func" or attr_def.type == "list(func)": 377 return False 378 379 return True 380 381 382def _CheckOpDeprecation(op_type_name, op_def, producer): 383 """Checks if the op is deprecated.""" 384 deprecation_version = op_def.deprecation.version 385 if deprecation_version and producer >= deprecation_version: 386 raise NotImplementedError( 387 f"Op {op_type_name} is not available in GraphDef version {producer}. " 388 f"It has been removed in version {deprecation_version}. " 389 f"{op_def.deprecation.explanation}.") 390 391 392def _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, 393 allowed_list_attr_map): 394 """Extracts the `default_type_attr_map` and `allowed_list_attr_map`.""" 395 # TODO(b/31302892): Currently the defaults don't work in the right 396 # way if you have two inputs, one of whose type resolution depends 397 # on the other. Handling this will require restructuring this code 398 # significantly. 399 for attr_def in op_def.attr: 400 if attr_def.type != "type": 401 continue 402 key = attr_def.name 403 if attr_def.HasField("default_value"): 404 default_type_attr_map[key] = dtypes.as_dtype( 405 attr_def.default_value.type) 406 if attr_def.HasField("allowed_values"): 407 allowed_list_attr_map[key] = attr_def.allowed_values.list.type 408 409 410def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, 411 keywords, default_type_attr_map, attrs, inputs, 412 input_types): 413 """Extracts `attrs`, `inputs`, and `input_types` in _apply_op_helper.""" 414 inferred_from = {} 415 for input_arg in op_def.input_arg: 416 input_name = input_arg.name 417 if input_name in keywords: 418 values = keywords.pop(input_name) 419 elif input_name + "_" in keywords: 420 # Handle the case where the name is a keyword or built-in 421 # for Python so we use the name + _ instead. 422 input_name += "_" 423 values = keywords.pop(input_name) 424 else: 425 raise TypeError(f"No argument for input {input_name} found in {op_def}") 426 427 # Goals: 428 # * Convert values to Tensors if it contains constants. 429 # * Verify that values is a list if that matches the input_arg's 430 # type. 431 # * If the input_arg's type is determined by attrs, either set 432 # those attrs and validate those attr values are legal (if 433 # they have not yet been set) or validate the input matches 434 # the type indicated by the attrs (if they have already been 435 # inferred via an earlier input). 436 # * If the input_arg has an explicit type, make sure the input 437 # conforms. 438 439 if _IsListParameter(input_arg): 440 if not _IsListValue(values): 441 raise TypeError( 442 f"Expected list for '{input_name}' argument to '{op_type_name}' " 443 f"Op, not {values}.") 444 # In cases where we expect all elements of the list to have the 445 # same dtype, try to cast non-Tensor elements to that type. 446 dtype = None 447 default_dtype = None 448 if input_arg.type != types_pb2.DT_INVALID: 449 dtype = input_arg.type 450 elif input_arg.number_attr: 451 if input_arg.type_attr in attrs: 452 dtype = attrs[input_arg.type_attr] 453 else: 454 for t in values: 455 if isinstance(t, ops.Tensor): 456 dtype = t.dtype 457 break 458 459 # dtype still not found, prefer using the default dtype 460 # from the attr. 461 if dtype is None and input_arg.type_attr in default_type_attr_map: 462 default_dtype = default_type_attr_map[input_arg.type_attr] 463 464 try: 465 if not input_arg.is_ref and dtype: 466 dtype = dtypes.as_dtype(dtype).base_dtype 467 values = ops.internal_convert_n_to_tensor( 468 values, 469 name=input_arg.name, 470 dtype=dtype if dtype else None, 471 preferred_dtype=default_dtype, 472 as_ref=input_arg.is_ref) 473 all_types = set(v.dtype.base_dtype for v in values) 474 if input_arg.number_attr and len(all_types) > 1: 475 # All types should match. 476 raise TypeError(f"Not all types matched for {input_arg.name} for " 477 f"{op_type_name}. Got {all_types}") 478 except (TypeError, ValueError): 479 # What types does the conversion function think values have? 480 observed_types = [] 481 for value in values: 482 try: 483 converted_value = ops.convert_to_tensor( 484 value, as_ref=input_arg.is_ref) 485 observed_types.append(converted_value.dtype.base_dtype.name) 486 except (TypeError, ValueError): 487 observed_types.append("<NOT CONVERTIBLE TO TENSOR>") 488 observed = ", ".join(observed_types) 489 490 prefix = ("Tensors in list passed to '%s' of '%s' Op have types [%s]" % 491 (input_name, op_type_name, observed)) 492 if input_arg.number_attr: 493 if input_arg.type != types_pb2.DT_INVALID: 494 raise TypeError(f"{prefix} that do not match expected type " 495 f"{dtype.name}.") 496 elif input_arg.type_attr in attrs: 497 raise TypeError(f"{prefix} that do not match type {dtype.name} " 498 "inferred from earlier arguments.") 499 else: 500 raise TypeError(f"{prefix} that don't all match.") 501 else: 502 raise TypeError(f"{prefix} that are invalid. Tensors: {values}") 503 504 types = [x.dtype for x in values] 505 inputs.extend(values) 506 else: 507 # In cases where we have an expected type, try to convert non-Tensor 508 # arguments to that type. 509 dtype = None 510 default_dtype = None 511 allowed_list = None 512 if input_arg.type != types_pb2.DT_INVALID: 513 dtype = input_arg.type 514 elif input_arg.type_attr in attrs: 515 dtype = attrs[input_arg.type_attr] 516 elif input_arg.type_attr in default_type_attr_map: 517 # The dtype could not be inferred solely from the inputs, 518 # so we prefer the attr's default, so code that adds a new attr 519 # with a default is backwards compatible. 520 default_dtype = default_type_attr_map[input_arg.type_attr] 521 allowed_list = allowed_list_attr_map.get(input_arg.type_attr) 522 523 try: 524 # First see if we can get a valid dtype with the default conversion 525 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may 526 # not list allowed dtypes, in which case we should skip this. 527 if dtype is None and allowed_list: 528 inferred = None 529 try: 530 inferred = ops.convert_to_tensor( 531 values, name=input_arg.name, as_ref=input_arg.is_ref) 532 except TypeError as err: 533 # When converting a python object such as a list of Dimensions, we 534 # need a dtype to be specified, thus tensor conversion may throw 535 # an exception which we will ignore and try again below. 536 pass 537 538 # If we did not match an allowed dtype, try again with the default 539 # dtype. This could be because we have an empty tensor and thus we 540 # picked the wrong type. 541 if inferred is not None and inferred.dtype in allowed_list: 542 values = inferred 543 else: 544 values = ops.convert_to_tensor( 545 values, 546 name=input_arg.name, 547 as_ref=input_arg.is_ref, 548 preferred_dtype=default_dtype) 549 else: 550 values = ops.convert_to_tensor( 551 values, 552 name=input_arg.name, 553 dtype=dtype, 554 as_ref=input_arg.is_ref, 555 preferred_dtype=default_dtype) 556 except TypeError as err: 557 if dtype is None: 558 raise err 559 else: 560 raise TypeError( 561 f"Expected {dtypes.as_dtype(dtype).name} passed to parameter " 562 f"'{input_arg.name}' of op '{op_type_name}', got " 563 f"{repr(values)} of type '{type(values).__name__}' instead. " 564 f"Error: {err}") 565 except ValueError: 566 # What type does convert_to_tensor think it has? 567 try: 568 observed = ops.convert_to_tensor( 569 values, as_ref=input_arg.is_ref).dtype.name 570 except ValueError as err: 571 raise ValueError( 572 f"Tried to convert '{input_name}' to a tensor and failed. " 573 f"Error: {err}") 574 prefix = ("Input '%s' of '%s' Op has type %s that does not match" % 575 (input_name, op_type_name, observed)) 576 if input_arg.type != types_pb2.DT_INVALID: 577 raise TypeError(f"{prefix} expected type of " 578 f"{dtypes.as_dtype(input_arg.type).name}.") 579 else: 580 # Update the maps with the default, if needed. 581 k = input_arg.type_attr 582 if k in default_type_attr_map: 583 if k not in attrs: 584 attrs[k] = default_type_attr_map[k] 585 if k not in inferred_from: 586 inferred_from[k] = "Default in OpDef" 587 588 raise TypeError( 589 f"{prefix} type " 590 f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " 591 f"argument '{inferred_from[input_arg.type_attr]}'.") 592 593 types = [values.dtype] 594 inputs.append(values) 595 base_types = [x.base_dtype for x in types] 596 597 if input_arg.number_attr: 598 # <number-attr> * <type> or <number-attr> * <type-attr> 599 if input_arg.number_attr in attrs: 600 if len(values) != attrs[input_arg.number_attr]: 601 raise ValueError( 602 f"List argument '{input_name}' to '{op_type_name}' Op with " 603 f"length {len(values)} must match length " 604 f"{attrs[input_arg.number_attr]} of argument " 605 f"'{inferred_from[input_arg.number_attr]}'.") 606 else: 607 attrs[input_arg.number_attr] = len(values) 608 inferred_from[input_arg.number_attr] = input_name 609 num_attr = _Attr(op_def, input_arg.number_attr) 610 if num_attr.has_minimum and len(values) < num_attr.minimum: 611 raise ValueError( 612 f"List argument '{input_name}' to '{op_type_name}' Op with " 613 f"length {len(values)} shorter than minimum length " 614 f"{num_attr.minimum}.") 615 # All tensors must have the same base type. 616 if any(bt != base_types[0] for bt in base_types): 617 raise TypeError( 618 f"All tensors passed to '{input_name}' of '{op_type_name}' Op " 619 f"must have the same type. Got {base_types} instead.") 620 if input_arg.type != types_pb2.DT_INVALID: 621 # <number-attr> * <type> case 622 if base_types and base_types[0] != input_arg.type: 623 assert False, "Unreachable" 624 elif input_arg.type_attr in attrs: 625 # <number-attr> * <type-attr> case, where <type-attr> already 626 # has an inferred value. 627 if base_types and base_types[0] != attrs[input_arg.type_attr]: 628 assert False, "Unreachable" 629 else: 630 # <number-attr> * <type-attr> case, where we are now setting 631 # the <type-attr> based on this input 632 if not base_types: 633 # If it's in default_type_attr_map, then wait to set it 634 # (in "process remaining attrs", below). 635 if input_arg.type_attr not in default_type_attr_map: 636 raise TypeError( 637 "Don't know how to infer type variable from empty input " 638 f"list passed to input '{input_name}' of '{op_type_name}' " 639 "Op.") 640 else: 641 attrs[input_arg.type_attr] = base_types[0] 642 inferred_from[input_arg.type_attr] = input_name 643 type_attr = _Attr(op_def, input_arg.type_attr) 644 _SatisfiesTypeConstraint( 645 base_types[0], type_attr, param_name=input_name) 646 elif input_arg.type_attr: 647 # <type-attr> 648 attr_value = base_types[0] 649 if input_arg.type_attr in attrs: 650 if attrs[input_arg.type_attr] != attr_value: 651 raise TypeError( 652 f"Input '{input_name}' of '{op_type_name}' Op has type " 653 f"{dtypes.as_dtype(attr_value).name} that does not match type " 654 f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " 655 f"argument '{inferred_from[input_arg.type_attr]}'.") 656 else: 657 for base_type in base_types: 658 _SatisfiesTypeConstraint( 659 base_type, 660 _Attr(op_def, input_arg.type_attr), 661 param_name=input_name) 662 attrs[input_arg.type_attr] = attr_value 663 inferred_from[input_arg.type_attr] = input_name 664 elif input_arg.type_list_attr: 665 # <type-list-attr> 666 attr_value = base_types 667 if input_arg.type_list_attr in attrs: 668 if attrs[input_arg.type_list_attr] != attr_value: 669 actual_types = ", ".join(dtypes.as_dtype(x).name for x in attr_value) 670 expected_types = ", ".join( 671 dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]) 672 raise TypeError( 673 f"Input '{input_name}' of '{op_type_name}' Op has type list of " 674 f"{actual_types} that does not match type list {expected_types}" 675 f" of argument '{inferred_from[input_arg.type_list_attr]}'.") 676 else: 677 for base_type in base_types: 678 _SatisfiesTypeConstraint( 679 base_type, 680 _Attr(op_def, input_arg.type_list_attr), 681 param_name=input_name) 682 attrs[input_arg.type_list_attr] = attr_value 683 inferred_from[input_arg.type_list_attr] = input_name 684 else: 685 # single Tensor with specified type 686 if base_types[0] != input_arg.type: 687 assert False, "Unreachable" 688 689 if input_arg.is_ref: 690 if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access 691 raise TypeError( 692 f"'{op_type_name}' Op requires that input '{input_name}' be a " 693 "mutable tensor (e.g.: a tf.Variable)") 694 input_types.extend(types) 695 else: 696 input_types.extend(base_types) 697 698 699def _ExtractRemainingAttrs(op_type_name, op_def, keywords, 700 default_type_attr_map, attrs): 701 """Extracts the remaining attributes into `attrs` in _apply_op_helper.""" 702 for attr in op_def.attr: 703 # Skip attrs that have already had their values inferred 704 if attr.name in attrs: 705 if attr.name in keywords: 706 raise TypeError( 707 f"Should not specify value for inferred attr '{attr.name}' for " 708 f"{op_type_name}.") 709 continue 710 if attr.name in keywords: 711 attrs[attr.name] = keywords.pop(attr.name) 712 elif attr.name + "_" in keywords: 713 # Attrs whose names match Python keywords have an extra '_' 714 # appended, so we must check for that as well. 715 attrs[attr.name] = keywords.pop(attr.name + "_") 716 elif attr.name in default_type_attr_map: 717 attrs[attr.name] = default_type_attr_map[attr.name] 718 else: 719 raise TypeError(f"No argument found for attr {attr.name} for " 720 f"{op_type_name}") 721 722 723def _GetOpDef(op_type_name, keywords): 724 """Returns the OpDef, Graph and Producer. For use in _apply_op_helper.""" 725 op_def = op_def_registry.get(op_type_name) 726 if op_def is None: 727 raise RuntimeError(f"Unrecognized Op name {op_type_name}") 728 729 # Determine the graph context. 730 try: 731 # Need to flatten all the arguments into a list. 732 # pylint: disable=protected-access 733 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 734 producer = g.graph_def_versions.producer 735 # pylint: enable=protected-access 736 except AssertionError as e: 737 raise RuntimeError( 738 f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") 739 740 return op_def, g, producer 741 742 743def _CheckAllInputsUsed(op_type_name, keywords): 744 """Ensures all inputs passed into _apply_op_helper were used.""" 745 if keywords: 746 all_keywords = ", ".join(sorted(keywords.keys())) 747 raise TypeError(f"{op_type_name} got unexpected keyword arguments: " 748 f"{all_keywords}.") 749 750 751def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name 752 """Implementation of apply_op that returns output_structure, op.""" 753 754 op_def, g, producer = _GetOpDef(op_type_name, keywords) 755 name = name if name else op_type_name 756 757 attrs, attr_protos = {}, {} 758 default_type_attr_map, allowed_list_attr_map = {}, {} 759 inputs, input_types, output_structure = [], [], [] 760 fallback = True 761 762 if (_CanExtractAttrsFastPath(op_def, keywords) and 763 flags.config().graph_building_optimization.value()): 764 fallback = False 765 attr_protos, inputs, input_types, output_structure = ( 766 op_def_library_pybind.process_inputs(op_type_name, producer, keywords)) 767 768 if fallback: 769 _CheckOpDeprecation(op_type_name, op_def, producer) 770 _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, 771 allowed_list_attr_map) 772 773 # Requires that op_def has passed validation (using the C++ 774 # ValidateOpDef() from ../framework/op_def_util.h). 775 with g.as_default(), ops.name_scope(name) as scope: 776 if fallback: 777 _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, 778 keywords, default_type_attr_map, attrs, inputs, 779 input_types) 780 _ExtractRemainingAttrs(op_type_name, op_def, keywords, 781 default_type_attr_map, attrs) 782 _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos) 783 del attrs # attrs is no longer authoritative, use attr_protos instead 784 _ExtractOutputStructure(op_type_name, op_def, attr_protos, 785 output_structure) 786 _CheckAllInputsUsed(op_type_name, keywords) 787 788 # NOTE(mrry): We add an explicit colocation constraint between 789 # the newly created op and any of its reference-typed inputs. 790 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) 791 if arg.is_ref] 792 with _MaybeColocateWith(must_colocate_inputs): 793 # Add Op to graph 794 # pylint: disable=protected-access 795 op = g._create_op_internal(op_type_name, inputs, dtypes=None, 796 name=scope, input_types=input_types, 797 attrs=attr_protos, op_def=op_def) 798 799 # `outputs` is returned as a separate return value so that the output 800 # tensors can the `op` per se can be decoupled so that the 801 # `op_callbacks` can function properly. See framework/op_callbacks.py 802 # for more details. 803 outputs = op.outputs 804 # Conditionally invoke tfdbg v2's op callback(s). 805 if op_callbacks.should_invoke_op_callbacks(): 806 callback_outputs = op_callbacks.invoke_op_callbacks( 807 op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), 808 op_name=op.name, graph=g) 809 if callback_outputs is not None: 810 outputs = callback_outputs 811 812 return output_structure, op_def.is_stateful, op, outputs 813 814 815def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name 816 """Encodes a Python value as an `AttrValue` proto message. 817 818 Args: 819 value: The value to convert. 820 attr_type: The value type (string) -- see the AttrValue proto definition for 821 valid strings. 822 arg_name: Argument name (for error messages). 823 824 Returns: 825 An AttrValue proto message that encodes `value`. 826 """ 827 attr_value = attr_value_pb2.AttrValue() 828 829 if attr_type.startswith("list("): 830 if not _IsListValue(value): 831 raise TypeError(f"Expected list for attr {arg_name}, obtained " 832 f"{type(value).__name__} instead.") 833 834 if attr_type == "string": 835 attr_value.s = _MakeStr(value, arg_name) 836 elif attr_type == "list(string)": 837 attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) 838 elif attr_type == "int": 839 attr_value.i = _MakeInt(value, arg_name) 840 elif attr_type == "list(int)": 841 attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) 842 elif attr_type == "float": 843 attr_value.f = _MakeFloat(value, arg_name) 844 elif attr_type == "list(float)": 845 attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) 846 elif attr_type == "bool": 847 attr_value.b = _MakeBool(value, arg_name) 848 elif attr_type == "list(bool)": 849 attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) 850 elif attr_type == "type": 851 attr_value.type = _MakeType(value, arg_name) 852 elif attr_type == "list(type)": 853 attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) 854 elif attr_type == "shape": 855 attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) 856 elif attr_type == "list(shape)": 857 attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) 858 elif attr_type == "tensor": 859 attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) 860 elif attr_type == "list(tensor)": 861 attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) 862 elif attr_type == "func": 863 attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) 864 elif attr_type == "list(func)": 865 attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) 866 else: 867 raise TypeError(f"Unrecognized Attr type {attr_type} for {arg_name}.") 868 return attr_value 869# LINT.ThenChange(//tensorflow/python/framework/op_def_library_pybind.cc) 870 871 872# The following symbols are used by op_def_util.cc. 873_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) 874_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype) 875_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape) 876_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape) 877_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto) 878_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse) 879_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor) 880