1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A library of common shape functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import six.moves 22 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.framework import cpp_shape_inference_pb2 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_util 29 30 31def has_fully_defined_shape(tensor): 32 """Returns true if tensor has a fully defined shape.""" 33 return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined() 34 35 36def rank(tensor): 37 """Return a rank if it is a tensor, else return None.""" 38 if isinstance(tensor, ops.Tensor): 39 return tensor._rank() # pylint: disable=protected-access 40 return None 41 42 43def scalar_shape(unused_op): 44 """Shape function for ops that output a scalar value.""" 45 return [tensor_shape.scalar()] 46 47 48def unchanged_shape(op): 49 """Shape function for ops that output a tensor like their first input.""" 50 return [op.inputs[0].get_shape()] 51 52 53def unchanged_shape_with_rank(rank): 54 """Returns a shape function for ops that constrain the rank of their input. 55 56 Args: 57 rank: The exact rank of the input and output. 58 59 Returns: 60 A shape function for ops that output a tensor of the same size as their 61 input, with a particular rank. 62 """ 63 64 def _ShapeFunction(op): 65 return [op.inputs[0].get_shape().with_rank(rank)] 66 67 return _ShapeFunction 68 69 70def unchanged_shape_with_rank_at_least(rank): 71 """Returns a shape function for ops that constrain the rank of their input. 72 73 Args: 74 rank: A lower bound on the rank of the input and output. 75 76 Returns: 77 A shape function for ops that output a tensor of the same size as their 78 input, with a particular rank. 79 """ 80 81 def _ShapeFunction(op): 82 return [op.inputs[0].get_shape().with_rank_at_least(rank)] 83 84 return _ShapeFunction 85 86 87def unchanged_shape_with_rank_at_most(rank): 88 """Returns a shape function for ops that constrain the rank of their input. 89 90 Args: 91 rank: An upper bound on the rank of the input and output. 92 93 Returns: 94 A shape function for ops that output a tensor of the same size as their 95 input, with a particular rank. 96 """ 97 98 def _ShapeFunction(op): 99 return [op.inputs[0].get_shape().with_rank_at_most(rank)] 100 101 return _ShapeFunction 102 103 104def matmul_shape(op): 105 """Shape function for a MatMul op.""" 106 a_shape = op.inputs[0].get_shape().with_rank(2) 107 transpose_a = op.get_attr("transpose_a") 108 b_shape = op.inputs[1].get_shape().with_rank(2) 109 transpose_b = op.get_attr("transpose_b") 110 output_rows = a_shape[1] if transpose_a else a_shape[0] 111 output_cols = b_shape[0] if transpose_b else b_shape[1] 112 inner_a = a_shape[0] if transpose_a else a_shape[1] 113 inner_b = b_shape[1] if transpose_b else b_shape[0] 114 inner_a.assert_is_compatible_with(inner_b) 115 return [tensor_shape.TensorShape([output_rows, output_cols])] 116 117 118def get_conv_output_size(input_size, filter_size, strides, padding_type): 119 """Returns the spatial size of a n-d convolution/pooling output.""" 120 input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size]) 121 filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size]) 122 strides = [int(x) for x in strides] 123 124 if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size): 125 return input_size 126 127 if any(x is not None and y is not None and x > y for x, y in 128 zip(filter_size, input_size)): 129 raise ValueError("Filter must not be larger than the input: " 130 "Filter: %r Input: %r" % (filter_size, input_size)) 131 132 if padding_type == b"VALID": 133 134 def _valid(in_dim, k_dim, s_dim): 135 if in_dim is not None and k_dim is not None: 136 return (in_dim - k_dim + s_dim) // s_dim 137 else: 138 return None 139 140 output_size = [ 141 _valid(in_dim, k_dim, s_dim) 142 for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides) 143 ] 144 elif padding_type == b"SAME": 145 146 def _same(in_dim, s_dim): 147 if in_dim is not None: 148 return (in_dim + s_dim - 1) // s_dim 149 else: 150 return None 151 152 output_size = [_same(in_dim, s_dim) 153 for in_dim, s_dim in zip(input_size, strides)] 154 else: 155 raise ValueError("Invalid padding: %r" % padding_type) 156 157 return tuple(output_size) 158 159 160def get2d_conv_output_size(input_height, input_width, filter_height, 161 filter_width, row_stride, col_stride, padding_type): 162 """Returns the number of rows and columns in a convolution/pooling output.""" 163 return get_conv_output_size((input_height, input_width), 164 (filter_height, filter_width), 165 (row_stride, col_stride), padding_type) 166 167 168def conv2d_shape(op): 169 """Shape function for a Conv2D op. 170 171 This op has two inputs: 172 173 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 174 * filter, a 4D tensor with shape = [filter_rows, filter_cols, 175 depth_in, depth_out] 176 177 The output is a 4D tensor with shape = [batch_size, out_rows, 178 out_cols, depth_out], where out_rows and out_cols depend on the 179 value of the op's "padding" and "strides" attrs. 180 181 Args: 182 op: A Conv2D Operation. 183 184 Returns: 185 A list containing the Shape of the Conv2D output. 186 187 Raises: 188 ValueError: If the shapes of the input or filter are incompatible. 189 """ 190 input_shape = op.inputs[0].get_shape().with_rank(4) 191 filter_shape = op.inputs[1].get_shape().with_rank(4) 192 193 try: 194 data_format = op.get_attr("data_format") 195 except ValueError: 196 data_format = None 197 198 if data_format == b"NCHW": 199 # Convert input shape to the default NHWC for inference. 200 input_shape = [input_shape[0], input_shape[2], input_shape[3], 201 input_shape[1]] 202 203 batch_size = input_shape[0] 204 in_rows = input_shape[1] 205 in_cols = input_shape[2] 206 207 filter_rows = filter_shape[0] 208 filter_cols = filter_shape[1] 209 depth_out = filter_shape[3] 210 # Check that the input depths are compatible. 211 input_shape[3].assert_is_compatible_with(filter_shape[2]) 212 213 if data_format == b"NCHW": 214 stride_b, stride_d, stride_r, stride_c = op.get_attr("strides") 215 else: 216 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 217 218 if stride_b != 1 or stride_d != 1: 219 raise ValueError("Current implementation does not yet support " 220 "strides in the batch and depth dimensions.") 221 # TODO(mrry,shlens): Raise an error if the stride would cause 222 # information in the input to be ignored. This will require a change 223 # in the kernel implementation. 224 padding = op.get_attr("padding") 225 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, 226 filter_cols, stride_r, stride_c, 227 padding) 228 229 output_shape = [batch_size, out_rows, out_cols, depth_out] 230 if data_format == b"NCHW": 231 # Convert output shape back to NCHW. 232 output_shape = [output_shape[0], output_shape[3], output_shape[1], 233 output_shape[2]] 234 return [tensor_shape.TensorShape(output_shape)] 235 236 237def depthwise_conv2d_native_shape(op): 238 """Shape function for a DepthwiseConv2D op. 239 240 This op has two inputs: 241 242 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 243 * filter, a 4D tensor with shape = [filter_rows, filter_cols, 244 depth_in, depthwise_multiplier] 245 246 The output is a 4D tensor with shape = [batch_size, out_rows, 247 out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend 248 on the value of the op's "padding" and "strides" attrs. 249 250 Args: 251 op: A DepthwiseConv2dNative Operation. 252 253 Returns: 254 A list containing the Shape of the DepthwiseConv2DNative output. 255 256 Raises: 257 ValueError: If the shapes of the input or filter are incompatible. 258 """ 259 input_shape = op.inputs[0].get_shape().with_rank(4) 260 filter_shape = op.inputs[1].get_shape().with_rank(4) 261 262 batch_size = input_shape[0] 263 in_rows = input_shape[1] 264 in_cols = input_shape[2] 265 266 filter_rows = filter_shape[0] 267 filter_cols = filter_shape[1] 268 depth_out = filter_shape[3] * filter_shape[2] 269 # Check that the input depths are compatible. 270 input_shape[3].assert_is_compatible_with(filter_shape[2]) 271 272 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 273 if stride_b != 1 or stride_d != 1: 274 raise ValueError("Current implementation does not yet support " 275 "strides in the batch and depth dimensions.") 276 if stride_r != stride_c: 277 # TODO(shlens): Add support for this. 278 raise ValueError("Current implementation only supports equal length " 279 "strides in the row and column dimensions.") 280 281 # TODO(mrry,shlens): Raise an error if the stride would cause 282 # information in the input to be ignored. This will require a change 283 # in the kernel implementation. 284 stride = stride_r 285 padding = op.get_attr("padding") 286 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, 287 filter_cols, stride, stride, 288 padding) 289 290 return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] 291 292 293def separable_conv2d_shape(op): 294 """Shape function for a SeparableConv2D op. 295 296 This op has three inputs: 297 298 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 299 300 * depthwise_filter, a 4D tensor with shape = [filter_rows, 301 filter_cols, depth_in, depth_multiplier] 302 303 * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in * 304 depth_multiplier, depth_out] 305 306 The output is a 4D tensor with shape = [batch_size, out_rows, 307 out_cols, depth_out], where out_rows and out_cols depend on the 308 value of the op's "padding" and "strides" attrs. 309 310 Args: 311 op: A SeparableConv2D Operation. 312 313 Returns: 314 A list containing the Shape of the SeparableConv2D output. 315 316 Raises: 317 ValueError: If the shapes of the input or filter are incompatible. 318 """ 319 input_shape = op.inputs[0].get_shape().with_rank(4) 320 depthwise_filter_shape = op.inputs[1].get_shape().merge_with( 321 tensor_shape.TensorShape([None, None, input_shape[3], None])) 322 pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3] 323 324 pointwise_filter_shape = op.inputs[2].get_shape().merge_with( 325 tensor_shape.TensorShape([1, 1, pointwise_depth_in, None])) 326 327 batch_size = input_shape[0] 328 in_rows = input_shape[1] 329 in_cols = input_shape[2] 330 331 filter_rows = depthwise_filter_shape[0] 332 filter_cols = depthwise_filter_shape[1] 333 depth_out = pointwise_filter_shape[3] 334 335 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 336 if stride_b != 1 or stride_d != 1: 337 raise ValueError("Current implementation does not yet support " 338 "strides in the batch and depth dimensions.") 339 if stride_r != stride_c: 340 # TODO(shlens): Add support for this. 341 raise ValueError("Current implementation only supports equal length " 342 "strides in the row and column dimensions.") 343 344 # TODO(mrry,shlens): Raise an error if the stride would cause 345 # information in the input to be ignored. This will require a change 346 # in the kernel implementation. 347 stride = stride_r 348 padding = op.get_attr("padding") 349 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, 350 filter_cols, stride, stride, 351 padding) 352 353 return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] 354 355 356def avg_pool_shape(op): 357 """Shape function for an AvgPool op. 358 359 This op has one input: 360 361 * input, a 4D tensor with shape = [batch_size, rows, cols, depth] 362 363 The output is a 4D tensor with shape = [batch_size, out_rows, 364 out_cols, depth_out], where out_rows and out_cols depend on the 365 value of the op's "ksize", "strides", and "padding" attrs. 366 367 Args: 368 op: An AvgPool Operation. 369 370 Returns: 371 A single-element list containing the Shape of the AvgPool output. 372 373 Raises: 374 ValueError: If the shape of the input is invalid or incompatible with 375 the values of the attrs. 376 """ 377 input_shape = op.inputs[0].get_shape().with_rank(4) 378 try: 379 data_format = op.get_attr("data_format") 380 except ValueError: 381 data_format = None 382 383 if data_format == b"NCHW": 384 # Convert input shape to the default NHWC for inference. 385 input_shape = [input_shape[0], input_shape[2], input_shape[3], 386 input_shape[1]] 387 388 if data_format == b"NCHW": 389 ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize") 390 stride_b, stride_d, stride_r, stride_c = op.get_attr("strides") 391 else: 392 ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") 393 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 394 395 batch_size = input_shape[0] 396 in_rows = input_shape[1] 397 in_cols = input_shape[2] 398 depth = input_shape[3] 399 400 if ksize_b != 1 or ksize_d != 1: 401 raise ValueError("Current implementation does not support pooling " 402 "in the batch and depth dimensions.") 403 if stride_b != 1 or stride_d != 1: 404 raise ValueError("Current implementation does not support strides " 405 "in the batch and depth dimensions.") 406 407 # TODO(mrry,shlens): Raise an error if the stride would cause 408 # information in the input to be ignored. This will require a change 409 # in the kernel implementation. 410 padding = op.get_attr("padding") 411 412 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r, 413 ksize_c, stride_r, stride_c, 414 padding) 415 416 output_shape = [batch_size, out_rows, out_cols, depth] 417 if data_format == b"NCHW": 418 # Convert output shape back to NCHW. 419 output_shape = [output_shape[0], output_shape[3], output_shape[1], 420 output_shape[2]] 421 return [tensor_shape.TensorShape(output_shape)] 422 423 424def max_pool_shape(op): 425 """Shape function for a MaxPool op. 426 427 This op has one input: 428 429 * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] 430 431 The output is a 4D tensor with shape = [batch_size, out_rows, 432 out_cols, depth_out], where out_rows, out_cols, and depth_out depend 433 on the value of the op's "ksize", "strides", and "padding" attrs. 434 435 Args: 436 op: A MaxPool Operation. 437 438 Returns: 439 A single-element list containing the Shape of the MaxPool output. 440 441 Raises: 442 ValueError: If the shape of the input is invalid or incompatible with 443 the values of the attrs. 444 """ 445 input_shape = op.inputs[0].get_shape().with_rank(4) 446 try: 447 data_format = op.get_attr("data_format") 448 except ValueError: 449 data_format = None 450 451 if data_format == b"NCHW": 452 # Convert input shape to the default NHWC for inference. 453 input_shape = [input_shape[0], input_shape[2], input_shape[3], 454 input_shape[1]] 455 456 if data_format == b"NCHW": 457 ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize") 458 stride_b, stride_d, stride_r, stride_c = op.get_attr("strides") 459 else: 460 ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") 461 stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") 462 463 batch_size = input_shape[0] 464 in_rows = input_shape[1] 465 in_cols = input_shape[2] 466 depth = input_shape[3] 467 468 if ksize_b != 1: 469 raise ValueError("Current implementation does not support pooling " 470 "in the batch dimension.") 471 if stride_b != 1: 472 raise ValueError("Current implementation does not support strides " 473 "in the batch dimension.") 474 475 if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1): 476 raise ValueError("MaxPooling supports exactly one of pooling across depth " 477 "or pooling across width/height.") 478 479 # TODO(mrry,shlens): Raise an error if the stride would cause 480 # information in the input to be ignored. This will require a change 481 # in the kernel implementation. 482 if ksize_d == 1: 483 padding = op.get_attr("padding") 484 out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r, 485 ksize_c, stride_r, stride_c, 486 padding) 487 output_shape = [batch_size, out_rows, out_cols, depth] 488 else: 489 if depth % ksize_d > 0: 490 raise ValueError("Depthwise max pooling requires the depth window " 491 "to evenly divide the input depth.") 492 if stride_d != ksize_d: 493 raise ValueError("Depthwise max pooling requires the depth window " 494 "to equal the depth stride.") 495 output_shape = [batch_size, in_rows, in_cols, depth // ksize_d] 496 497 if data_format == b"NCHW": 498 # Convert output shape back to NCHW. 499 output_shape = [output_shape[0], output_shape[3], output_shape[1], 500 output_shape[2]] 501 return [tensor_shape.TensorShape(output_shape)] 502 503 504def no_outputs(unused_op): 505 """Shape function for use with ops that have no outputs.""" 506 return [] 507 508 509def unknown_shape(op): 510 """Shape function for use with ops whose output shapes are unknown.""" 511 return [tensor_shape.unknown_shape() for _ in op.outputs] 512 513 514def _broadcast_shape_helper(shape_x, shape_y): 515 """Helper functions for is_broadcast_compatible and broadcast_shape. 516 517 Args: 518 shape_x: A `TensorShape` 519 shape_y: A `TensorShape` 520 521 Returns: 522 Returns None if the shapes are not broadcast compatible, 523 a list of the broadcast dimensions otherwise. 524 """ 525 # To compute the broadcasted dimensions, we zip together shape_x and shape_y, 526 # and pad with 1 to make them the same length. 527 broadcasted_dims = reversed(list(six.moves.zip_longest( 528 reversed(shape_x.dims), 529 reversed(shape_y.dims), 530 fillvalue=tensor_shape.Dimension(1)))) 531 # Next we combine the dimensions according to the numpy broadcasting rules. 532 # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html 533 return_dims = [] 534 for (dim_x, dim_y) in broadcasted_dims: 535 if dim_x.value is None or dim_y.value is None: 536 # One or both dimensions is unknown. If either dimension is greater than 537 # 1, we assume that the program is correct, and the other dimension will 538 # be broadcast to match it. 539 # TODO(mrry): If we eliminate the shape checks in C++, we must still 540 # assert that the unknown dim is either 1 or the same as the known dim. 541 if dim_x.value is not None and dim_x.value > 1: 542 return_dims.append(dim_x) 543 elif dim_y.value is not None and dim_y.value > 1: 544 return_dims.append(dim_y) 545 else: 546 return_dims.append(None) 547 elif dim_x.value == 1: 548 # We will broadcast dim_x to dim_y. 549 return_dims.append(dim_y) 550 elif dim_y.value == 1: 551 # We will broadcast dim_y to dim_x. 552 return_dims.append(dim_x) 553 elif dim_x.value == dim_y.value: 554 # The dimensions are compatible, so output is the same size in that 555 # dimension. 556 return_dims.append(dim_x.merge_with(dim_y)) 557 else: 558 return None 559 return return_dims 560 561 562def is_broadcast_compatible(shape_x, shape_y): 563 """Returns True if `shape_x` and `shape_y` are broadcast compatible. 564 565 Args: 566 shape_x: A `TensorShape` 567 shape_y: A `TensorShape` 568 569 Returns: 570 True if a shape exists that both `shape_x` and `shape_y` can be broadcasted 571 to. False otherwise. 572 """ 573 if shape_x.ndims is None or shape_y.ndims is None: 574 return False 575 return _broadcast_shape_helper(shape_x, shape_y) is not None 576 577 578def broadcast_shape(shape_x, shape_y): 579 """Returns the broadcasted shape between `shape_x` and `shape_y`. 580 581 Args: 582 shape_x: A `TensorShape` 583 shape_y: A `TensorShape` 584 585 Returns: 586 A `TensorShape` representing the broadcasted shape. 587 588 Raises: 589 ValueError: If the two shapes can not be broadcasted. 590 """ 591 if shape_x.ndims is None or shape_y.ndims is None: 592 return tensor_shape.unknown_shape() 593 return_dims = _broadcast_shape_helper(shape_x, shape_y) 594 if return_dims is None: 595 raise ValueError("Incompatible shapes for broadcasting: %s and %s" 596 % (shape_x, shape_y)) 597 return tensor_shape.TensorShape(return_dims) 598 599 600def call_cpp_shape_fn(op, require_shape_fn=True): 601 """A shape function that delegates to the registered C++ shape function. 602 603 Args: 604 op: the node in the graph for which to compute output shapes. 605 require_shape_fn: If true, and the C++ shape function is not registered 606 in the current binary then an exception is raised; otherwise, if the 607 C++ shape function is not registered then unknown_shape is used. 608 609 Returns: 610 A dictionary with the following keys: 611 shapes: A TensorShape list of the output shapes of the op, as computed 612 using the C++ shape inference function registered for the op. 613 handle_shapes: A TensorShape list of the shapes for handle outputs, if 614 any. 615 handle_dtypes: A list of DataType enums for the handle outputs, if any. 616 617 Raises: 618 ValueError: If the C++ shape function returned an error (e.g. because the 619 shapes of the inputs are of the wrong rank or otherwise incompatible 620 according to the shape function). 621 RuntimeError: If the C++ shape function is not registered and 622 <require_shape_fn> is True. 623 """ 624 if op.type == "Const": 625 # To avoid serializing large constants, we special-case constant 626 # here, even though it has a C++ shape function. When Python 627 # calls the C / C-API directly, we should be able to remove this. 628 return { 629 "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)], 630 "handle_data": [None] 631 } 632 633 input_tensors_needed = [] 634 input_tensors_as_shapes_needed = [] 635 636 while True: 637 res = _call_cpp_shape_fn_impl(op, input_tensors_needed, 638 input_tensors_as_shapes_needed, 639 require_shape_fn) 640 if not isinstance(res, dict): 641 # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op). 642 return res 643 644 # See if we need to evaluate some inputs. 645 if not res["inputs_needed"]: 646 return res 647 p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded() 648 p = p.FromString(res["inputs_needed"]) 649 changed = False 650 for idx in p.input_tensors_needed: 651 if idx not in input_tensors_needed: 652 input_tensors_needed.append(idx) 653 changed = True 654 for idx in p.input_tensors_as_shapes_needed: 655 if idx not in input_tensors_as_shapes_needed: 656 input_tensors_as_shapes_needed.append(idx) 657 changed = True 658 if not changed: 659 return res 660 661 662def _call_cpp_shape_fn_impl( 663 op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn): 664 """Core implementation of call_cpp_shape_fn.""" 665 graph_def_version = op.graph.graph_def_versions.producer 666 node_def_str = op.node_def.SerializeToString() 667 668 def tensor_to_inference_result(t): 669 r = cpp_shape_inference_pb2.CppShapeInferenceResult() 670 r.shape.CopyFrom(t.get_shape().as_proto()) 671 # pylint: disable=protected-access 672 if t._handle_data is not None: 673 r.handle_data.CopyFrom(t._handle_data) 674 # pylint: enable=protected-access 675 return r.SerializeToString() 676 input_shapes = [tensor_to_inference_result(i) for i in op.inputs] 677 678 input_tensors = [None for i in input_shapes] 679 for idx in input_tensors_needed: 680 v = tensor_util.constant_value(op.inputs[idx]) 681 if v is not None: 682 input_tensors[idx] = np.asarray(v) 683 684 serialized_unknown_shape = ( 685 tensor_shape.TensorShape(None).as_proto().SerializeToString()) 686 arr = [serialized_unknown_shape for i in input_shapes] 687 for idx in input_tensors_as_shapes_needed: 688 s = tensor_util.constant_value_as_shape(op.inputs[idx]) 689 if s is not None: 690 arr[idx] = s.as_proto().SerializeToString() 691 input_tensors_as_shapes = arr 692 693 missing_shape_fn = False 694 try: 695 with errors.raise_exception_on_not_ok_status() as status: 696 output = pywrap_tensorflow.RunCppShapeInference( 697 graph_def_version, node_def_str, input_shapes, input_tensors, 698 input_tensors_as_shapes, status) 699 except errors.InvalidArgumentError as err: 700 if err.message.startswith("No shape inference function exists for op"): 701 missing_shape_fn = True 702 else: 703 raise ValueError(err.message) 704 705 if missing_shape_fn: 706 if require_shape_fn: 707 raise RuntimeError( 708 "No C++ shape function registered for standard op: %s" % op.type) 709 return unknown_shape(op) 710 711 output_shapes = output[:-1] 712 713 # Convert TensorShapeProto values in output_shapes. 714 result_protos = [ 715 cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s) 716 for s in output_shapes 717 ] 718 result = [r.shape for r in result_protos] 719 result_handle_data = [ 720 r.handle_data if r.handle_data.is_set else None for r in result_protos 721 ] 722 723 return { 724 "shapes": result, 725 "handle_data": result_handle_data, 726 "inputs_needed": output[-1] 727 } 728 729# pylint: disable=protected-access 730ops._set_call_cpp_shape_fn(call_cpp_shape_fn) 731# pylint: enable=protected-access 732