• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A library of common shape functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import six.moves
22
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.framework import cpp_shape_inference_pb2
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29
30
31def has_fully_defined_shape(tensor):
32  """Returns true if tensor has a fully defined shape."""
33  return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
34
35
36def rank(tensor):
37  """Return a rank if it is a tensor, else return None."""
38  if isinstance(tensor, ops.Tensor):
39    return tensor._rank()  # pylint: disable=protected-access
40  return None
41
42
43def scalar_shape(unused_op):
44  """Shape function for ops that output a scalar value."""
45  return [tensor_shape.scalar()]
46
47
48def unchanged_shape(op):
49  """Shape function for ops that output a tensor like their first input."""
50  return [op.inputs[0].get_shape()]
51
52
53def unchanged_shape_with_rank(rank):
54  """Returns a shape function for ops that constrain the rank of their input.
55
56  Args:
57    rank: The exact rank of the input and output.
58
59  Returns:
60    A shape function for ops that output a tensor of the same size as their
61    input, with a particular rank.
62  """
63
64  def _ShapeFunction(op):
65    return [op.inputs[0].get_shape().with_rank(rank)]
66
67  return _ShapeFunction
68
69
70def unchanged_shape_with_rank_at_least(rank):
71  """Returns a shape function for ops that constrain the rank of their input.
72
73  Args:
74    rank: A lower bound on the rank of the input and output.
75
76  Returns:
77    A shape function for ops that output a tensor of the same size as their
78    input, with a particular rank.
79  """
80
81  def _ShapeFunction(op):
82    return [op.inputs[0].get_shape().with_rank_at_least(rank)]
83
84  return _ShapeFunction
85
86
87def unchanged_shape_with_rank_at_most(rank):
88  """Returns a shape function for ops that constrain the rank of their input.
89
90  Args:
91    rank: An upper bound on the rank of the input and output.
92
93  Returns:
94    A shape function for ops that output a tensor of the same size as their
95    input, with a particular rank.
96  """
97
98  def _ShapeFunction(op):
99    return [op.inputs[0].get_shape().with_rank_at_most(rank)]
100
101  return _ShapeFunction
102
103
104def matmul_shape(op):
105  """Shape function for a MatMul op."""
106  a_shape = op.inputs[0].get_shape().with_rank(2)
107  transpose_a = op.get_attr("transpose_a")
108  b_shape = op.inputs[1].get_shape().with_rank(2)
109  transpose_b = op.get_attr("transpose_b")
110  output_rows = a_shape[1] if transpose_a else a_shape[0]
111  output_cols = b_shape[0] if transpose_b else b_shape[1]
112  inner_a = a_shape[0] if transpose_a else a_shape[1]
113  inner_b = b_shape[1] if transpose_b else b_shape[0]
114  inner_a.assert_is_compatible_with(inner_b)
115  return [tensor_shape.TensorShape([output_rows, output_cols])]
116
117
118def get_conv_output_size(input_size, filter_size, strides, padding_type):
119  """Returns the spatial size of a n-d convolution/pooling output."""
120  input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
121  filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size])
122  strides = [int(x) for x in strides]
123
124  if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size):
125    return input_size
126
127  if any(x is not None and y is not None and x > y for x, y in
128         zip(filter_size, input_size)):
129    raise ValueError("Filter must not be larger than the input: "
130                     "Filter: %r Input: %r" % (filter_size, input_size))
131
132  if padding_type == b"VALID":
133
134    def _valid(in_dim, k_dim, s_dim):
135      if in_dim is not None and k_dim is not None:
136        return (in_dim - k_dim + s_dim) // s_dim
137      else:
138        return None
139
140    output_size = [
141        _valid(in_dim, k_dim, s_dim)
142        for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides)
143    ]
144  elif padding_type == b"SAME":
145
146    def _same(in_dim, s_dim):
147      if in_dim is not None:
148        return (in_dim + s_dim - 1) // s_dim
149      else:
150        return None
151
152    output_size = [_same(in_dim, s_dim)
153                   for in_dim, s_dim in zip(input_size, strides)]
154  else:
155    raise ValueError("Invalid padding: %r" % padding_type)
156
157  return tuple(output_size)
158
159
160def get2d_conv_output_size(input_height, input_width, filter_height,
161                           filter_width, row_stride, col_stride, padding_type):
162  """Returns the number of rows and columns in a convolution/pooling output."""
163  return get_conv_output_size((input_height, input_width),
164                              (filter_height, filter_width),
165                              (row_stride, col_stride), padding_type)
166
167
168def conv2d_shape(op):
169  """Shape function for a Conv2D op.
170
171  This op has two inputs:
172
173  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
174  * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
175    depth_in, depth_out]
176
177  The output is a 4D tensor with shape = [batch_size, out_rows,
178  out_cols, depth_out], where out_rows and out_cols depend on the
179  value of the op's "padding" and "strides" attrs.
180
181  Args:
182    op: A Conv2D Operation.
183
184  Returns:
185    A list containing the Shape of the Conv2D output.
186
187  Raises:
188    ValueError: If the shapes of the input or filter are incompatible.
189  """
190  input_shape = op.inputs[0].get_shape().with_rank(4)
191  filter_shape = op.inputs[1].get_shape().with_rank(4)
192
193  try:
194    data_format = op.get_attr("data_format")
195  except ValueError:
196    data_format = None
197
198  if data_format == b"NCHW":
199    # Convert input shape to the default NHWC for inference.
200    input_shape = [input_shape[0], input_shape[2], input_shape[3],
201                   input_shape[1]]
202
203  batch_size = input_shape[0]
204  in_rows = input_shape[1]
205  in_cols = input_shape[2]
206
207  filter_rows = filter_shape[0]
208  filter_cols = filter_shape[1]
209  depth_out = filter_shape[3]
210  # Check that the input depths are compatible.
211  input_shape[3].assert_is_compatible_with(filter_shape[2])
212
213  if data_format == b"NCHW":
214    stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
215  else:
216    stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
217
218  if stride_b != 1 or stride_d != 1:
219    raise ValueError("Current implementation does not yet support "
220                     "strides in the batch and depth dimensions.")
221  # TODO(mrry,shlens): Raise an error if the stride would cause
222  # information in the input to be ignored. This will require a change
223  # in the kernel implementation.
224  padding = op.get_attr("padding")
225  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
226                                              filter_cols, stride_r, stride_c,
227                                              padding)
228
229  output_shape = [batch_size, out_rows, out_cols, depth_out]
230  if data_format == b"NCHW":
231    # Convert output shape back to NCHW.
232    output_shape = [output_shape[0], output_shape[3], output_shape[1],
233                    output_shape[2]]
234  return [tensor_shape.TensorShape(output_shape)]
235
236
237def depthwise_conv2d_native_shape(op):
238  """Shape function for a DepthwiseConv2D op.
239
240  This op has two inputs:
241
242  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
243  * filter, a 4D tensor with shape =  [filter_rows, filter_cols,
244    depth_in, depthwise_multiplier]
245
246  The output is a 4D tensor with shape = [batch_size, out_rows,
247  out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend
248  on the value of the op's "padding" and "strides" attrs.
249
250  Args:
251    op: A DepthwiseConv2dNative Operation.
252
253  Returns:
254    A list containing the Shape of the DepthwiseConv2DNative output.
255
256  Raises:
257    ValueError: If the shapes of the input or filter are incompatible.
258  """
259  input_shape = op.inputs[0].get_shape().with_rank(4)
260  filter_shape = op.inputs[1].get_shape().with_rank(4)
261
262  batch_size = input_shape[0]
263  in_rows = input_shape[1]
264  in_cols = input_shape[2]
265
266  filter_rows = filter_shape[0]
267  filter_cols = filter_shape[1]
268  depth_out = filter_shape[3] * filter_shape[2]
269  # Check that the input depths are compatible.
270  input_shape[3].assert_is_compatible_with(filter_shape[2])
271
272  stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
273  if stride_b != 1 or stride_d != 1:
274    raise ValueError("Current implementation does not yet support "
275                     "strides in the batch and depth dimensions.")
276  if stride_r != stride_c:
277    # TODO(shlens): Add support for this.
278    raise ValueError("Current implementation only supports equal length "
279                     "strides in the row and column dimensions.")
280
281  # TODO(mrry,shlens): Raise an error if the stride would cause
282  # information in the input to be ignored. This will require a change
283  # in the kernel implementation.
284  stride = stride_r
285  padding = op.get_attr("padding")
286  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
287                                              filter_cols, stride, stride,
288                                              padding)
289
290  return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
291
292
293def separable_conv2d_shape(op):
294  """Shape function for a SeparableConv2D op.
295
296  This op has three inputs:
297
298  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
299
300  * depthwise_filter, a 4D tensor with shape = [filter_rows,
301    filter_cols, depth_in, depth_multiplier]
302
303  * pointwise_filter, a 4D tensor with shape = [1, 1, depth_in *
304    depth_multiplier, depth_out]
305
306  The output is a 4D tensor with shape = [batch_size, out_rows,
307  out_cols, depth_out], where out_rows and out_cols depend on the
308  value of the op's "padding" and "strides" attrs.
309
310  Args:
311    op: A SeparableConv2D Operation.
312
313  Returns:
314    A list containing the Shape of the SeparableConv2D output.
315
316  Raises:
317    ValueError: If the shapes of the input or filter are incompatible.
318  """
319  input_shape = op.inputs[0].get_shape().with_rank(4)
320  depthwise_filter_shape = op.inputs[1].get_shape().merge_with(
321      tensor_shape.TensorShape([None, None, input_shape[3], None]))
322  pointwise_depth_in = depthwise_filter_shape[2] * depthwise_filter_shape[3]
323
324  pointwise_filter_shape = op.inputs[2].get_shape().merge_with(
325      tensor_shape.TensorShape([1, 1, pointwise_depth_in, None]))
326
327  batch_size = input_shape[0]
328  in_rows = input_shape[1]
329  in_cols = input_shape[2]
330
331  filter_rows = depthwise_filter_shape[0]
332  filter_cols = depthwise_filter_shape[1]
333  depth_out = pointwise_filter_shape[3]
334
335  stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
336  if stride_b != 1 or stride_d != 1:
337    raise ValueError("Current implementation does not yet support "
338                     "strides in the batch and depth dimensions.")
339  if stride_r != stride_c:
340    # TODO(shlens): Add support for this.
341    raise ValueError("Current implementation only supports equal length "
342                     "strides in the row and column dimensions.")
343
344  # TODO(mrry,shlens): Raise an error if the stride would cause
345  # information in the input to be ignored. This will require a change
346  # in the kernel implementation.
347  stride = stride_r
348  padding = op.get_attr("padding")
349  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows,
350                                              filter_cols, stride, stride,
351                                              padding)
352
353  return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])]
354
355
356def avg_pool_shape(op):
357  """Shape function for an AvgPool op.
358
359  This op has one input:
360
361  * input, a 4D tensor with shape = [batch_size, rows, cols, depth]
362
363  The output is a 4D tensor with shape = [batch_size, out_rows,
364  out_cols, depth_out], where out_rows and out_cols depend on the
365  value of the op's "ksize", "strides", and "padding" attrs.
366
367  Args:
368    op: An AvgPool Operation.
369
370  Returns:
371    A single-element list containing the Shape of the AvgPool output.
372
373  Raises:
374    ValueError: If the shape of the input is invalid or incompatible with
375      the values of the attrs.
376  """
377  input_shape = op.inputs[0].get_shape().with_rank(4)
378  try:
379    data_format = op.get_attr("data_format")
380  except ValueError:
381    data_format = None
382
383  if data_format == b"NCHW":
384    # Convert input shape to the default NHWC for inference.
385    input_shape = [input_shape[0], input_shape[2], input_shape[3],
386                   input_shape[1]]
387
388  if data_format == b"NCHW":
389    ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
390    stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
391  else:
392    ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
393    stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
394
395  batch_size = input_shape[0]
396  in_rows = input_shape[1]
397  in_cols = input_shape[2]
398  depth = input_shape[3]
399
400  if ksize_b != 1 or ksize_d != 1:
401    raise ValueError("Current implementation does not support pooling "
402                     "in the batch and depth dimensions.")
403  if stride_b != 1 or stride_d != 1:
404    raise ValueError("Current implementation does not support strides "
405                     "in the batch and depth dimensions.")
406
407  # TODO(mrry,shlens): Raise an error if the stride would cause
408  # information in the input to be ignored. This will require a change
409  # in the kernel implementation.
410  padding = op.get_attr("padding")
411
412  out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
413                                              ksize_c, stride_r, stride_c,
414                                              padding)
415
416  output_shape = [batch_size, out_rows, out_cols, depth]
417  if data_format == b"NCHW":
418    # Convert output shape back to NCHW.
419    output_shape = [output_shape[0], output_shape[3], output_shape[1],
420                    output_shape[2]]
421  return [tensor_shape.TensorShape(output_shape)]
422
423
424def max_pool_shape(op):
425  """Shape function for a MaxPool op.
426
427  This op has one input:
428
429  * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in]
430
431  The output is a 4D tensor with shape = [batch_size, out_rows,
432  out_cols, depth_out], where out_rows, out_cols, and depth_out depend
433  on the value of the op's "ksize", "strides", and "padding" attrs.
434
435  Args:
436    op: A MaxPool Operation.
437
438  Returns:
439    A single-element list containing the Shape of the MaxPool output.
440
441  Raises:
442    ValueError: If the shape of the input is invalid or incompatible with
443      the values of the attrs.
444  """
445  input_shape = op.inputs[0].get_shape().with_rank(4)
446  try:
447    data_format = op.get_attr("data_format")
448  except ValueError:
449    data_format = None
450
451  if data_format == b"NCHW":
452    # Convert input shape to the default NHWC for inference.
453    input_shape = [input_shape[0], input_shape[2], input_shape[3],
454                   input_shape[1]]
455
456  if data_format == b"NCHW":
457    ksize_b, ksize_d, ksize_r, ksize_c = op.get_attr("ksize")
458    stride_b, stride_d, stride_r, stride_c = op.get_attr("strides")
459  else:
460    ksize_b, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
461    stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
462
463  batch_size = input_shape[0]
464  in_rows = input_shape[1]
465  in_cols = input_shape[2]
466  depth = input_shape[3]
467
468  if ksize_b != 1:
469    raise ValueError("Current implementation does not support pooling "
470                     "in the batch dimension.")
471  if stride_b != 1:
472    raise ValueError("Current implementation does not support strides "
473                     "in the batch dimension.")
474
475  if not ((ksize_r == 1 and ksize_c == 1) or ksize_d == 1):
476    raise ValueError("MaxPooling supports exactly one of pooling across depth "
477                     "or pooling across width/height.")
478
479  # TODO(mrry,shlens): Raise an error if the stride would cause
480  # information in the input to be ignored. This will require a change
481  # in the kernel implementation.
482  if ksize_d == 1:
483    padding = op.get_attr("padding")
484    out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r,
485                                                ksize_c, stride_r, stride_c,
486                                                padding)
487    output_shape = [batch_size, out_rows, out_cols, depth]
488  else:
489    if depth % ksize_d > 0:
490      raise ValueError("Depthwise max pooling requires the depth window "
491                       "to evenly divide the input depth.")
492    if stride_d != ksize_d:
493      raise ValueError("Depthwise max pooling requires the depth window "
494                       "to equal the depth stride.")
495    output_shape = [batch_size, in_rows, in_cols, depth // ksize_d]
496
497  if data_format == b"NCHW":
498    # Convert output shape back to NCHW.
499    output_shape = [output_shape[0], output_shape[3], output_shape[1],
500                    output_shape[2]]
501  return [tensor_shape.TensorShape(output_shape)]
502
503
504def no_outputs(unused_op):
505  """Shape function for use with ops that have no outputs."""
506  return []
507
508
509def unknown_shape(op):
510  """Shape function for use with ops whose output shapes are unknown."""
511  return [tensor_shape.unknown_shape() for _ in op.outputs]
512
513
514def _broadcast_shape_helper(shape_x, shape_y):
515  """Helper functions for is_broadcast_compatible and broadcast_shape.
516
517  Args:
518    shape_x: A `TensorShape`
519    shape_y: A `TensorShape`
520
521  Returns:
522    Returns None if the shapes are not broadcast compatible,
523    a list of the broadcast dimensions otherwise.
524  """
525  # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
526  # and pad with 1 to make them the same length.
527  broadcasted_dims = reversed(list(six.moves.zip_longest(
528      reversed(shape_x.dims),
529      reversed(shape_y.dims),
530      fillvalue=tensor_shape.Dimension(1))))
531  # Next we combine the dimensions according to the numpy broadcasting rules.
532  # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
533  return_dims = []
534  for (dim_x, dim_y) in broadcasted_dims:
535    if dim_x.value is None or dim_y.value is None:
536      # One or both dimensions is unknown. If either dimension is greater than
537      # 1, we assume that the program is correct, and the other dimension will
538      # be broadcast to match it.
539      # TODO(mrry): If we eliminate the shape checks in C++, we must still
540      # assert that the unknown dim is either 1 or the same as the known dim.
541      if dim_x.value is not None and dim_x.value > 1:
542        return_dims.append(dim_x)
543      elif dim_y.value is not None and dim_y.value > 1:
544        return_dims.append(dim_y)
545      else:
546        return_dims.append(None)
547    elif dim_x.value == 1:
548      # We will broadcast dim_x to dim_y.
549      return_dims.append(dim_y)
550    elif dim_y.value == 1:
551      # We will broadcast dim_y to dim_x.
552      return_dims.append(dim_x)
553    elif dim_x.value == dim_y.value:
554      # The dimensions are compatible, so output is the same size in that
555      # dimension.
556      return_dims.append(dim_x.merge_with(dim_y))
557    else:
558      return None
559  return return_dims
560
561
562def is_broadcast_compatible(shape_x, shape_y):
563  """Returns True if `shape_x` and `shape_y` are broadcast compatible.
564
565  Args:
566    shape_x: A `TensorShape`
567    shape_y: A `TensorShape`
568
569  Returns:
570    True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
571    to.  False otherwise.
572  """
573  if shape_x.ndims is None or shape_y.ndims is None:
574    return False
575  return _broadcast_shape_helper(shape_x, shape_y) is not None
576
577
578def broadcast_shape(shape_x, shape_y):
579  """Returns the broadcasted shape between `shape_x` and `shape_y`.
580
581  Args:
582    shape_x: A `TensorShape`
583    shape_y: A `TensorShape`
584
585  Returns:
586    A `TensorShape` representing the broadcasted shape.
587
588  Raises:
589    ValueError: If the two shapes can not be broadcasted.
590  """
591  if shape_x.ndims is None or shape_y.ndims is None:
592    return tensor_shape.unknown_shape()
593  return_dims = _broadcast_shape_helper(shape_x, shape_y)
594  if return_dims is None:
595    raise ValueError("Incompatible shapes for broadcasting: %s and %s"
596                     % (shape_x, shape_y))
597  return tensor_shape.TensorShape(return_dims)
598
599
600def call_cpp_shape_fn(op, require_shape_fn=True):
601  """A shape function that delegates to the registered C++ shape function.
602
603  Args:
604    op: the node in the graph for which to compute output shapes.
605    require_shape_fn: If true, and the C++ shape function is not registered
606      in the current binary then an exception is raised; otherwise, if the
607      C++ shape function is not registered then unknown_shape is used.
608
609  Returns:
610    A dictionary with the following keys:
611      shapes: A TensorShape list of the output shapes of the op, as computed
612        using the C++ shape inference function registered for the op.
613      handle_shapes: A TensorShape list of the shapes for handle outputs, if
614         any.
615      handle_dtypes: A list of DataType enums for the handle outputs, if any.
616
617  Raises:
618    ValueError: If the C++ shape function returned an error (e.g. because the
619      shapes of the inputs are of the wrong rank or otherwise incompatible
620      according to the shape function).
621    RuntimeError: If the C++ shape function is not registered and
622      <require_shape_fn> is True.
623  """
624  if op.type == "Const":
625    # To avoid serializing large constants, we special-case constant
626    # here, even though it has a C++ shape function.  When Python
627    # calls the C / C-API directly, we should be able to remove this.
628    return {
629        "shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
630        "handle_data": [None]
631    }
632
633  input_tensors_needed = []
634  input_tensors_as_shapes_needed = []
635
636  while True:
637    res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
638                                  input_tensors_as_shapes_needed,
639                                  require_shape_fn)
640    if not isinstance(res, dict):
641      # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).
642      return res
643
644    # See if we need to evaluate some inputs.
645    if not res["inputs_needed"]:
646      return res
647    p = cpp_shape_inference_pb2.CppShapeInferenceInputsNeeded()
648    p = p.FromString(res["inputs_needed"])
649    changed = False
650    for idx in p.input_tensors_needed:
651      if idx not in input_tensors_needed:
652        input_tensors_needed.append(idx)
653        changed = True
654    for idx in p.input_tensors_as_shapes_needed:
655      if idx not in input_tensors_as_shapes_needed:
656        input_tensors_as_shapes_needed.append(idx)
657        changed = True
658    if not changed:
659      return res
660
661
662def _call_cpp_shape_fn_impl(
663    op, input_tensors_needed, input_tensors_as_shapes_needed, require_shape_fn):
664  """Core implementation of call_cpp_shape_fn."""
665  graph_def_version = op.graph.graph_def_versions.producer
666  node_def_str = op.node_def.SerializeToString()
667
668  def tensor_to_inference_result(t):
669    r = cpp_shape_inference_pb2.CppShapeInferenceResult()
670    r.shape.CopyFrom(t.get_shape().as_proto())
671    # pylint: disable=protected-access
672    if t._handle_data is not None:
673      r.handle_data.CopyFrom(t._handle_data)
674    # pylint: enable=protected-access
675    return r.SerializeToString()
676  input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
677
678  input_tensors = [None for i in input_shapes]
679  for idx in input_tensors_needed:
680    v = tensor_util.constant_value(op.inputs[idx])
681    if v is not None:
682      input_tensors[idx] = np.asarray(v)
683
684  serialized_unknown_shape = (
685      tensor_shape.TensorShape(None).as_proto().SerializeToString())
686  arr = [serialized_unknown_shape for i in input_shapes]
687  for idx in input_tensors_as_shapes_needed:
688    s = tensor_util.constant_value_as_shape(op.inputs[idx])
689    if s is not None:
690      arr[idx] = s.as_proto().SerializeToString()
691  input_tensors_as_shapes = arr
692
693  missing_shape_fn = False
694  try:
695    with errors.raise_exception_on_not_ok_status() as status:
696      output = pywrap_tensorflow.RunCppShapeInference(
697          graph_def_version, node_def_str, input_shapes, input_tensors,
698          input_tensors_as_shapes, status)
699  except errors.InvalidArgumentError as err:
700    if err.message.startswith("No shape inference function exists for op"):
701      missing_shape_fn = True
702    else:
703      raise ValueError(err.message)
704
705  if missing_shape_fn:
706    if require_shape_fn:
707      raise RuntimeError(
708          "No C++ shape function registered for standard op: %s" % op.type)
709    return unknown_shape(op)
710
711  output_shapes = output[:-1]
712
713  # Convert TensorShapeProto values in output_shapes.
714  result_protos = [
715      cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
716      for s in output_shapes
717  ]
718  result = [r.shape for r in result_protos]
719  result_handle_data = [
720      r.handle_data if r.handle_data.is_set else None for r in result_protos
721  ]
722
723  return {
724      "shapes": result,
725      "handle_data": result_handle_data,
726      "inputs_needed": output[-1]
727  }
728
729# pylint: disable=protected-access
730ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
731# pylint: enable=protected-access
732