• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <iterator>
18 #include <memory>
19 #include <numeric>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/model.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30 
31 namespace toco {
32 
33 namespace {
34 
ComputeConvSizes(const Shape & input_shape,int output_depth,int kwidth,int kheight,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,PaddingType padding_type,Shape * output_shape,FixedPadding * fixed_padding)35 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
36                       int kheight, int stride_width, int stride_height,
37                       int dilation_width_factor, int dilation_height_factor,
38                       PaddingType padding_type, Shape* output_shape,
39                       FixedPadding* fixed_padding) {
40   const int input_width = input_shape.dims(2);
41   const int input_height = input_shape.dims(1);
42   const int batch = input_shape.dims(0);
43 
44   CHECK_GE(input_width, 1);
45   CHECK_GE(input_height, 1);
46   CHECK_GE(batch, 1);
47   CHECK_GE(kwidth, 1);
48   CHECK_GE(kheight, 1);
49   CHECK_GE(stride_width, 1);
50   CHECK_GE(stride_height, 1);
51   CHECK_GE(dilation_width_factor, 1);
52   CHECK_GE(dilation_height_factor, 1);
53 
54   int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
55   int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
56 
57   int output_height = 0;
58   int output_width = 0;
59   if (padding_type == PaddingType::kValid) {
60     output_height =
61         (input_height + stride_height - dilated_kheight) / stride_height;
62     output_width = (input_width + stride_width - dilated_kwidth) / stride_width;
63   } else if (padding_type == PaddingType::kSame) {
64     output_height = (input_height + stride_height - 1) / stride_height;
65     output_width = (input_width + stride_width - 1) / stride_width;
66   } else {
67     LOG(FATAL) << "Only supporting SAME or VALID padding";
68   }
69 
70   fixed_padding->height = std::max(0, ((output_height - 1) * stride_height +
71                                        dilated_kheight - input_height) /
72                                           2);
73   fixed_padding->width = std::max(
74       0,
75       ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2);
76 
77   // Actually had to debug a situation where those were negative due to bad
78   // propagation of placeholder -1 sizes in TensorFlowReshape.
79   CHECK_GT(output_width, 0);
80   CHECK_GT(output_height, 0);
81   output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
82 }
83 
ComputeBinaryOperatorOutputSize(const Shape & input_shape_x,const Shape & input_shape_y,Array * output_array)84 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
85                                      const Shape& input_shape_y,
86                                      Array* output_array) {
87   // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
88   // It zips together the two input shapes and pads with 1 to make them the
89   // same length. For each dimension we broadcast if either dimension is 1 and
90   // otherwise expect them to match.
91   int rank_x = input_shape_x.dimensions_count();
92   int rank_y = input_shape_y.dimensions_count();
93   int rank_out = std::max(rank_x, rank_y);
94   std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
95   dims_out->clear();
96   dims_out->reserve(rank_out);
97   for (int i = 0; i < rank_out; ++i) {
98     int dim_x = i < (rank_out - rank_x)
99                     ? 1
100                     : input_shape_x.dims(i - (rank_out - rank_x));
101     bool dim_y_is_one = i < (rank_out - rank_y);
102     int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
103     if (dim_x == -1 || dim_y == -1) {
104       // One or both dimensions is unknown.
105       QCHECK(false) << "Shapes must be specified";
106     } else if (dim_x == 1 || dim_y == 1) {
107       // Broadcast one dimension to the other that is 1.
108       if (dim_x == 1 && !dim_y_is_one) {
109         // Broadcast dim_y to dim_x (1).
110         dims_out->push_back(dim_y);
111       } else {
112         // Broadcast dim_x to dim_y (1).
113         DCHECK_EQ(dim_y, 1);
114         dims_out->push_back(dim_x);
115       }
116     } else {
117       // Expect the dimensions to match.
118       CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
119       dims_out->push_back(dim_x);
120     }
121   }
122   CHECK(output_array->has_shape());
123 }
124 
ProcessConvOperator(Model * model,ConvOperator * op)125 void ProcessConvOperator(Model* model, ConvOperator* op) {
126   const auto& input_array = model->GetArray(op->inputs[0]);
127   // Yield until input dims have been resolved.
128   if (!input_array.has_shape()) {
129     return;
130   }
131   const auto& input_shape = input_array.shape();
132   CHECK(input_shape.dimensions_count() == 4)
133       << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
134       << "\" is " << input_shape.dimensions_count() << "D.";
135 
136   const auto& weights_array = model->GetArray(op->inputs[1]);
137   // Yield until weights dims have been resolved.
138   if (!weights_array.has_shape()) {
139     return;
140   }
141   const auto& weights_shape = weights_array.shape();
142   CHECK_EQ(weights_shape.dimensions_count(), 4);
143 
144   auto& output_array = model->GetArray(op->outputs[0]);
145   const int output_depth = weights_shape.dims(0);
146   const int kheight = weights_shape.dims(1);
147   const int kwidth = weights_shape.dims(2);
148   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
149                    op->stride_height, op->dilation_width_factor,
150                    op->dilation_height_factor, op->padding.type,
151                    output_array.mutable_shape(),
152                    &op->padding.GetOrCreateFixedPadding());
153   CHECK_EQ(output_array.shape().dimensions_count(), 4);
154 
155   // Set im2col array dimensions if there is one.
156   if (op->outputs.size() == 2) {
157     const auto& output_shape = output_array.shape();
158     const int input_depth = weights_shape.dims(3);
159     auto& im2col_array = model->GetArray(op->outputs[1]);
160     im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
161                                   output_shape.dims(2),
162                                   input_depth * kheight * kwidth});
163   }
164 }
165 
ProcessTransposeConvOperator(Model * model,TransposeConvOperator * op)166 void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
167   // TransposeConv is unique in that it is specifically given the output shape
168   // as a 1D array on it's 1st input. Theoretically then, resolving the output
169   // shape is as easy as waiting for this input to be resolved. However, we also
170   // have to calculate the padding which requires the weights shape. So, we
171   // might as well calculate the output shape and ensure it matches the
172   // specified one
173 
174   // SPECIFIED OUTPUT SHAPE
175   // The below is the specified, or prescribed output shape, _given_ to the
176   // operator as an input.
177   auto& specified_output_shape_array =
178       model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]);
179   if (!specified_output_shape_array.has_shape() ||
180       !specified_output_shape_array.buffer) {
181     // Yield until the specified output shape is resolved as a constant
182     return;
183   }
184 
185   CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32)
186       << "TransposeConv input_dims must be int32";
187 
188   CHECK(specified_output_shape_array.shape().dimensions_count() == 1 &&
189         specified_output_shape_array.shape().dims(0) == 4)
190       << "TransposeConv requires a 1D, 4 element array on it's 0th input "
191          "specifying the output shape. \""
192       << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape "
193       << toco::ShapeToString(specified_output_shape_array.shape());
194 
195   // COMPUTE PADDING
196   // We require the weights shape to calculate padding.
197   const auto& weights_array =
198       model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]);
199   if (!weights_array.has_shape()) {
200     // Yield until weights dims have been resolved.
201     return;
202   }
203   const auto& weights_shape = weights_array.shape();
204   CHECK_EQ(weights_shape.dimensions_count(), 4)
205       << "TransposeConv weights must have 4 input dimensions. Input weights \""
206       << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
207       << toco::ShapeToString(weights_shape) << ".";
208 
209   // Compute padding
210   const int kheight = weights_shape.dims(1);
211   const int kwidth = weights_shape.dims(2);
212   op->padding.GetOrCreateFixedPadding();
213   if (op->padding.type == PaddingType::kValid) {
214     op->padding.fixed->height = 0;
215     op->padding.fixed->width = 0;
216   } else if (op->padding.type == PaddingType::kSame) {
217     op->padding.fixed->height = (kheight - 1) / 2;
218     op->padding.fixed->width = (kwidth - 1) / 2;
219   } else {
220     LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
221   }
222 
223   // VALIDATE some dimensions and set the output shape.
224   const auto& input_array =
225       model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
226   if (!input_array.has_shape()) {
227     // Yield until input dims have been resolved.
228     return;
229   }
230   const auto& input_shape = input_array.shape();
231   CHECK_EQ(input_shape.dimensions_count(), 4)
232       << "TransposeConv input shape must have 4 dimensions. Input \""
233       << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
234       << toco::ShapeToString(weights_shape) << ".";
235   CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
236       << "Input shape depth and weight depth do not agree";
237 
238   // Set the output shape according to the specified output shape.
239   std::vector<int32> const& specified_output_shape =
240       specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
241   auto& output_array = model->GetArray(op->outputs[0]);
242   *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
243 
244   // Set im2col array dimensions if there is one.
245   if (op->outputs.size() == 2) {
246     const int input_depth = weights_shape.dims(3);
247     auto& im2col_array = model->GetArray(op->outputs[1]);
248     im2col_array.copy_shape(
249         Shape{specified_output_shape[0], specified_output_shape[1],
250               specified_output_shape[2], input_depth * kheight * kwidth});
251   }
252 }
253 
ProcessDepthwiseConvOperator(Model * model,DepthwiseConvOperator * op)254 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
255   const auto& input_array = model->GetArray(op->inputs[0]);
256   // Yield until input dims have been resolved.
257   if (!input_array.has_shape()) {
258     return;
259   }
260   const auto& input_shape = input_array.shape();
261   CHECK_EQ(input_shape.dimensions_count(), 4);
262 
263   const auto& weights_array = model->GetArray(op->inputs[1]);
264   // Yield until weights dims have been resolved.
265   if (!weights_array.has_shape()) {
266     return;
267   }
268   const auto& weights_shape = weights_array.shape();
269   CHECK_EQ(weights_shape.dimensions_count(), 4);
270 
271   const string& output_name = op->outputs[0];
272   const int input_depth = input_shape.dims(3);
273   const int output_depth = weights_shape.dims(3);
274   // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
275   // instead it has to be inferred from the weights dims. However, once we are
276   // here, weights dims have already been converted to our own internal format,
277   // where the multiplier is no longer readily apparent. So instead we get it
278   // as the quotient of output and input depths. We only want to do that when
279   // depth_multiplier had the zero value: any other value should be checked
280   // as done by the next if() below.
281   if (!op->depth_multiplier) {
282     op->depth_multiplier = output_depth / input_depth;
283   }
284   CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
285       << "input/output depths and depth_multiplier don't match";
286 
287   const int kheight = weights_shape.dims(1);
288   const int kwidth = weights_shape.dims(2);
289   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
290                    op->stride_height, op->dilation_width_factor,
291                    op->dilation_height_factor, op->padding.type,
292                    model->GetArray(output_name).mutable_shape(),
293                    &op->padding.GetOrCreateFixedPadding());
294 }
295 
ProcessDepthToSpaceOperator(Model * model,DepthToSpaceOperator * op)296 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
297   const auto& input_array = model->GetArray(op->inputs[0]);
298   // Yield until input dims have been resolved.
299   if (!input_array.has_shape()) {
300     return;
301   }
302   const auto& input_shape = input_array.shape();
303   CHECK_EQ(input_shape.dimensions_count(), 4);
304 
305   const string& output_name = op->outputs[0];
306   const int block_size = op->block_size;
307   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
308   const int batch = input_shape.dims(0);
309   const int height = input_shape.dims(1);
310   const int width = input_shape.dims(2);
311   const int depth = input_shape.dims(3);
312   QCHECK_EQ(depth % (block_size * block_size), 0);
313 
314   model->GetArray(output_name)
315       .copy_shape(Shape({batch, height * block_size, width * block_size,
316                          depth / block_size / block_size}));
317 }
318 
ProcessSpaceToDepthOperator(Model * model,SpaceToDepthOperator * op)319 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
320   const auto& input_array = model->GetArray(op->inputs[0]);
321   // Yield until input dims have been resolved.
322   if (!input_array.has_shape()) {
323     return;
324   }
325   const auto& input_shape = input_array.shape();
326   CHECK_EQ(input_shape.dimensions_count(), 4);
327 
328   const string& output_name = op->outputs[0];
329   const int block_size = op->block_size;
330   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
331   const int batch = input_shape.dims(0);
332   const int height = input_shape.dims(1);
333   const int width = input_shape.dims(2);
334   const int depth = input_shape.dims(3);
335   QCHECK_EQ(width % block_size, 0);
336   QCHECK_EQ(height % block_size, 0);
337 
338   model->GetArray(output_name)
339       .copy_shape(Shape({batch, height / block_size, width / block_size,
340                          depth * block_size * block_size}));
341 }
342 
ProcessOpWithShapeInput(Model * model,Operator * op)343 void ProcessOpWithShapeInput(Model* model, Operator* op) {
344   CHECK_EQ(op->outputs.size(), 1);
345   auto& output_array = model->GetArray(op->outputs[0]);
346   if (output_array.has_shape()) {
347     // We have already run
348     return;
349   }
350 
351   auto& dims_array = model->GetArray(op->inputs[0]);
352   if (!dims_array.has_shape()) {
353     // Yield until dims shape been resolved.
354     return;
355   }
356   if (!dims_array.buffer) {
357     // Yield until the dims are constant
358     return;
359   }
360   CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
361   CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4)
362       << "dims vector can be no larger than 4 values";
363 
364   std::vector<int32> const& dims =
365       dims_array.GetBuffer<ArrayDataType::kInt32>().data;
366   *(output_array.mutable_shape()->mutable_dims()) = dims;
367 }
368 
ProcessFullyConnectedOperator(Model * model,FullyConnectedOperator * op)369 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
370   const auto& input_array = model->GetArray(op->inputs[0]);
371   // Yield until input dims have been resolved.
372   if (!input_array.has_shape()) {
373     return;
374   }
375   const auto& input_shape = input_array.shape();
376   CHECK_GE(input_shape.dimensions_count(), 1);
377 
378   const auto& weights_array = model->GetArray(op->inputs[1]);
379   // Yield until weights dims have been resolved.
380   if (!weights_array.has_shape()) {
381     return;
382   }
383   const auto& weights_shape = weights_array.shape();
384 
385   const int weights_output_depth = weights_shape.dims(0);
386   CHECK_EQ(weights_shape.dimensions_count(), 2);
387 
388   const int input_overall_size = RequiredBufferSizeForShape(input_shape);
389   const int matmul_repeats = input_overall_size / weights_shape.dims(1);
390   CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
391 
392   auto& output_array = model->GetArray(op->outputs[0]);
393   output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
394 }
395 
ProcessTensorFlowReshapeOperator(Model * model,TensorFlowReshapeOperator * op)396 void ProcessTensorFlowReshapeOperator(Model* model,
397                                       TensorFlowReshapeOperator* op) {
398   auto& output_array = model->GetArray(op->outputs[0]);
399   if (output_array.has_shape()) {
400     // We have already run
401     return;
402   }
403 
404   const auto& input_array = model->GetArray(op->inputs[0]);
405   if (!input_array.has_shape()) {
406     // Yield until input dims have been resolved.
407     return;
408   }
409   const auto& input_shape = input_array.shape();
410 
411   auto& shape_array = model->GetArray(op->inputs[1]);
412   if (!shape_array.has_shape()) {
413     // Yield until target_shape shape been resolved.
414     return;
415   }
416   if (!shape_array.buffer) {
417     // Yield until the target_shape is constant
418     return;
419   }
420   CHECK(shape_array.data_type == ArrayDataType::kInt32)
421       << "Reshape dims must be int32";
422 
423   // shape_data is the raw array of ints describing the shape
424   // in the TensorFlow node. We intentionally make a copy here, rather than
425   // modify wildcards in-place below, because in some graphs, the same shape
426   // array with a wildcard may be referenced from multiple Reshape nodes, where
427   // the wildcard needs to resolved to distinct values.
428   std::vector<int32> shape_data =
429       shape_array.GetBuffer<ArrayDataType::kInt32>().data;
430   // The Reshape shape may have a wildcard dim, encoded as -1.
431   bool has_wildcard = false;
432   int wildcard_index = 0;
433   int product_non_wildcard_dims = 1;
434   for (int i = 0; i < shape_data.size(); i++) {
435     if (shape_data[i] == -1) {
436       CHECK(!has_wildcard);
437       has_wildcard = true;
438       wildcard_index = i;
439     } else {
440       product_non_wildcard_dims *= shape_data[i];
441     }
442   }
443 
444   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
445   if (has_wildcard) {
446     CHECK_GE(input_flat_size, product_non_wildcard_dims)
447         << "Array not large enough to fill the requested dimensions for "
448            "Reshape op with output \""
449         << op->outputs[0] << "\". Are your input shapes correct?";
450     shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
451   }
452 
453   if (shape_data.size() == 1 && shape_data[0] == 0) {
454     // We have reshaped a scalar, so preserve as a scalar.
455     shape_data.clear();
456   }
457 
458   auto& output_shape = *output_array.mutable_shape();
459   *output_shape.mutable_dims() = shape_data;
460   CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
461       << "Input cannot be reshaped to requested dimensions for Reshape op with "
462          "output \""
463       << op->outputs[0] << "\". Are your input shapes correct?";
464 }
465 
ProcessSimpleOperator(Model * model,Operator * op,int input_index)466 void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
467   const auto& input_array = model->GetArray(op->inputs[input_index]);
468   // Yield until input dims have been resolved.
469   if (!input_array.has_shape()) {
470     return;
471   }
472 
473   const string& output_name = op->outputs[0];
474   auto& output_array = model->GetArray(output_name);
475   if (output_array.has_shape()) {
476     return;
477   }
478 
479   output_array.copy_shape(input_array.shape());
480 }
481 
ProcessSimpleBinaryOperator(Model * model,Operator * op)482 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
483   CHECK_EQ(op->inputs.size(), 2);
484   const auto& input0_array = model->GetArray(op->inputs[0]);
485   const auto& input1_array = model->GetArray(op->inputs[1]);
486   // Yield until input dims have been resolved.
487   if (!input0_array.has_shape() || !input1_array.has_shape()) {
488     return;
489   }
490   const string& output_name = op->outputs[0];
491   auto& output_array = model->GetArray(output_name);
492   ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
493                                   &output_array);
494 }
495 
ProcessSelectOperator(Model * model,SelectOperator * op)496 void ProcessSelectOperator(Model* model, SelectOperator* op) {
497   // Yield until all input dims have been resolved.
498   for (const auto& input : op->inputs) {
499     const auto& input_array = model->GetArray(input);
500     if (!input_array.has_shape()) {
501       return;
502     }
503   }
504 
505   // Select's output matches the second and third output.
506   const auto& input1_array = model->GetArray(op->inputs[1]);
507   auto& output_array = model->GetArray(op->outputs[0]);
508   output_array.copy_shape(input1_array.shape());
509 }
510 
ProcessAddNOperator(Model * model,Operator * op)511 void ProcessAddNOperator(Model* model, Operator* op) {
512   // Yield until all input dims have been resolved.
513   //
514   // TODO(myenik): Since AddN does not support broadcasting, maybe we could
515   // actually use this to improve shape propagation by propagating the shape of
516   // one input to all other inputs once it is resolved instead of just the
517   // output, since all inputs must be the same size and shape for a well-formed
518   // graph.
519   for (const auto& input : op->inputs) {
520     const auto& input_array = model->GetArray(input);
521     if (!input_array.has_shape()) {
522       return;
523     }
524   }
525 
526   // AddN does not support broadcasting, all inputs must be the same shape, so
527   // we just take the first input shape and apply it to the output.
528   const auto& input0_array = model->GetArray(op->inputs[0]);
529   auto& output_array = model->GetArray(op->outputs[0]);
530   output_array.copy_shape(input0_array.shape());
531 }
532 
KeepDims(const Operator & op)533 bool KeepDims(const Operator& op) {
534   switch (op.type) {
535     case OperatorType::kReduceMin:  //  Reduction Min
536       return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
537     case OperatorType::kReduceMax:  //  Reduction Max
538       return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
539     case OperatorType::kSum:
540       return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
541     case OperatorType::kReduceProd:
542       return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
543     case OperatorType::kMean:
544       return static_cast<const MeanOperator&>(op).keep_dims;
545     case OperatorType::kAny:
546       return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
547     default:
548       LOG(FATAL) << "Not a reduction operator!";
549       return false;
550   }
551 }
552 
ProcessTensorFlowReductionOperator(Model * model,Operator * op)553 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
554   CHECK_LE(op->inputs.size(), 2);
555   auto& output_array = model->GetArray(op->outputs[0]);
556   if (output_array.has_shape()) {
557     return;
558   }
559   const auto& input_array = model->GetArray(op->inputs[0]);
560   if (!input_array.has_shape()) {
561     return;
562   }
563   const auto& input_shape = input_array.shape();
564   const bool keep_dims = KeepDims(*op);
565   if (op->inputs.size() == 2) {
566     // There is a reduction_indices input.
567     const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
568     if (!reduction_indices_array.buffer) {
569       return;
570     }
571     CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
572 
573     int input_rank = input_shape.dimensions_count();
574     std::set<int32> true_indices;
575     const auto& reduction_indices =
576         reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
577     for (int i = 0; i < reduction_indices.size(); ++i) {
578       const int32 reduction_index = reduction_indices[i];
579       if (reduction_index < -input_rank || reduction_index >= input_rank) {
580         CHECK(false) << "Invalid reduction dimension " << reduction_index
581                      << " for input with " << input_rank << " dimensions";
582       }
583       int32 wrapped_index = reduction_index;
584       if (wrapped_index < 0) {
585         wrapped_index += input_rank;
586       }
587       true_indices.insert(wrapped_index);
588     }
589 
590     auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
591     mutable_dims->clear();
592     for (int i = 0; i < input_rank; ++i) {
593       if (true_indices.count(i) > 0) {
594         if (keep_dims) {
595           mutable_dims->emplace_back(1);
596         }
597       } else {
598         mutable_dims->emplace_back(input_shape.dims(i));
599       }
600     }
601   } else {
602     // No reduction_indices means complete reduction to a single scalar.
603     if (keep_dims) {
604       output_array.copy_shape(input_shape);
605     } else {
606       output_array.copy_shape(Shape({}));
607     }
608   }
609 }
610 
ProcessSliceOperator(Model * model,SliceOperator * op)611 void ProcessSliceOperator(Model* model, SliceOperator* op) {
612   CHECK_EQ(op->inputs.size(), 3);
613   CHECK_EQ(op->outputs.size(), 1);
614 
615   // Yield until the Slice params have been resolved.
616   if (op->begin.empty()) return;
617 
618   // Yield until input dims have been resolved.
619   const auto& input_array = model->GetArray(op->inputs[0]);
620   if (!input_array.has_shape()) return;
621   const Shape& input_shape = input_array.shape();
622 
623   auto& output_array = model->GetArray(op->outputs[0]);
624   if (output_array.has_shape()) return;
625 
626   CHECK_EQ(input_shape.dims().size(), op->size.size());
627   CHECK_EQ(op->begin.size(), op->size.size());
628 
629   std::vector<int> output_dims;
630   for (int i = 0; i < op->begin.size(); ++i) {
631     int size = op->size[i];
632     if (size == -1) {
633       size = input_array.shape().dims(i) - op->begin[i];
634     }
635     output_dims.push_back(size);
636   }
637 
638   *output_array.mutable_shape()->mutable_dims() = output_dims;
639 }
640 
ProcessReorderAxesOperator(Model * model,ReorderAxesOperator * op)641 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
642   const string& input_name = op->inputs[0];
643   const auto& input_array = model->GetArray(input_name);
644   // Yield until input dims have been resolved.
645   if (!input_array.has_shape()) {
646     return;
647   }
648   const auto& input_shape = input_array.shape();
649   const string& output_name = op->outputs[0];
650   Shape* output_shape = model->GetArray(output_name).mutable_shape();
651   ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
652               output_shape);
653 }
654 
ProcessConcatenationOperator(Model * model,ConcatenationOperator * op)655 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
656   // Yield until input dims have been resolved.
657   for (const auto& input_name : op->inputs) {
658     auto& input_array = model->GetArray(input_name);
659     if (!input_array.has_shape()) {
660       return;
661     }
662   }
663   auto& output_array = model->GetArray(op->outputs[0]);
664   // Use first non-empty input as basis for output dimensions.
665   for (const auto& input_name : op->inputs) {
666     const auto& input_array = model->GetArray(input_name);
667     if (input_array.shape().dimensions_count() > 0) {
668       output_array.copy_shape(input_array.shape());
669       // Negative axis means the count starts at the back of the dims().
670       if (op->axis < 0) op->axis += input_array.shape().dims().size();
671       break;
672     }
673   }
674   // Determine the concat size, and enfore that all inputs have
675   // the same dimensions count.
676   int concat_size = 0;
677   for (const auto& input_name : op->inputs) {
678     auto& input_array = model->GetArray(input_name);
679     CHECK(input_array.has_shape());
680     if (input_array.shape().dimensions_count() == 0) {
681       continue;
682     }
683     CHECK_EQ(input_array.shape().dimensions_count(),
684              output_array.shape().dimensions_count());
685     const std::vector<int>& input_dims = input_array.shape().dims();
686     CHECK_LT(op->axis, input_dims.size());
687     concat_size += input_dims[op->axis];
688   }
689   // Write out the concat_size on the output array shape.
690   auto& output_shape = *output_array.mutable_shape();
691   auto& output_dims = *output_shape.mutable_dims();
692   CHECK_LT(op->axis, output_shape.dimensions_count());
693   output_dims[op->axis] = concat_size;
694 }
695 
ProcessRangeOperator(Model * model,RangeOperator * op)696 void ProcessRangeOperator(Model* model, RangeOperator* op) {
697   CHECK_EQ(op->inputs.size(), 3);
698   const auto& start_array = model->GetArray(op->inputs[0]);
699   if (!start_array.has_shape()) {
700     // Yield until input dims have been resolved.
701     return;
702   }
703   const auto& limit_array = model->GetArray(op->inputs[1]);
704   if (!limit_array.has_shape()) {
705     return;
706   }
707   const auto& delta_array = model->GetArray(op->inputs[2]);
708   if (!delta_array.has_shape()) {
709     return;
710   }
711 
712   if (!IsConstantParameterArray(*model, op->inputs[0])) {
713     // Yield until inputs are constant.
714     return;
715   }
716   if (!IsConstantParameterArray(*model, op->inputs[1])) {
717     return;
718   }
719   if (!IsConstantParameterArray(*model, op->inputs[2])) {
720     return;
721   }
722 
723   const ArrayDataType& start_dtype = start_array.data_type;
724   CHECK(start_dtype == ArrayDataType::kInt32 ||
725         start_dtype == ArrayDataType::kFloat)
726       << "Range op inputs must be int32 or float.";
727   CHECK(limit_array.data_type == start_dtype)
728       << "In Range op, limit tensor must have the same data type as start "
729          "tensor.";
730   CHECK(delta_array.data_type == start_dtype)
731       << "In Range op, delta tensor must have the same data type as start "
732          "tensor.";
733   CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
734       << "Range op inputs must be scalar.";
735   CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
736       << "Range op inputs must be scalar.";
737   CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
738       << "Range op inputs must be scalar.";
739 
740   int size = 0;
741   if (start_dtype == ArrayDataType::kInt32) {
742     size = std::floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
743                        start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
744                       delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
745   } else if (start_dtype == ArrayDataType::kFloat) {
746     size = std::floor((limit_array.GetBuffer<ArrayDataType::kFloat>().data[0] -
747                        start_array.GetBuffer<ArrayDataType::kFloat>().data[0]) /
748                       delta_array.GetBuffer<ArrayDataType::kFloat>().data[0]);
749   }
750 
751   // Only set the output shape. Contents are set by ResolveConstantRange.
752   CHECK_EQ(op->outputs.size(), 1);
753   auto& output_array = model->GetArray(op->outputs[0]);
754   Shape* output_shape = output_array.mutable_shape();
755   output_shape->ReplaceDims({size});
756 }
757 
ProcessTensorFlowSplitOperator(Model * model,TensorFlowSplitOperator * op)758 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
759   CHECK_EQ(op->inputs.size(), 2);
760   const string& input_name = op->inputs[1];
761   const auto& input_array = model->GetArray(input_name);
762   // Yield until input dims have been resolved.
763   if (!input_array.has_shape()) {
764     return;
765   }
766   const Shape& input_shape = input_array.shape();
767 
768   // Yield until axis is constant.
769   if (!IsConstantParameterArray(*model, op->inputs[0])) {
770     return;
771   }
772 
773   const auto& axis_array = model->GetArray(op->inputs[0]);
774 
775   // Yield until axis dims have been resolved.
776   if (!axis_array.has_shape()) {
777     return;
778   }
779 
780   CHECK(axis_array.data_type == ArrayDataType::kInt32)
781       << "Axis array must be int32.";
782   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
783       << "Axis array must be scalar.";
784 
785   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
786   if (axis < 0) {
787     axis += input_shape.dimensions_count();
788   }
789 
790   const int split_dim = input_shape.dims(axis);
791   CHECK_EQ(split_dim % op->num_split, 0);
792   const int split_depth = split_dim / op->num_split;
793 
794   Shape output_shape = input_shape;
795   (*output_shape.mutable_dims())[axis] = split_depth;
796 
797   CHECK_EQ(op->outputs.size(), op->num_split);
798   for (const auto& output : op->outputs) {
799     model->GetArray(output).copy_shape(output_shape);
800   }
801 }
802 
ProcessTensorFlowSplitVOperator(Model * model,TensorFlowSplitVOperator * op)803 void ProcessTensorFlowSplitVOperator(Model* model,
804                                      TensorFlowSplitVOperator* op) {
805   CHECK_EQ(op->inputs.size(), 3);
806 
807   const auto& input_array = model->GetArray(op->inputs[0]);
808   // Yield until input dims have been resolved.
809   if (!input_array.has_shape()) {
810     return;
811   }
812   const Shape& input_shape = input_array.shape();
813 
814   // Yield until size_splits is constant.
815   if (!IsConstantParameterArray(*model, op->inputs[1])) {
816     return;
817   }
818   const auto& size_array = model->GetArray(op->inputs[1]);
819   // Yield until size_splits dims have been resolved.
820   if (!size_array.has_shape()) {
821     return;
822   }
823   const Shape& size_shape = size_array.shape();
824 
825   CHECK(size_array.data_type == ArrayDataType::kInt32 ||
826         size_array.data_type == ArrayDataType::kInt64)
827       << "size_splits must be int32, int64";
828   CHECK_EQ(size_shape.dimensions_count(), 1) << "size_splits must be 1-D";
829 
830   std::vector<int64> size_splits_vector;
831   if (size_array.data_type == ArrayDataType::kInt32) {
832     for (const auto each_size :
833          size_array.GetBuffer<ArrayDataType::kInt32>().data) {
834       size_splits_vector.push_back(each_size);
835     }
836   } else {
837     size_splits_vector = size_array.GetBuffer<ArrayDataType::kInt64>().data;
838   }
839 
840   // Yield until axis is constant.
841   if (!IsConstantParameterArray(*model, op->inputs[2])) {
842     return;
843   }
844   const auto& axis_array = model->GetArray(op->inputs[2]);
845   // Yield until axis dims have been resolved.
846   if (!axis_array.has_shape()) {
847     return;
848   }
849 
850   CHECK(axis_array.data_type == ArrayDataType::kInt32)
851       << "Axis array must be int32.";
852   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
853       << "Axis array must be scalar.";
854 
855   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
856   if (axis < 0) {
857     axis += input_shape.dimensions_count();
858   }
859 
860   CHECK_EQ(op->num_split, size_splits_vector.size());
861 
862   int64_t minus_one_count = 0, size_splits_sum = 0;
863   for (auto size : size_splits_vector) {
864     if (size == -1) {
865       ++minus_one_count;
866     } else {
867       size_splits_sum += size;
868     }
869   }
870 
871   const int input_size = input_shape.dims(axis);
872 
873   CHECK_LE(minus_one_count, 1) << "size_splits can contain at most one -1.";
874 
875   if (minus_one_count == 1) {
876     CHECK_LE(size_splits_sum, input_size);
877     auto iter =
878         std::find(size_splits_vector.begin(), size_splits_vector.end(), -1);
879     *iter = input_size - size_splits_sum;
880   } else {
881     CHECK_EQ(size_splits_sum, input_size);
882   }
883 
884   CHECK_EQ(op->outputs.size(), op->num_split);
885 
886   for (int i = 0; i < op->outputs.size(); ++i) {
887     const auto& output = op->outputs[i];
888     Shape output_shape = input_shape;
889     (*output_shape.mutable_dims())[axis] = size_splits_vector.at(i);
890     model->GetArray(output).copy_shape(output_shape);
891   }
892 }
893 
ProcessAveragePoolOperator(Model * model,AveragePoolOperator * op)894 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
895   const string& input_name = op->inputs[0];
896   const auto& input_array = model->GetArray(input_name);
897   // Yield until input dims have been resolved.
898   if (!input_array.has_shape()) {
899     return;
900   }
901   const auto& input_shape = input_array.shape();
902   CHECK_EQ(input_shape.dimensions_count(), 4);
903   const string& output_name = op->outputs[0];
904   const int output_depth = input_shape.dims(3);
905   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
906                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
907                    model->GetArray(output_name).mutable_shape(),
908                    &op->padding.GetOrCreateFixedPadding());
909 }
910 
ProcessMaxPoolOperator(Model * model,MaxPoolOperator * op)911 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
912   const string& input_name = op->inputs[0];
913   const auto& input_array = model->GetArray(input_name);
914   // Yield until input dims have been resolved.
915   if (!input_array.has_shape()) {
916     return;
917   }
918   const auto& input_shape = input_array.shape();
919   CHECK_EQ(input_shape.dimensions_count(), 4);
920   const string& output_name = op->outputs[0];
921   const int output_depth = input_shape.dims(3);
922   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
923                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
924                    model->GetArray(output_name).mutable_shape(),
925                    &op->padding.GetOrCreateFixedPadding());
926 }
927 
ProcessL2PoolOperator(Model * model,L2PoolOperator * op)928 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
929   const string& input_name = op->inputs[0];
930   const auto& input_array = model->GetArray(input_name);
931   // Yield until input dims have been resolved.
932   if (!input_array.has_shape()) {
933     return;
934   }
935   const auto& input_shape = input_array.shape();
936   if (input_shape.dimensions_count() < 4) {
937     LOG(FATAL) << "missing dimensions for " << input_name;
938   }
939   const string& output_name = op->outputs[0];
940   const int output_depth = input_shape.dims(3);
941   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
942                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
943                    model->GetArray(output_name).mutable_shape(),
944                    &op->padding.GetOrCreateFixedPadding());
945 }
946 
ProcessResizeBilinearOperator(Model * model,ResizeBilinearOperator * op)947 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
948   CHECK_EQ(op->inputs.size(), 2);
949   CHECK_EQ(op->outputs.size(), 1);
950 
951   if (!model->GetArray(op->inputs[0]).has_shape() ||
952       !model->GetArray(op->inputs[1]).has_shape()) {
953     return;
954   }
955   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
956 
957   const string& output_size_name = op->inputs[1];
958   const auto& output_size_array = model->GetArray(output_size_name);
959   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
960   CHECK(output_size_array.has_shape());
961   const auto& output_size_shape = output_size_array.shape();
962   CHECK_EQ(output_size_shape.dimensions_count(), 1);
963   CHECK_EQ(output_size_shape.dims(0), 2);
964   if (!output_size_array.buffer) {
965     return;
966   }
967   std::vector<int32> output_shape =
968       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
969   model->GetArray(op->outputs[0])
970       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
971                          output_shape[1], input_data_shape.dims(3)}));
972 }
973 
ProcessResizeNearestNeighborOperator(Model * model,ResizeNearestNeighborOperator * op)974 void ProcessResizeNearestNeighborOperator(Model* model,
975                                           ResizeNearestNeighborOperator* op) {
976   CHECK_EQ(op->inputs.size(), 2);
977   CHECK_EQ(op->outputs.size(), 1);
978 
979   if (!model->GetArray(op->inputs[0]).has_shape() ||
980       !model->GetArray(op->inputs[1]).has_shape()) {
981     return;
982   }
983   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
984 
985   const string& output_size_name = op->inputs[1];
986   const auto& output_size_array = model->GetArray(output_size_name);
987   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
988   CHECK(output_size_array.has_shape());
989   const auto& output_size_shape = output_size_array.shape();
990   CHECK_EQ(output_size_shape.dimensions_count(), 1);
991   CHECK_EQ(output_size_shape.dims(0), 2);
992   if (!output_size_array.buffer) {
993     return;
994   }
995   std::vector<int32> output_shape =
996       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
997   model->GetArray(op->outputs[0])
998       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
999                          output_shape[1], input_data_shape.dims(3)}));
1000 }
1001 
ProcessLstmCellOperator(Model * model,LstmCellOperator * op)1002 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
1003   // Only required for compact LstmCell with default NUM_INPUTS of inputs.
1004   if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
1005 
1006   const auto& input_array =
1007       model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
1008   // Yield until all input dims have been resolved.
1009   if (!input_array.has_shape()) {
1010     return;
1011   }
1012   const auto& input_shape = input_array.shape();
1013   CHECK_GE(input_shape.dimensions_count(), 2);
1014 
1015   const auto& prev_activ_array =
1016       model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
1017   // Yield until all input dims have been resolved.
1018   if (!prev_activ_array.has_shape()) {
1019     return;
1020   }
1021   const auto& prev_activ_shape = prev_activ_array.shape();
1022   CHECK_GE(prev_activ_shape.dimensions_count(), 2);
1023 
1024   const auto& weights_array =
1025       model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
1026   // Yield until weights dims have been resolved.
1027   if (!weights_array.has_shape()) {
1028     return;
1029   }
1030   const auto& weights_shape = weights_array.shape();
1031   CHECK_EQ(weights_shape.dimensions_count(), 2);
1032 
1033   const auto& bias_array =
1034       model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
1035   // Yield until bias dims have been resolved.
1036   if (!bias_array.has_shape()) {
1037     return;
1038   }
1039   const auto& bias_shape = bias_array.shape();
1040   CHECK_GE(bias_shape.dimensions_count(), 1);
1041 
1042   const auto& prev_state_array =
1043       model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
1044   // Yield until all input dims have been resolved.
1045   if (!prev_state_array.has_shape()) {
1046     return;
1047   }
1048   const auto& prev_state_shape = prev_state_array.shape();
1049   CHECK_GE(prev_state_shape.dimensions_count(), 2);
1050 
1051   const int fc_output_depth = weights_shape.dims(0);
1052   CHECK_EQ(fc_output_depth, bias_shape.dims(0));
1053   CHECK_EQ(fc_output_depth % 4, 0);
1054   const int depth = fc_output_depth / 4;
1055 
1056   const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
1057   const int fc_input_depth = weights_shape.dims(1);
1058   CHECK_EQ(input_depth + depth, fc_input_depth);
1059   Shape output_shape(input_shape);
1060   (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
1061 
1062   // Set output dimensions
1063   model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
1064       .copy_shape(output_shape);
1065   model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
1066       .copy_shape(output_shape);
1067 
1068   Shape concat_temp_shape(input_shape);
1069   (*concat_temp_shape
1070         .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
1071       fc_input_depth;
1072   model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
1073       .copy_shape(concat_temp_shape);
1074 
1075   Shape activ_temp_shape(input_shape);
1076   (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
1077       fc_output_depth;
1078   model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
1079       .copy_shape(activ_temp_shape);
1080 }
1081 
ProcessUnidirectionalSequenceLstmOperator(Model * model,UnidirectionalSequenceLstmOperator * op)1082 void ProcessUnidirectionalSequenceLstmOperator(
1083     Model* model, UnidirectionalSequenceLstmOperator* op) {
1084   auto& output_array = model->GetArray(op->outputs[0]);
1085   if (output_array.has_shape()) {
1086     // Shape already propagated
1087     return;
1088   }
1089 
1090   if (output_array.data_type == ArrayDataType::kNone) {
1091     // Yield until the output type has been set by PropagateArrayDataTypes
1092     return;
1093   }
1094 
1095   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1096   const auto& input_array = model->GetArray(op->inputs[0]);
1097 
1098   constexpr int kInputActivationStateTensor = 18;
1099   constexpr int kInputCellStateTensor = 19;
1100 
1101   // TFlite intepreter does not support array which is variable and contains a
1102   // buffer (see b/115961645 for more discussion).
1103   // The follow block remove buffer from the array to work around the
1104   // restriction, as a consequence, downstream applications should not
1105   // read lstm state as input to other operations.
1106   model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
1107   model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
1108 
1109   // Yield until input dims have been resolved.
1110   if (!input_array.has_shape()) {
1111     return;
1112   }
1113   const auto& input_shape = input_array.shape();
1114   const int batch_size = input_shape.dims(1);
1115   const int timestamp = input_shape.dims(0);
1116 
1117   const auto& recurrent_to_output_weights_array =
1118       model->GetArray(op->inputs[8]);
1119   // Yield until input dims have been resolved.
1120   if (!recurrent_to_output_weights_array.has_shape()) {
1121     return;
1122   }
1123 
1124   const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1125   const int output_size = output_weights_shape.dims(1);
1126 
1127   Shape* output_shape = output_array.mutable_shape();
1128   output_shape->ReplaceDims({timestamp, batch_size, output_size});
1129 }
1130 
ProcessUnidirectionalSequenceRnnOperator(Model * model,UnidirectionalSequenceRnnOperator * op)1131 void ProcessUnidirectionalSequenceRnnOperator(
1132     Model* model, UnidirectionalSequenceRnnOperator* op) {
1133   auto& output_array = model->GetArray(op->outputs[0]);
1134   if (output_array.has_shape()) {
1135     // Shape already propagated.
1136     return;
1137   }
1138 
1139   if (output_array.data_type == ArrayDataType::kNone) {
1140     // Yield until the output type has been set by PropagateArrayDataTypes
1141     return;
1142   }
1143 
1144   constexpr int kHiddenStateTensor = 4;
1145   // TFlite intepreter does not support array which is variable and contains a
1146   // buffer (see b/115961645 for more discussion).
1147   // The follow block remove buffer from the array to work around the
1148   // restriction, as a consequence, downstream applications should not
1149   // read lstm state as input to other operations.
1150   model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset();
1151 
1152   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1153   const auto& input_array = model->GetArray(op->inputs[0]);
1154   // Yield until input dims have been resolved.
1155   if (!input_array.has_shape()) {
1156     return;
1157   }
1158   const auto& input_shape = input_array.shape();
1159   const int batch_size = input_shape.dims(1);
1160   const int timestamp = input_shape.dims(0);
1161 
1162   const auto& bias_array = model->GetArray(op->inputs[3]);
1163   // Yield until input dims have been resolved.
1164   if (!bias_array.has_shape()) {
1165     return;
1166   }
1167 
1168   const auto& bias_shape = bias_array.shape();
1169   const int output_size = bias_shape.dims(0);
1170 
1171   Shape* output_shape = output_array.mutable_shape();
1172   output_shape->ReplaceDims({timestamp, batch_size, output_size});
1173 }
1174 
ProcessBidirectionalSequenceLstmOperator(Model * model,BidirectionalSequenceLstmOperator * op)1175 void ProcessBidirectionalSequenceLstmOperator(
1176     Model* model, BidirectionalSequenceLstmOperator* op) {
1177   // We assume time major.
1178   auto& fw_output_array = model->GetArray(op->outputs[0]);
1179   auto& bw_output_array = model->GetArray(op->outputs[1]);
1180   if (fw_output_array.has_shape()) {
1181     // Shape already propagated
1182     return;
1183   }
1184 
1185   if (fw_output_array.data_type == ArrayDataType::kNone) {
1186     // Yield until the output type has been set by PropagateArrayDataTypes
1187     return;
1188   }
1189 
1190   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1191   const auto& input_array = model->GetArray(op->inputs[0]);
1192   // Yield until input dims have been resolved.
1193   if (!input_array.has_shape()) {
1194     return;
1195   }
1196   const auto& input_shape = input_array.shape();
1197   const int batch_size = input_shape.dims(1);
1198   const int timestamp = input_shape.dims(0);
1199 
1200   constexpr int kBwRecurrentToOutputWeightsTensor = 25;
1201   const auto& recurrent_to_output_weights_array =
1202       model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]);
1203   // Yield until input dims have been resolved.
1204   if (!recurrent_to_output_weights_array.has_shape()) {
1205     return;
1206   }
1207 
1208   constexpr int kFwInputActivationStateTensor = 35;
1209   constexpr int kFwInputCellStateTensor = 36;
1210   constexpr int kBwInputActivationStateTensor = 37;
1211   constexpr int kBwInputCellStateTensor = 38;
1212   // b(115961645): This is a hack to work around.
1213   model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset();
1214   model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset();
1215   model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset();
1216   model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset();
1217 
1218   const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1219   const int output_size = output_weights_shape.dims(1);
1220 
1221   Shape* fw_output_shape = fw_output_array.mutable_shape();
1222   if (op->merge_outputs) {
1223     fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1224   } else {
1225     fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1226     Shape* bw_output_shape = bw_output_array.mutable_shape();
1227     bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1228   }
1229 }
1230 
ProcessBidirectionalSequenceRnnOperator(Model * model,BidirectionalSequenceRnnOperator * op)1231 void ProcessBidirectionalSequenceRnnOperator(
1232     Model* model, BidirectionalSequenceRnnOperator* op) {
1233   // We assume time major.
1234   auto& fw_output_array = model->GetArray(op->outputs[0]);
1235   auto& bw_output_array = model->GetArray(op->outputs[1]);
1236   if (fw_output_array.has_shape()) {
1237     // Shape already propagated
1238     return;
1239   }
1240 
1241   if (fw_output_array.data_type == ArrayDataType::kNone) {
1242     // Yield until the output type has been set by PropagateArrayDataTypes
1243     return;
1244   }
1245 
1246   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1247   const auto& input_array = model->GetArray(op->inputs[0]);
1248   // Yield until input dims have been resolved.
1249   if (!input_array.has_shape()) {
1250     return;
1251   }
1252   const auto& input_shape = input_array.shape();
1253   const int batch_size = input_shape.dims(1);
1254   const int timestamp = input_shape.dims(0);
1255 
1256   constexpr int kFwWeightsTensor = 1;
1257   const auto& forward_weights_array =
1258       model->GetArray(op->inputs[kFwWeightsTensor]);
1259   // Yield until input dims have been resolved.
1260   if (!forward_weights_array.has_shape()) {
1261     return;
1262   }
1263 
1264   constexpr int kFwHiddenStateTensor = 4;
1265   constexpr int kBwHiddenStateTensor = 8;
1266   // b(115961645): This is a hack to work around.
1267   model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset();
1268   model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset();
1269 
1270   const auto& output_weights_shape = forward_weights_array.shape();
1271   const int output_size = output_weights_shape.dims(0);
1272 
1273   Shape* fw_output_shape = fw_output_array.mutable_shape();
1274   if (op->merge_outputs) {
1275     fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1276   } else {
1277     fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1278     Shape* bw_output_shape = bw_output_array.mutable_shape();
1279     bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1280   }
1281 }
1282 
ProcessSpaceToBatchNDOperator(Model * model,SpaceToBatchNDOperator * op)1283 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
1284   const auto& input_array = model->GetArray(op->inputs[0]);
1285   // Yield until input dims have been resolved.
1286   if (!input_array.has_shape()) {
1287     return;
1288   }
1289   const auto& input_shape = input_array.shape();
1290   // This method only handles input dimensions of 4.
1291   if (input_shape.dimensions_count() != 4) {
1292     return;
1293   }
1294   const auto input_height = input_shape.dims(1);
1295   const auto input_width = input_shape.dims(2);
1296 
1297   const auto& block_shape_array = model->GetArray(op->inputs[1]);
1298   const auto& paddings_array = model->GetArray(op->inputs[2]);
1299   const auto& block_shape_array_shape = block_shape_array.shape();
1300   const auto& paddings_array_shape = paddings_array.shape();
1301   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1302   QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
1303 
1304   // We only support two dimensions.
1305   QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1306   if (!block_shape_array.buffer) {
1307     return;
1308   }
1309   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1310   const auto& block_shape_data =
1311       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1312   auto block_height = block_shape_data[0];
1313   auto block_width = block_shape_data[1];
1314 
1315   QCHECK_EQ(paddings_array_shape.dims(0), 2);  // Number of block dimensions
1316   QCHECK_EQ(paddings_array_shape.dims(1), 2);  // Two parameters per dimension.
1317   if (!paddings_array.buffer) {
1318     return;
1319   }
1320   QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
1321   const auto& paddings_data =
1322       paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
1323   int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
1324   int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
1325   QCHECK_EQ(height_with_paddings % block_height, 0);
1326   QCHECK_EQ(width_with_paddings % block_width, 0);
1327   int output_height = height_with_paddings / block_height;
1328   int output_width = width_with_paddings / block_width;
1329 
1330   model->GetArray(op->outputs[0])
1331       .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
1332                          output_height, output_width, input_shape.dims(3)}));
1333 }
1334 
ProcessBatchToSpaceNDOperator(Model * model,BatchToSpaceNDOperator * op)1335 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
1336   const auto& input_array = model->GetArray(op->inputs[0]);
1337   // Yield until input dims have been resolved.
1338   if (!input_array.has_shape()) {
1339     return;
1340   }
1341   const auto& input_shape = input_array.shape();
1342   CHECK_EQ(input_shape.dimensions_count(), 4);
1343   const auto input_height = input_shape.dims(1);
1344   const auto input_width = input_shape.dims(2);
1345 
1346   const auto& block_shape_array = model->GetArray(op->inputs[1]);
1347   const auto& crops_array = model->GetArray(op->inputs[2]);
1348   const auto& block_shape_array_shape = block_shape_array.shape();
1349   const auto& crops_array_shape = crops_array.shape();
1350   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1351   QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
1352 
1353   // We only support two dimensions.
1354   QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1355   if (!block_shape_array.buffer) {
1356     return;
1357   }
1358   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1359   const auto& block_shape_data =
1360       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1361   auto block_height = block_shape_data[0];
1362   auto block_width = block_shape_data[1];
1363 
1364   QCHECK_EQ(crops_array_shape.dims(0), 2);  // Number of block dimensions
1365   QCHECK_EQ(crops_array_shape.dims(1), 2);  // Two parameters per dimension.
1366   if (!crops_array.buffer) {
1367     return;
1368   }
1369   QCHECK(crops_array.data_type == ArrayDataType::kInt32);
1370   const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
1371   const int crops_top = crops_data[0];
1372   const int crops_bottom = crops_data[1];
1373   const int crops_left = crops_data[2];
1374   const int crops_right = crops_data[3];
1375   const int output_height =
1376       input_height * block_height - crops_top - crops_bottom;
1377   const int output_width = input_width * block_width - crops_left - crops_right;
1378   QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
1379 
1380   model->GetArray(op->outputs[0])
1381       .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
1382                          output_height, output_width, input_shape.dims(3)}));
1383 }
1384 
ProcessGatherOperator(Model * model,GatherOperator * op)1385 void ProcessGatherOperator(Model* model, GatherOperator* op) {
1386   const auto& input_array = model->GetArray(op->inputs[0]);
1387   const auto& indices_array = model->GetArray(op->inputs[1]);
1388   auto& output_array = model->GetArray(op->outputs[0]);
1389 
1390   // Bail if we already know the output shape.
1391   if (output_array.has_shape()) {
1392     return;
1393   }
1394 
1395   // Yield until input dims have been resolved.
1396   if (!input_array.has_shape() || !indices_array.has_shape()) {
1397     return;
1398   }
1399 
1400   // Yield until the axis has been resolved.
1401   if (!op->axis) {
1402     return;
1403   }
1404   int axis = op->axis.value();
1405 
1406   const auto& input_shape = input_array.shape();
1407   const auto& indices_shape = indices_array.shape();
1408   QCHECK_GE(input_shape.dimensions_count(), 1);
1409   op->input_rank = input_shape.dimensions_count();
1410   QCHECK_LT(axis, op->input_rank);
1411 
1412   // Copy the input dimensions to the output except for the axis dimensions
1413   // where the dimension of indices_shape is used.
1414   auto output_dims = output_array.mutable_shape()->mutable_dims();
1415   for (int dim = 0; dim < axis; ++dim) {
1416     output_dims->push_back(input_shape.dims(dim));
1417   }
1418   for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
1419     output_dims->push_back(indices_shape.dims(dim));
1420   }
1421   for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
1422     output_dims->push_back(input_shape.dims(dim));
1423   }
1424 }
1425 
ProcessGatherNdOperator(Model * model,GatherNdOperator * op)1426 void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) {
1427   const auto& input_array = model->GetArray(op->inputs[0]);
1428   const auto& indices_array = model->GetArray(op->inputs[1]);
1429   auto& output_array = model->GetArray(op->outputs[0]);
1430 
1431   // Bail if we already know the output shape.
1432   if (output_array.has_shape()) {
1433     return;
1434   }
1435 
1436   // Yield until input dims have been resolved.
1437   if (!input_array.has_shape() || !indices_array.has_shape()) {
1438     return;
1439   }
1440 
1441   const auto& input_shape = input_array.shape();
1442   const auto& indices_shape = indices_array.shape();
1443   QCHECK_GE(input_shape.dimensions_count(), 1);
1444   QCHECK_GE(indices_shape.dimensions_count(), 1);
1445   const int indices_nd =
1446       indices_shape.dims(indices_shape.dimensions_count() - 1);
1447   QCHECK_LE(indices_nd, input_shape.dimensions_count());
1448 
1449   auto output_dims = output_array.mutable_shape()->mutable_dims();
1450   for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) {
1451     output_dims->push_back(indices_shape.dims(dim));
1452   }
1453   for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) {
1454     output_dims->push_back(input_shape.dims(dim));
1455   }
1456 }
1457 
ProcessTopkV2Operator(Model * model,TopKV2Operator * op)1458 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
1459   const auto& input_values = model->GetArray(op->inputs[0]);
1460   const auto& input_k = model->GetArray(op->inputs[1]);
1461   auto& output_values = model->GetArray(op->outputs[0]);
1462   auto& output_indexes = model->GetArray(op->outputs[1]);
1463 
1464   // Bail if we already know the output shape.
1465   if (output_indexes.has_shape()) {
1466     QCHECK(output_values.has_shape());
1467     return;
1468   }
1469 
1470   // Yield until input dims have been resolved.
1471   if (!input_values.has_shape() || !input_k.has_shape()) {
1472     return;
1473   }
1474 
1475   // If the value is initialized, we can specify the last dimension, otherwise
1476   // unknown.
1477   if (input_k.buffer) {
1478     const auto& input_values_shape = input_values.shape();
1479     auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1480     auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1481     for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1482       output_indexes_dims->push_back(input_values_shape.dims(dim));
1483       output_values_dims->push_back(input_values_shape.dims(dim));
1484     }
1485     const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1486     output_indexes_dims->push_back(k_value);
1487     output_values_dims->push_back(k_value);
1488   }
1489 }
1490 
ProcessPadOperator(Model * model,PadOperator * op)1491 void ProcessPadOperator(Model* model, PadOperator* op) {
1492   CHECK_EQ(op->inputs.size(), 2);
1493   CHECK_EQ(op->outputs.size(), 1);
1494 
1495   const auto& input_array = model->GetArray(op->inputs[0]);
1496 
1497   // Yield until input dims have been resolved.
1498   if (!input_array.has_shape()) return;
1499 
1500   if (op->left_padding.empty()) return;
1501   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1502 
1503   auto& output_array = model->GetArray(op->outputs[0]);
1504   if (output_array.has_shape()) return;
1505 
1506   Shape output_shape = input_array.shape();
1507   std::vector<int>& dims = *output_shape.mutable_dims();
1508   CHECK_EQ(op->left_padding.size(), dims.size());
1509 
1510   for (int i = 0; i < op->left_padding.size(); ++i) {
1511     dims[i] += op->left_padding[i] + op->right_padding[i];
1512   }
1513 
1514   output_array.copy_shape(output_shape);
1515 }
1516 
ProcessPadV2Operator(Model * model,PadV2Operator * op)1517 void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
1518   CHECK_EQ(op->inputs.size(), 3);
1519   CHECK_EQ(op->outputs.size(), 1);
1520 
1521   const auto& input_array = model->GetArray(op->inputs[0]);
1522 
1523   // Yield until input dims have been resolved.
1524   if (!input_array.has_shape()) return;
1525 
1526   if (op->left_padding.empty()) return;
1527   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1528 
1529   auto& output_array = model->GetArray(op->outputs[0]);
1530   if (output_array.has_shape()) return;
1531 
1532   Shape output_shape = input_array.shape();
1533   std::vector<int>& dims = *output_shape.mutable_dims();
1534   CHECK_EQ(op->left_padding.size(), dims.size());
1535 
1536   for (int i = 0; i < op->left_padding.size(); ++i) {
1537     dims[i] += op->left_padding[i] + op->right_padding[i];
1538   }
1539 
1540   output_array.copy_shape(output_shape);
1541 }
1542 
ProcessRankOperator(Model * model,TensorFlowRankOperator * op)1543 void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
1544   CHECK_GE(op->inputs.size(), 1);
1545   CHECK_EQ(op->outputs.size(), 1);
1546   auto& output_array = model->GetArray(op->outputs[0]);
1547   if (output_array.has_shape()) {
1548     // Shape already propagated
1549     return;
1550   }
1551 
1552   if (output_array.data_type == ArrayDataType::kNone) {
1553     // Yield until the output type has been set by PropagateArrayDataTypes
1554     return;
1555   }
1556 
1557   const auto& input_array = model->GetArray(op->inputs[0]);
1558   if (!input_array.has_shape()) {
1559     // Yield until input dims have been resolved.
1560     return;
1561   }
1562 
1563   // Only set the output shape. Array contents are set by
1564   // ResolveConstantShapeOrRank.
1565   Shape* output_shape = output_array.mutable_shape();
1566   output_shape->ReplaceDims({});
1567 }
1568 
ProcessShapeOperator(Model * model,TensorFlowShapeOperator * op)1569 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1570   CHECK_GE(op->inputs.size(), 1);
1571   CHECK_EQ(op->outputs.size(), 1);
1572   auto& output_array = model->GetArray(op->outputs[0]);
1573   if (output_array.has_shape()) {
1574     // Shape already propagated
1575     return;
1576   }
1577 
1578   if (output_array.data_type == ArrayDataType::kNone) {
1579     // Yield until the output type has been set by PropagateArrayDataTypes
1580     return;
1581   }
1582 
1583   const auto& input_array = model->GetArray(op->inputs[0]);
1584   if (!input_array.has_shape()) {
1585     // Yield until input dims have been resolved.
1586     return;
1587   }
1588 
1589   // Only set the output shape. Array contents are set by
1590   // ResolveConstantShapeOrRank.
1591   Shape* output_shape = output_array.mutable_shape();
1592   output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1593 }
1594 
ProcessPackOperator(Model * model,PackOperator * op)1595 void ProcessPackOperator(Model* model, PackOperator* op) {
1596   CHECK_GE(op->inputs.size(), 1);
1597   CHECK_EQ(op->outputs.size(), 1);
1598   auto& output_array = model->GetArray(op->outputs[0]);
1599   if (output_array.has_shape()) {
1600     // Shape already propagated
1601     return;
1602   }
1603 
1604   std::unique_ptr<Shape> packed_shape;
1605   for (const auto& input : op->inputs) {
1606     const auto& input_array = model->GetArray(input);
1607     if (!input_array.has_shape()) {
1608       // Yield until all input dims have been resolved.
1609       return;
1610     }
1611 
1612     Shape shape = input_array.shape();
1613     if (!packed_shape) {
1614       packed_shape.reset(new Shape(shape));
1615     } else {
1616       CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
1617                                        "must have the same shape. Input \""
1618                                     << input << "\" is different.";
1619     }
1620   }
1621 
1622   int axis = op->axis;
1623   if (axis < 0) {
1624     // Handle negative axis
1625     axis += packed_shape->dims().size() + 1;
1626   }
1627   packed_shape->mutable_dims()->insert(
1628       packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
1629   output_array.copy_shape(*packed_shape);
1630 }
1631 
ProcessStridedSliceOperator(Model * model,StridedSliceOperator * op)1632 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1633   CHECK_GE(op->inputs.size(), 1);
1634   CHECK_EQ(op->outputs.size(), 1);
1635   auto& output_array = model->GetArray(op->outputs[0]);
1636   if (output_array.has_shape()) {
1637     // Shape already propagated
1638     return;
1639   }
1640 
1641   if (op->start_indices.empty() || op->stop_indices.empty() ||
1642       op->strides.empty()) {
1643     // ResolveStridedSliceAttributes has not run yet.
1644     return;
1645   }
1646 
1647   const auto& input_array = model->GetArray(op->inputs[0]);
1648   if (!input_array.has_shape()) {
1649     // Yield until input dims have been resolved.
1650     return;
1651   }
1652 
1653   if (op->ellipsis_mask != 0) {
1654     // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
1655     // log noise. However, the TensorFlow logging library does not appear to
1656     // support this.
1657     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1658                  << "\". ellipsis_mask is not supported (mask="
1659                  << op->ellipsis_mask << ")";
1660     return;
1661   }
1662   if (op->new_axis_mask != 0) {
1663     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1664                  << "\". new_axis_mask is not supported (mask="
1665                  << op->new_axis_mask << ")";
1666     return;
1667   }
1668 
1669   int num_input_axes = input_array.shape().dimensions_count();
1670   CHECK_LE(op->start_indices.size(), num_input_axes)
1671       << "StridedSlice op with output \"" << op->outputs[0]
1672       << "\", requires no more than " << num_input_axes << " start indices";
1673   CHECK_LE(op->stop_indices.size(), num_input_axes)
1674       << "StridedSlice op with output \"" << op->outputs[0]
1675       << "\", requires no more than " << num_input_axes << " stop indices";
1676   CHECK_LE(op->strides.size(), num_input_axes)
1677       << "StridedSlice op with output \"" << op->outputs[0]
1678       << "\", requires no more than " << num_input_axes << " strides";
1679   for (int i = 0; i < op->strides.size(); i++) {
1680     CHECK_NE(op->strides[i], 0) << "Strides must be non-zero. Axis " << i
1681                                 << " has stride=" << op->strides[i] << ".";
1682   }
1683 
1684   // Create output shape
1685   std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1686 
1687   // Compute output shape
1688   for (int axis = 0; axis < num_input_axes; ++axis) {
1689     const auto strided_slice_params =
1690         tflite::strided_slice::BuildStridedSliceParams(
1691             op->begin_mask, op->end_mask, op->shrink_axis_mask,
1692             op->start_indices, op->stop_indices, op->strides);
1693     int start_index = tflite::strided_slice::StartForAxis(
1694         strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
1695     int stop_index = tflite::strided_slice::StopForAxis(
1696         strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
1697         start_index);
1698 
1699     int dim_size = std::ceil(static_cast<float>(stop_index - start_index) /
1700                              op->strides[axis]);
1701 
1702     CHECK_GT(dim_size, 0)
1703         << "Output size for an axis must be greater than 0. Axis " << axis
1704         << " computes to size " << dim_size
1705         << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1706     if (op->shrink_axis_mask & (1 << axis)) {
1707       CHECK_EQ(dim_size, 1)
1708           << "Output size for an axis must compute to 1 when shrinking an "
1709              "axis. Axis "
1710           << axis << " computes to size " << dim_size
1711           << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1712     } else {
1713       dims->push_back(dim_size);
1714     }
1715   }
1716 }
1717 
ProcessSqueezeOperator(Model * model,SqueezeOperator * op)1718 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1719   CHECK_EQ(op->inputs.size(), 1);
1720   CHECK_EQ(op->outputs.size(), 1);
1721 
1722   const auto& input_array = model->GetArray(op->inputs[0]);
1723 
1724   // Yield until input dims have been resolved.
1725   if (!input_array.has_shape()) return;
1726 
1727   auto& output_array = model->GetArray(op->outputs[0]);
1728   if (output_array.has_shape()) return;
1729 
1730   const std::vector<int>& input_dims = input_array.shape().dims();
1731   std::vector<int> output_dims;
1732 
1733   std::vector<int> squeeze_dims;
1734   const int input_num_dims = input_dims.size();
1735   for (int i : op->squeeze_dims) {
1736     squeeze_dims.push_back(i < 0 ? i + input_num_dims : i);
1737   }
1738   for (int i = 0; i < input_num_dims; ++i) {
1739     if (input_dims[i] != 1 ||
1740         (!squeeze_dims.empty() &&
1741          std::find(squeeze_dims.begin(), squeeze_dims.end(), i) ==
1742              squeeze_dims.end())) {
1743       output_dims.push_back(input_dims[i]);
1744     }
1745   }
1746   *output_array.mutable_shape()->mutable_dims() = output_dims;
1747 }
1748 
ProcessSvdfOperator(Model * model,SvdfOperator * op)1749 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1750   CHECK(op->inputs.size() == 4 || op->inputs.size() == 5);
1751   const auto& input_array = model->GetArray(op->inputs[0]);
1752   if (!input_array.has_shape()) return;
1753 
1754   auto& weights_feature_array = model->GetArray(op->inputs[1]);
1755   if (!weights_feature_array.has_shape()) return;
1756 
1757   const auto& weights_time_array = model->GetArray(op->inputs[2]);
1758   if (!weights_time_array.has_shape()) return;
1759 
1760   const bool has_bias = (op->inputs.size() == 5);
1761   if (has_bias) {
1762     const auto& bias_array = model->GetArray(op->inputs[3]);
1763     if (!bias_array.has_shape()) return;
1764   }
1765 
1766   const int batch_size = input_array.shape().dims()[0];
1767   const int num_units = weights_feature_array.shape().dims()[0];
1768   const int memory_size = weights_time_array.shape().dims()[1];
1769 
1770   auto& state_array = model->GetArray(op->outputs[0]);
1771   state_array.mutable_shape()->ReplaceDims(
1772       {batch_size, memory_size * num_units});
1773 
1774   auto& output_array = model->GetArray(op->outputs[1]);
1775   output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1776 }
1777 
ProcessTransposeOperator(Model * model,TransposeOperator * op)1778 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1779   auto& output_array = model->GetArray(op->outputs[0]);
1780   if (output_array.has_shape()) {
1781     // We have already run
1782     return;
1783   }
1784 
1785   const auto& input_array = model->GetArray(op->inputs[0]);
1786   if (!input_array.has_shape()) {
1787     // Yield until input dims have been resolved.
1788     return;
1789   }
1790   const auto& input_shape = input_array.shape();
1791 
1792   auto& perm_array = model->GetArray(op->inputs[1]);
1793   if (!perm_array.has_shape()) {
1794     // Yield until permutation shape been resolved.
1795     return;
1796   }
1797   if (!perm_array.buffer) {
1798     // Yield until the permutation is constant
1799     return;
1800   }
1801   CHECK(perm_array.data_type == ArrayDataType::kInt32)
1802       << "Transpose permutation input must be int32";
1803 
1804   std::vector<int32> const& perm =
1805       perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1806   CHECK_EQ(perm.size(), input_shape.dimensions_count())
1807       << "Transpose permutation input " << op->inputs[1]
1808       << " must be same length as input dimensions";
1809   std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1810   for (int i = 0; i < perm.size(); i++) {
1811     int axis = perm[i];
1812     CHECK_GE(axis, 0);
1813     CHECK_LT(axis, input_shape.dimensions_count());
1814     output_dims->push_back(input_shape.dims(axis));
1815   }
1816 }
1817 
1818 template <typename Op>
ProcessArgMinMaxOperator(Model * model,Op * op)1819 void ProcessArgMinMaxOperator(Model* model, Op* op) {
1820   CHECK_EQ(op->inputs.size(), 2);
1821   const auto& input_array = model->GetArray(op->inputs[0]);
1822   // Yield until input dims have been resolved.
1823   if (!input_array.has_shape()) {
1824     return;
1825   }
1826 
1827   const Array& axis_array = model->GetArray(op->inputs[1]);
1828   // Yield until input axis array shape has been resolved.
1829   if (!axis_array.has_shape()) {
1830     return;
1831   }
1832 
1833   const std::vector<int>& input_dims = input_array.shape().dims();
1834 
1835   CHECK(axis_array.data_type == ArrayDataType::kInt32 ||
1836         axis_array.data_type == ArrayDataType::kInt64)
1837       << "axis_array must be int32, int64";
1838 
1839   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
1840       << "Axis array must be scalar.";
1841 
1842   int64 axis;
1843   if (axis_array.data_type == ArrayDataType::kInt32) {
1844     axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1845   } else {
1846     axis = axis_array.GetBuffer<ArrayDataType::kInt64>().data[0];
1847   }
1848 
1849   std::vector<int> output_dims;
1850 
1851   output_dims.reserve(input_dims.size() - 1);
1852   for (int i = 0; i < input_dims.size(); ++i) {
1853     if (i != axis) {
1854       output_dims.push_back(input_dims[i]);
1855     }
1856   }
1857 
1858   const string& output_name = op->outputs[0];
1859   auto& output_array = model->GetArray(output_name);
1860   if (output_array.has_shape()) {
1861     return;
1862   }
1863   *output_array.mutable_shape()->mutable_dims() = output_dims;
1864 }
1865 
ProcessSparseToDenseOperator(Model * model,SparseToDenseOperator * op)1866 void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
1867   CHECK_EQ(op->inputs.size(), 4);
1868 
1869   const Array& output_shape_array = model->GetArray(op->inputs[1]);
1870   if (!output_shape_array.has_shape()) return;
1871   CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
1872 
1873   // Output should not go over four dimensions.
1874   CHECK_LE(output_shape_array.shape().dims(0), 4);
1875 
1876   const string& output_name = op->outputs[0];
1877   Array& output_array = model->GetArray(output_name);
1878   if (output_array.has_shape()) return;
1879 
1880   CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
1881         output_shape_array.data_type == ArrayDataType::kInt64);
1882   if (output_shape_array.data_type == ArrayDataType::kInt32) {
1883     *output_array.mutable_shape()->mutable_dims() =
1884         output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1885   } else {
1886     const std::vector<int64>& output_shape_data =
1887         output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
1888     std::copy(
1889         output_shape_data.begin(), output_shape_data.end(),
1890         std::back_inserter(*output_array.mutable_shape()->mutable_dims()));
1891   }
1892 }
1893 
ProcessTileOperator(Model * model,TensorFlowTileOperator * op)1894 void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
1895   CHECK_EQ(op->inputs.size(), 2);
1896   CHECK_EQ(op->outputs.size(), 1);
1897 
1898   auto& output_array = model->GetArray(op->outputs[0]);
1899   if (output_array.has_shape()) {
1900     // We have already run.
1901     return;
1902   }
1903 
1904   const auto& input_array = model->GetArray(op->inputs[0]);
1905   if (!input_array.has_shape()) {
1906     // Yield until input dims have been resolved.
1907     return;
1908   }
1909   const auto& input_shape = input_array.shape();
1910 
1911   auto& multiples_array = model->GetArray(op->inputs[1]);
1912   if (!multiples_array.has_shape()) {
1913     // Yield until multiples shape been resolved.
1914     return;
1915   }
1916   if (!multiples_array.buffer) {
1917     // Yield until the multiples is constant.
1918     return;
1919   }
1920   CHECK(multiples_array.data_type == ArrayDataType::kInt32)
1921       << "Tile multiples input must be int32";
1922 
1923   std::vector<int32> const& multiples =
1924       multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
1925   CHECK_EQ(multiples.size(), input_shape.dimensions_count())
1926       << "Tile multiples input " << op->inputs[1]
1927       << " must be same length as input dimensions";
1928 
1929   auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1930   mutable_dims->resize(multiples.size());
1931   for (int i = 0; i < mutable_dims->size(); ++i) {
1932     (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
1933   }
1934 }
1935 
ProcessOneHotOperator(Model * model,OneHotOperator * op)1936 void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
1937   CHECK_EQ(op->inputs.size(), 4);
1938   CHECK_EQ(op->outputs.size(), 1);
1939   auto& output_array = model->GetArray(op->outputs[0]);
1940   if (output_array.has_shape()) {
1941     // Shape already propagated
1942     return;
1943   }
1944 
1945   // Yield until indices dims have been resolved.
1946   const auto& indices_array =
1947       model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
1948   if (!indices_array.has_shape()) {
1949     return;
1950   }
1951 
1952   // Yield until depth is constant and dims have been resolved.
1953   if (!IsConstantParameterArray(*model,
1954                                 op->inputs[OneHotOperator::DEPTH_INPUT])) {
1955     return;
1956   }
1957   const auto& depth_array =
1958       model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
1959   if (!depth_array.has_shape()) {
1960     return;
1961   }
1962 
1963   CHECK(depth_array.data_type == ArrayDataType::kInt32)
1964       << "Depth array must be int32.";
1965   CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
1966       << "Depth array must be scalar.";
1967 
1968   const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1969   CHECK_GE(depth, 0) << "Depth must be non-negative.";
1970 
1971   const int indices_dims = indices_array.shape().dimensions_count();
1972   const int output_dims = indices_dims + 1;
1973   const int axis = op->axis == -1 ? indices_dims : op->axis;
1974   CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
1975 
1976   auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1977   mutable_dims->resize(output_dims);
1978   for (int i = 0; i < output_dims; ++i) {
1979     int dim = 0;
1980     if (i < axis) {
1981       dim = indices_array.shape().dims(i);
1982     } else if (i == axis) {
1983       dim = depth;
1984     } else {
1985       dim = indices_array.shape().dims(i - 1);
1986     }
1987     (*mutable_dims)[i] = dim;
1988   }
1989 }
1990 
ProcessUnpackOperator(Model * model,UnpackOperator * op)1991 void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
1992   CHECK_EQ(op->inputs.size(), 1);
1993   const auto& input_array = model->GetArray(op->inputs[0]);
1994   // Yield until input dims have been resolved.
1995   if (!input_array.has_shape()) {
1996     return;
1997   }
1998 
1999   const std::vector<int>& input_dims = input_array.shape().dims();
2000   std::vector<int> output_dims;
2001 
2002   output_dims.reserve(input_dims.size() - 1);
2003   for (int i = 0; i < input_dims.size(); ++i) {
2004     if (i != op->axis) {
2005       output_dims.push_back(input_dims[i]);
2006     }
2007   }
2008   for (const string& output_name : op->outputs) {
2009     auto& output_array = model->GetArray(output_name);
2010     if (output_array.has_shape()) {
2011       return;
2012     }
2013     *output_array.mutable_shape()->mutable_dims() = output_dims;
2014   }
2015 }
2016 
ProcessMirrorPadOperator(Model * model,MirrorPadOperator * op)2017 void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
2018   CHECK_EQ(op->inputs.size(), 2);
2019   const auto& input_array = model->GetArray(op->inputs[0]);
2020   const auto& padding_matrix = model->GetArray(op->inputs[1]);
2021 
2022   // Yield until input dims have been resolved.
2023   if (!input_array.has_shape()) {
2024     return;
2025   }
2026 
2027   auto& output_array = model->GetArray(op->outputs[0]);
2028   // If output already computed or padding matrix is non
2029   // const then return.
2030   if (output_array.has_shape() ||
2031       !IsConstantParameterArray(*model, op->inputs[1])) {
2032     return;
2033   }
2034   Shape output_shape = input_array.shape();
2035   std::vector<int>& dims = *output_shape.mutable_dims();
2036 
2037   std::vector<int64_t> padding;
2038   if (padding_matrix.data_type == ArrayDataType::kInt32) {
2039     const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt32>().data;
2040     for (auto elem : data) {
2041       padding.push_back(static_cast<int64_t>(elem));
2042     }
2043   } else if (padding_matrix.data_type == ArrayDataType::kInt64) {
2044     const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt64>().data;
2045     for (auto elem : data) {
2046       padding.push_back(elem);
2047     }
2048   } else {
2049     CHECK(padding_matrix.data_type == ArrayDataType::kInt64 ||
2050           padding_matrix.data_type == ArrayDataType::kInt32);
2051   }
2052   CHECK_EQ(padding_matrix.shape().dimensions_count(), 2);
2053   CHECK_EQ(input_array.shape().dimensions_count(),
2054            padding_matrix.shape().dims(0));
2055   for (int i = 0; i < input_array.shape().dimensions_count(); ++i) {
2056     dims[i] += padding[i * 2] + padding[i * 2 + 1];
2057   }
2058 
2059   output_array.copy_shape(output_shape);
2060 }
2061 
ProcessUniqueOperator(Model * model,UniqueOperator * op)2062 void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
2063   const auto& input_array = model->GetArray(op->inputs[0]);
2064   // We have 2 outputs, the shape of the index tensor, is the same size
2065   // as the input array. The unique values tensor, is unknown until runtime.
2066   CHECK_EQ(op->outputs.size(), 2);
2067   auto& idx_output_array = model->GetArray(op->outputs[1]);
2068 
2069   // Yield until input dims have been resolved, or output already computed
2070   if (!input_array.has_shape() || idx_output_array.has_shape()) {
2071     return;
2072   }
2073   idx_output_array.copy_shape(input_array.shape());
2074 }
2075 
ProcessMatrixDiagOperator(Model * model,MatrixDiagOperator * op)2076 void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
2077   CHECK_EQ(op->inputs.size(), 1);
2078   CHECK_EQ(op->outputs.size(), 1);
2079   auto& input_array = model->GetArray(op->inputs[0]);
2080   auto& output_array = model->GetArray(op->outputs[0]);
2081   // The input array must have a shape in order to proceed. Also,
2082   // bail out if the output shape has already been calculated.
2083   if (!input_array.has_shape() || output_array.has_shape()) {
2084     // We have already run
2085     return;
2086   }
2087   // Get the input_shape
2088   Shape* mutable_shape = input_array.mutable_shape();
2089   std::vector<int>* dims = mutable_shape->mutable_dims();
2090   int dims_size = dims->size();
2091   // Scalars are not allowed.
2092   CHECK_GT(dims_size, 0);
2093   int last_dim = (*dims)[dims_size - 1];
2094   dims->push_back(last_dim);
2095   output_array.copy_shape(*mutable_shape);
2096 }
2097 
ProcessMatrixSetDiagOperator(Model * model,MatrixSetDiagOperator * op)2098 void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
2099   CHECK_EQ(op->inputs.size(), 2);
2100   CHECK_EQ(op->outputs.size(), 1);
2101   auto& input_array = model->GetArray(op->inputs[0]);
2102   auto& output_array = model->GetArray(op->outputs[0]);
2103   // The shape of the input array must be known because that will
2104   // be the shape of the output array.
2105   if (!input_array.has_shape() || !output_array.has_shape()) {
2106     // We have already run
2107     return;
2108   }
2109 
2110   output_array.copy_shape(input_array.shape());
2111 }
2112 
2113 }  // namespace
2114 
Run(Model * model,std::size_t op_index,bool * modified)2115 ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
2116                                               std::size_t op_index,
2117                                               bool* modified) {
2118   *modified = false;
2119   auto it = model->operators.begin() + op_index;
2120   auto* op = it->get();
2121   std::unordered_map<string, std::vector<int>> old_output_dims;
2122   for (const auto& output : op->outputs) {
2123     if (model->GetArray(output).has_shape()) {
2124       old_output_dims[output] = model->GetArray(output).shape().dims();
2125     }
2126   }
2127 
2128   switch (op->type) {
2129     case OperatorType::kAbs:
2130     case OperatorType::kBatchNormalization:
2131     case OperatorType::kL2Normalization:
2132     case OperatorType::kDequantize:
2133     case OperatorType::kElu:
2134     case OperatorType::kHardSwish:
2135     case OperatorType::kRelu:
2136     case OperatorType::kRelu1:
2137     case OperatorType::kRelu6:
2138     case OperatorType::kPRelu:
2139     case OperatorType::kLeakyRelu:
2140     case OperatorType::kSoftmax:
2141     case OperatorType::kLogSoftmax:
2142     case OperatorType::kLog:
2143     case OperatorType::kLogistic:
2144     case OperatorType::kTanh:
2145     case OperatorType::kLocalResponseNormalization:
2146     case OperatorType::kIdentity:
2147     case OperatorType::kFakeQuant:
2148     case OperatorType::kNeg:
2149     case OperatorType::kRsqrt:
2150     case OperatorType::kSqrt:
2151     case OperatorType::kSquare:
2152     case OperatorType::kAll:
2153     case OperatorType::kAssert:
2154     case OperatorType::kCast:
2155     case OperatorType::kFloor:
2156     case OperatorType::kCeil:
2157     case OperatorType::kRound:
2158     case OperatorType::kExp:
2159     case OperatorType::kSin:
2160     case OperatorType::kCos:
2161     case OperatorType::kLogicalAnd:
2162     case OperatorType::kLogicalNot:
2163     case OperatorType::kLogicalOr:
2164     case OperatorType::kZerosLike:
2165     case OperatorType::kReverseV2:
2166     case OperatorType::kReverseSequence:
2167       ProcessSimpleOperator(model, op, 0);
2168       break;
2169     case OperatorType::kGather:
2170       ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
2171       break;
2172     case OperatorType::kGatherNd:
2173       ProcessGatherNdOperator(model, static_cast<GatherNdOperator*>(op));
2174       break;
2175     case OperatorType::kTopK_V2:
2176       ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
2177       break;
2178     case OperatorType::kAdd:
2179     case OperatorType::kSub:
2180     case OperatorType::kMul:
2181     case OperatorType::kDiv:
2182     case OperatorType::kFloorDiv:
2183     case OperatorType::kFloorMod:
2184     case OperatorType::kLess:
2185     case OperatorType::kLessEqual:
2186     case OperatorType::kGreater:
2187     case OperatorType::kMaximum:  //  Element-wise Maximum
2188     case OperatorType::kMinimum:  //  Element-wise Minimum
2189     case OperatorType::kGreaterEqual:
2190     case OperatorType::kEqual:
2191     case OperatorType::kNotEqual:
2192     case OperatorType::kPow:
2193     case OperatorType::kSquaredDifference:
2194       ProcessSimpleBinaryOperator(model, op);
2195       break;
2196     case OperatorType::kAddN:
2197       ProcessAddNOperator(model, op);
2198       break;
2199     case OperatorType::kConv:
2200       ProcessConvOperator(model, static_cast<ConvOperator*>(op));
2201       break;
2202     case OperatorType::kTransposeConv:
2203       ProcessTransposeConvOperator(model,
2204                                    static_cast<TransposeConvOperator*>(op));
2205       break;
2206     case OperatorType::kDepthwiseConv:
2207       ProcessDepthwiseConvOperator(model,
2208                                    static_cast<DepthwiseConvOperator*>(op));
2209       break;
2210     case OperatorType::kDepthToSpace:
2211       ProcessDepthToSpaceOperator(model,
2212                                   static_cast<DepthToSpaceOperator*>(op));
2213       break;
2214     case OperatorType::kSpaceToDepth:
2215       ProcessSpaceToDepthOperator(model,
2216                                   static_cast<SpaceToDepthOperator*>(op));
2217       break;
2218     case OperatorType::kFill:
2219       CHECK_EQ(op->inputs.size(), 2);
2220       ProcessOpWithShapeInput(model, op);
2221       break;
2222     case OperatorType::kFullyConnected:
2223       ProcessFullyConnectedOperator(model,
2224                                     static_cast<FullyConnectedOperator*>(op));
2225       break;
2226     case OperatorType::kReshape:
2227       ProcessTensorFlowReshapeOperator(
2228           model, static_cast<TensorFlowReshapeOperator*>(op));
2229       break;
2230     case OperatorType::kAveragePool:
2231       ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
2232       break;
2233     case OperatorType::kMaxPool:
2234       ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
2235       break;
2236     case OperatorType::kL2Pool:
2237       ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
2238       break;
2239     case OperatorType::kReduceMin:  //  Reduction Min
2240     case OperatorType::kReduceMax:  //  Reduction Max
2241     case OperatorType::kSum:
2242     case OperatorType::kReduceProd:
2243     case OperatorType::kMean:
2244     case OperatorType::kAny:
2245       ProcessTensorFlowReductionOperator(model, op);
2246       break;
2247     case OperatorType::kSelect:
2248       ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
2249       break;
2250     case OperatorType::kSlice:
2251       ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
2252       break;
2253 
2254     case OperatorType::kSwitch:
2255       // We can't know the sizes of the outputs until we have resolved the
2256       // predicate, and once we have resolved the predicate, the whole
2257       // Switch node will get resolved away.
2258       // See ResolveTensorFlowSwitch.
2259       break;
2260     case OperatorType::kMerge:
2261       // No need to bother resolving TensorFlow Merge ops: other graph
2262       // transformations will remove them anyway.
2263       // See ResolveTensorFlowMerge.
2264       break;
2265     case OperatorType::kSplit:
2266       ProcessTensorFlowSplitOperator(model,
2267                                      static_cast<TensorFlowSplitOperator*>(op));
2268       break;
2269     case OperatorType::kSplitV:
2270       ProcessTensorFlowSplitVOperator(
2271           model, static_cast<TensorFlowSplitVOperator*>(op));
2272       break;
2273     case OperatorType::kSqueeze:
2274       ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
2275       break;
2276     case OperatorType::kConcat:
2277     case OperatorType::kConcatV2:
2278       // Unimplemented, hopefully another graph transformation will
2279       // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
2280       // will resolve this node to a DepthConcatenation, or else we have
2281       // a more general non-depth concatenation that will hopefully be dropped,
2282       // or else at the moment we will abort.
2283       break;
2284     case OperatorType::kExpandDims:
2285       // Yield until ExpandDims is converted to Reshape
2286       break;
2287     case OperatorType::kRange:
2288       ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
2289       break;
2290     case OperatorType::kRank:
2291       ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
2292       break;
2293     case OperatorType::kShape:
2294       ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
2295       break;
2296     case OperatorType::kPack:
2297       ProcessPackOperator(model, static_cast<PackOperator*>(op));
2298       break;
2299     case OperatorType::kReorderAxes:
2300       ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
2301       break;
2302     case OperatorType::kConcatenation:
2303       ProcessConcatenationOperator(model,
2304                                    static_cast<ConcatenationOperator*>(op));
2305       break;
2306     case OperatorType::kResizeBilinear:
2307       ProcessResizeBilinearOperator(model,
2308                                     static_cast<ResizeBilinearOperator*>(op));
2309       break;
2310     case OperatorType::kResizeNearestNeighbor:
2311       ProcessResizeNearestNeighborOperator(
2312           model, static_cast<ResizeNearestNeighborOperator*>(op));
2313       break;
2314     case OperatorType::kUnidirectionalSequenceLstm:
2315       ProcessUnidirectionalSequenceLstmOperator(
2316           model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
2317       break;
2318     case OperatorType::kUnidirectionalSequenceRnn:
2319       ProcessUnidirectionalSequenceRnnOperator(
2320           model, static_cast<UnidirectionalSequenceRnnOperator*>(op));
2321       break;
2322     case OperatorType::kBidirectionalSequenceLstm:
2323       ProcessBidirectionalSequenceLstmOperator(
2324           model, static_cast<BidirectionalSequenceLstmOperator*>(op));
2325       break;
2326     case OperatorType::kBidirectionalSequenceRnn:
2327       ProcessBidirectionalSequenceRnnOperator(
2328           model, static_cast<BidirectionalSequenceRnnOperator*>(op));
2329       break;
2330     case OperatorType::kLstmCell:
2331       ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
2332       break;
2333     case OperatorType::kBatchMatMul:
2334     case OperatorType::kMatMul:
2335       // MatMul operators are converted to FullyConnected, after which their
2336       // shapes are propagated.
2337       break;
2338     case OperatorType::kSpaceToBatchND:
2339       ProcessSpaceToBatchNDOperator(model,
2340                                     static_cast<SpaceToBatchNDOperator*>(op));
2341       break;
2342     case OperatorType::kBatchToSpaceND:
2343       ProcessBatchToSpaceNDOperator(model,
2344                                     static_cast<BatchToSpaceNDOperator*>(op));
2345       break;
2346     case OperatorType::kPad:
2347       ProcessPadOperator(model, static_cast<PadOperator*>(op));
2348       break;
2349     case OperatorType::kPadV2:
2350       ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
2351       break;
2352     case OperatorType::kStridedSlice:
2353       ProcessStridedSliceOperator(model,
2354                                   static_cast<StridedSliceOperator*>(op));
2355       break;
2356     case OperatorType::kArgMax:
2357       ProcessArgMinMaxOperator<ArgMaxOperator>(
2358           model, static_cast<ArgMaxOperator*>(op));
2359       break;
2360     case OperatorType::kArgMin:
2361       ProcessArgMinMaxOperator<ArgMinOperator>(
2362           model, static_cast<ArgMinOperator*>(op));
2363       break;
2364     case OperatorType::kUnsupported: {
2365       const auto* unsupported_op =
2366           static_cast<TensorFlowUnsupportedOperator*>(op);
2367       // Attribute can be not specified, ignore it.
2368       if (unsupported_op->output_shapes.size() < op->outputs.size()) {
2369         return ::tensorflow::Status::OK();
2370       }
2371       for (int i = 0; i < op->outputs.size(); ++i) {
2372         const string& output = op->outputs[i];
2373         model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
2374       }
2375       break;
2376     }
2377     case OperatorType::kSvdf:
2378       ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
2379       break;
2380     case OperatorType::kTranspose:
2381       ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
2382       break;
2383     case OperatorType::kDynamicPartition:
2384     case OperatorType::kDynamicStitch:
2385       // DynamicPartition/DynamicStitch are currently only supported for
2386       // transforms that remove them, so we avoid propagating shapes through
2387       // them and let things settle once they've been removed.
2388       break;
2389     case OperatorType::kRandomUniform:
2390       CHECK_EQ(op->inputs.size(), 1);
2391       ProcessOpWithShapeInput(model, op);
2392       break;
2393     case OperatorType::kSparseToDense:
2394       ProcessSparseToDenseOperator(model,
2395                                    static_cast<SparseToDenseOperator*>(op));
2396       break;
2397     case OperatorType::kTile:
2398       ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
2399       break;
2400       break;
2401     case OperatorType::kOneHot:
2402       ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
2403       break;
2404     case OperatorType::kUnpack:
2405       ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
2406       break;
2407     case OperatorType::kMirrorPad:
2408       ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
2409       break;
2410     case OperatorType::kUnique:
2411       ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
2412       break;
2413     case OperatorType::kWhere:
2414       // The size of the output can only be known after evaluating the cond
2415       // tensor. Ignore shape propagation here and defer that to the
2416       // interpreter.
2417       break;
2418     case OperatorType::kMatrixDiag:
2419       ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
2420       break;
2421     case OperatorType::kMatrixSetDiag:
2422       ProcessMatrixSetDiagOperator(model,
2423                                    static_cast<MatrixSetDiagOperator*>(op));
2424       break;
2425     case OperatorType::kCTCBeamSearchDecoder:
2426       // The sizes of the outputs are only known in runtime based on the input.
2427       // Ignore shape progapation here and defer that to the interpreter.
2428       break;
2429     case OperatorType::kMatrixSetDiagV2:
2430       // MatrixSetDiagV2 operators are converted to MatrixSetDiag,
2431       // after which their shapes are propagated.
2432       break;
2433     case OperatorType::kMatrixDiagV2:
2434       // MatrixDiagV2 operators are converted to MatrixDiag, after which their
2435       // shapes are propagated.
2436       break;
2437     case OperatorType::kMatrixDiagV3:
2438       // MatrixDiagV3 operators are converted to MatrixDiag, after which their
2439       // shapes are propagated.
2440       break;
2441     case OperatorType::kMatrixSetDiagV3:
2442       // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
2443       // their shapes are propagated.
2444       break;
2445     case OperatorType::kSegmentSum:
2446       break;
2447     default:
2448       // Unimplemented, another graph transformation should drop it.
2449       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
2450   }
2451 
2452   // Return true if any output dim changed, false if none changed.
2453   // Assumption: no transformation clears an output shape, they only add shapes.
2454   for (const auto& output : op->outputs) {
2455     if (model->GetArray(output).has_shape() &&
2456         (old_output_dims[output] != model->GetArray(output).shape().dims())) {
2457       AddMessageF("Set shape of %s to [%s]", output,
2458                   absl::StrJoin(model->GetArray(output).shape().dims(), ","));
2459       *modified = true;
2460       return ::tensorflow::Status::OK();
2461     }
2462   }
2463   return ::tensorflow::Status::OK();
2464 }
2465 
2466 }  // namespace toco
2467