1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <iterator>
18 #include <memory>
19 #include <numeric>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/model.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30
31 namespace toco {
32
33 namespace {
34
ComputeConvSizes(const Shape & input_shape,int output_depth,int kwidth,int kheight,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,PaddingType padding_type,Shape * output_shape,FixedPadding * fixed_padding)35 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
36 int kheight, int stride_width, int stride_height,
37 int dilation_width_factor, int dilation_height_factor,
38 PaddingType padding_type, Shape* output_shape,
39 FixedPadding* fixed_padding) {
40 const int input_width = input_shape.dims(2);
41 const int input_height = input_shape.dims(1);
42 const int batch = input_shape.dims(0);
43
44 CHECK_GE(input_width, 1);
45 CHECK_GE(input_height, 1);
46 CHECK_GE(batch, 1);
47 CHECK_GE(kwidth, 1);
48 CHECK_GE(kheight, 1);
49 CHECK_GE(stride_width, 1);
50 CHECK_GE(stride_height, 1);
51 CHECK_GE(dilation_width_factor, 1);
52 CHECK_GE(dilation_height_factor, 1);
53
54 int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
55 int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
56
57 int output_height = 0;
58 int output_width = 0;
59 if (padding_type == PaddingType::kValid) {
60 output_height =
61 (input_height + stride_height - dilated_kheight) / stride_height;
62 output_width = (input_width + stride_width - dilated_kwidth) / stride_width;
63 } else if (padding_type == PaddingType::kSame) {
64 output_height = (input_height + stride_height - 1) / stride_height;
65 output_width = (input_width + stride_width - 1) / stride_width;
66 } else {
67 LOG(FATAL) << "Only supporting SAME or VALID padding";
68 }
69
70 fixed_padding->height = std::max(0, ((output_height - 1) * stride_height +
71 dilated_kheight - input_height) /
72 2);
73 fixed_padding->width = std::max(
74 0,
75 ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2);
76
77 // Actually had to debug a situation where those were negative due to bad
78 // propagation of placeholder -1 sizes in TensorFlowReshape.
79 CHECK_GT(output_width, 0);
80 CHECK_GT(output_height, 0);
81 output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
82 }
83
ComputeBinaryOperatorOutputSize(const Shape & input_shape_x,const Shape & input_shape_y,Array * output_array)84 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
85 const Shape& input_shape_y,
86 Array* output_array) {
87 // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
88 // It zips together the two input shapes and pads with 1 to make them the
89 // same length. For each dimension we broadcast if either dimension is 1 and
90 // otherwise expect them to match.
91 int rank_x = input_shape_x.dimensions_count();
92 int rank_y = input_shape_y.dimensions_count();
93 int rank_out = std::max(rank_x, rank_y);
94 std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
95 dims_out->clear();
96 dims_out->reserve(rank_out);
97 for (int i = 0; i < rank_out; ++i) {
98 int dim_x = i < (rank_out - rank_x)
99 ? 1
100 : input_shape_x.dims(i - (rank_out - rank_x));
101 bool dim_y_is_one = i < (rank_out - rank_y);
102 int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
103 if (dim_x == -1 || dim_y == -1) {
104 // One or both dimensions is unknown.
105 QCHECK(false) << "Shapes must be specified";
106 } else if (dim_x == 1 || dim_y == 1) {
107 // Broadcast one dimension to the other that is 1.
108 if (dim_x == 1 && !dim_y_is_one) {
109 // Broadcast dim_y to dim_x (1).
110 dims_out->push_back(dim_y);
111 } else {
112 // Broadcast dim_x to dim_y (1).
113 DCHECK_EQ(dim_y, 1);
114 dims_out->push_back(dim_x);
115 }
116 } else {
117 // Expect the dimensions to match.
118 CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
119 dims_out->push_back(dim_x);
120 }
121 }
122 CHECK(output_array->has_shape());
123 }
124
ProcessConvOperator(Model * model,ConvOperator * op)125 void ProcessConvOperator(Model* model, ConvOperator* op) {
126 const auto& input_array = model->GetArray(op->inputs[0]);
127 // Yield until input dims have been resolved.
128 if (!input_array.has_shape()) {
129 return;
130 }
131 const auto& input_shape = input_array.shape();
132 CHECK(input_shape.dimensions_count() == 4)
133 << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
134 << "\" is " << input_shape.dimensions_count() << "D.";
135
136 const auto& weights_array = model->GetArray(op->inputs[1]);
137 // Yield until weights dims have been resolved.
138 if (!weights_array.has_shape()) {
139 return;
140 }
141 const auto& weights_shape = weights_array.shape();
142 CHECK_EQ(weights_shape.dimensions_count(), 4);
143
144 auto& output_array = model->GetArray(op->outputs[0]);
145 const int output_depth = weights_shape.dims(0);
146 const int kheight = weights_shape.dims(1);
147 const int kwidth = weights_shape.dims(2);
148 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
149 op->stride_height, op->dilation_width_factor,
150 op->dilation_height_factor, op->padding.type,
151 output_array.mutable_shape(),
152 &op->padding.GetOrCreateFixedPadding());
153 CHECK_EQ(output_array.shape().dimensions_count(), 4);
154
155 // Set im2col array dimensions if there is one.
156 if (op->outputs.size() == 2) {
157 const auto& output_shape = output_array.shape();
158 const int input_depth = weights_shape.dims(3);
159 auto& im2col_array = model->GetArray(op->outputs[1]);
160 im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
161 output_shape.dims(2),
162 input_depth * kheight * kwidth});
163 }
164 }
165
ProcessTransposeConvOperator(Model * model,TransposeConvOperator * op)166 void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
167 // TransposeConv is unique in that it is specifically given the output shape
168 // as a 1D array on it's 1st input. Theoretically then, resolving the output
169 // shape is as easy as waiting for this input to be resolved. However, we also
170 // have to calculate the padding which requires the weights shape. So, we
171 // might as well calculate the output shape and ensure it matches the
172 // specified one
173
174 // SPECIFIED OUTPUT SHAPE
175 // The below is the specified, or prescribed output shape, _given_ to the
176 // operator as an input.
177 auto& specified_output_shape_array =
178 model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]);
179 if (!specified_output_shape_array.has_shape() ||
180 !specified_output_shape_array.buffer) {
181 // Yield until the specified output shape is resolved as a constant
182 return;
183 }
184
185 CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32)
186 << "TransposeConv input_dims must be int32";
187
188 CHECK(specified_output_shape_array.shape().dimensions_count() == 1 &&
189 specified_output_shape_array.shape().dims(0) == 4)
190 << "TransposeConv requires a 1D, 4 element array on it's 0th input "
191 "specifying the output shape. \""
192 << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape "
193 << toco::ShapeToString(specified_output_shape_array.shape());
194
195 // COMPUTE PADDING
196 // We require the weights shape to calculate padding.
197 const auto& weights_array =
198 model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]);
199 if (!weights_array.has_shape()) {
200 // Yield until weights dims have been resolved.
201 return;
202 }
203 const auto& weights_shape = weights_array.shape();
204 CHECK_EQ(weights_shape.dimensions_count(), 4)
205 << "TransposeConv weights must have 4 input dimensions. Input weights \""
206 << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
207 << toco::ShapeToString(weights_shape) << ".";
208
209 // Compute padding
210 const int kheight = weights_shape.dims(1);
211 const int kwidth = weights_shape.dims(2);
212 op->padding.GetOrCreateFixedPadding();
213 if (op->padding.type == PaddingType::kValid) {
214 op->padding.fixed->height = 0;
215 op->padding.fixed->width = 0;
216 } else if (op->padding.type == PaddingType::kSame) {
217 op->padding.fixed->height = (kheight - 1) / 2;
218 op->padding.fixed->width = (kwidth - 1) / 2;
219 } else {
220 LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
221 }
222
223 // VALIDATE some dimensions and set the output shape.
224 const auto& input_array =
225 model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
226 if (!input_array.has_shape()) {
227 // Yield until input dims have been resolved.
228 return;
229 }
230 const auto& input_shape = input_array.shape();
231 CHECK_EQ(input_shape.dimensions_count(), 4)
232 << "TransposeConv input shape must have 4 dimensions. Input \""
233 << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
234 << toco::ShapeToString(weights_shape) << ".";
235 CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
236 << "Input shape depth and weight depth do not agree";
237
238 // Set the output shape according to the specified output shape.
239 std::vector<int32> const& specified_output_shape =
240 specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
241 auto& output_array = model->GetArray(op->outputs[0]);
242 *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
243
244 // Set im2col array dimensions if there is one.
245 if (op->outputs.size() == 2) {
246 const int input_depth = weights_shape.dims(3);
247 auto& im2col_array = model->GetArray(op->outputs[1]);
248 im2col_array.copy_shape(
249 Shape{specified_output_shape[0], specified_output_shape[1],
250 specified_output_shape[2], input_depth * kheight * kwidth});
251 }
252 }
253
ProcessDepthwiseConvOperator(Model * model,DepthwiseConvOperator * op)254 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
255 const auto& input_array = model->GetArray(op->inputs[0]);
256 // Yield until input dims have been resolved.
257 if (!input_array.has_shape()) {
258 return;
259 }
260 const auto& input_shape = input_array.shape();
261 CHECK_EQ(input_shape.dimensions_count(), 4);
262
263 const auto& weights_array = model->GetArray(op->inputs[1]);
264 // Yield until weights dims have been resolved.
265 if (!weights_array.has_shape()) {
266 return;
267 }
268 const auto& weights_shape = weights_array.shape();
269 CHECK_EQ(weights_shape.dimensions_count(), 4);
270
271 const string& output_name = op->outputs[0];
272 const int input_depth = input_shape.dims(3);
273 const int output_depth = weights_shape.dims(3);
274 // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
275 // instead it has to be inferred from the weights dims. However, once we are
276 // here, weights dims have already been converted to our own internal format,
277 // where the multiplier is no longer readily apparent. So instead we get it
278 // as the quotient of output and input depths. We only want to do that when
279 // depth_multiplier had the zero value: any other value should be checked
280 // as done by the next if() below.
281 if (!op->depth_multiplier) {
282 op->depth_multiplier = output_depth / input_depth;
283 }
284 CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
285 << "input/output depths and depth_multiplier don't match";
286
287 const int kheight = weights_shape.dims(1);
288 const int kwidth = weights_shape.dims(2);
289 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
290 op->stride_height, op->dilation_width_factor,
291 op->dilation_height_factor, op->padding.type,
292 model->GetArray(output_name).mutable_shape(),
293 &op->padding.GetOrCreateFixedPadding());
294 }
295
ProcessDepthToSpaceOperator(Model * model,DepthToSpaceOperator * op)296 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
297 const auto& input_array = model->GetArray(op->inputs[0]);
298 // Yield until input dims have been resolved.
299 if (!input_array.has_shape()) {
300 return;
301 }
302 const auto& input_shape = input_array.shape();
303 CHECK_EQ(input_shape.dimensions_count(), 4);
304
305 const string& output_name = op->outputs[0];
306 const int block_size = op->block_size;
307 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
308 const int batch = input_shape.dims(0);
309 const int height = input_shape.dims(1);
310 const int width = input_shape.dims(2);
311 const int depth = input_shape.dims(3);
312 QCHECK_EQ(depth % (block_size * block_size), 0);
313
314 model->GetArray(output_name)
315 .copy_shape(Shape({batch, height * block_size, width * block_size,
316 depth / block_size / block_size}));
317 }
318
ProcessSpaceToDepthOperator(Model * model,SpaceToDepthOperator * op)319 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
320 const auto& input_array = model->GetArray(op->inputs[0]);
321 // Yield until input dims have been resolved.
322 if (!input_array.has_shape()) {
323 return;
324 }
325 const auto& input_shape = input_array.shape();
326 CHECK_EQ(input_shape.dimensions_count(), 4);
327
328 const string& output_name = op->outputs[0];
329 const int block_size = op->block_size;
330 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
331 const int batch = input_shape.dims(0);
332 const int height = input_shape.dims(1);
333 const int width = input_shape.dims(2);
334 const int depth = input_shape.dims(3);
335 QCHECK_EQ(width % block_size, 0);
336 QCHECK_EQ(height % block_size, 0);
337
338 model->GetArray(output_name)
339 .copy_shape(Shape({batch, height / block_size, width / block_size,
340 depth * block_size * block_size}));
341 }
342
ProcessOpWithShapeInput(Model * model,Operator * op)343 void ProcessOpWithShapeInput(Model* model, Operator* op) {
344 CHECK_EQ(op->outputs.size(), 1);
345 auto& output_array = model->GetArray(op->outputs[0]);
346 if (output_array.has_shape()) {
347 // We have already run
348 return;
349 }
350
351 auto& dims_array = model->GetArray(op->inputs[0]);
352 if (!dims_array.has_shape()) {
353 // Yield until dims shape been resolved.
354 return;
355 }
356 if (!dims_array.buffer) {
357 // Yield until the dims are constant
358 return;
359 }
360 CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
361 CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4)
362 << "dims vector can be no larger than 4 values";
363
364 std::vector<int32> const& dims =
365 dims_array.GetBuffer<ArrayDataType::kInt32>().data;
366 *(output_array.mutable_shape()->mutable_dims()) = dims;
367 }
368
ProcessFullyConnectedOperator(Model * model,FullyConnectedOperator * op)369 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
370 const auto& input_array = model->GetArray(op->inputs[0]);
371 // Yield until input dims have been resolved.
372 if (!input_array.has_shape()) {
373 return;
374 }
375 const auto& input_shape = input_array.shape();
376 CHECK_GE(input_shape.dimensions_count(), 1);
377
378 const auto& weights_array = model->GetArray(op->inputs[1]);
379 // Yield until weights dims have been resolved.
380 if (!weights_array.has_shape()) {
381 return;
382 }
383 const auto& weights_shape = weights_array.shape();
384
385 const int weights_output_depth = weights_shape.dims(0);
386 CHECK_EQ(weights_shape.dimensions_count(), 2);
387
388 const int input_overall_size = RequiredBufferSizeForShape(input_shape);
389 const int matmul_repeats = input_overall_size / weights_shape.dims(1);
390 CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
391
392 auto& output_array = model->GetArray(op->outputs[0]);
393 output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
394 }
395
ProcessTensorFlowReshapeOperator(Model * model,TensorFlowReshapeOperator * op)396 void ProcessTensorFlowReshapeOperator(Model* model,
397 TensorFlowReshapeOperator* op) {
398 auto& output_array = model->GetArray(op->outputs[0]);
399 if (output_array.has_shape()) {
400 // We have already run
401 return;
402 }
403
404 const auto& input_array = model->GetArray(op->inputs[0]);
405 if (!input_array.has_shape()) {
406 // Yield until input dims have been resolved.
407 return;
408 }
409 const auto& input_shape = input_array.shape();
410
411 auto& shape_array = model->GetArray(op->inputs[1]);
412 if (!shape_array.has_shape()) {
413 // Yield until target_shape shape been resolved.
414 return;
415 }
416 if (!shape_array.buffer) {
417 // Yield until the target_shape is constant
418 return;
419 }
420 CHECK(shape_array.data_type == ArrayDataType::kInt32)
421 << "Reshape dims must be int32";
422
423 // shape_data is the raw array of ints describing the shape
424 // in the TensorFlow node. We intentionally make a copy here, rather than
425 // modify wildcards in-place below, because in some graphs, the same shape
426 // array with a wildcard may be referenced from multiple Reshape nodes, where
427 // the wildcard needs to resolved to distinct values.
428 std::vector<int32> shape_data =
429 shape_array.GetBuffer<ArrayDataType::kInt32>().data;
430 // The Reshape shape may have a wildcard dim, encoded as -1.
431 bool has_wildcard = false;
432 int wildcard_index = 0;
433 int product_non_wildcard_dims = 1;
434 for (int i = 0; i < shape_data.size(); i++) {
435 if (shape_data[i] == -1) {
436 CHECK(!has_wildcard);
437 has_wildcard = true;
438 wildcard_index = i;
439 } else {
440 product_non_wildcard_dims *= shape_data[i];
441 }
442 }
443
444 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
445 if (has_wildcard) {
446 CHECK_GE(input_flat_size, product_non_wildcard_dims)
447 << "Array not large enough to fill the requested dimensions for "
448 "Reshape op with output \""
449 << op->outputs[0] << "\". Are your input shapes correct?";
450 shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
451 }
452
453 if (shape_data.size() == 1 && shape_data[0] == 0) {
454 // We have reshaped a scalar, so preserve as a scalar.
455 shape_data.clear();
456 }
457
458 auto& output_shape = *output_array.mutable_shape();
459 *output_shape.mutable_dims() = shape_data;
460 CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
461 << "Input cannot be reshaped to requested dimensions for Reshape op with "
462 "output \""
463 << op->outputs[0] << "\". Are your input shapes correct?";
464 }
465
ProcessSimpleOperator(Model * model,Operator * op,int input_index)466 void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
467 const auto& input_array = model->GetArray(op->inputs[input_index]);
468 // Yield until input dims have been resolved.
469 if (!input_array.has_shape()) {
470 return;
471 }
472
473 const string& output_name = op->outputs[0];
474 auto& output_array = model->GetArray(output_name);
475 if (output_array.has_shape()) {
476 return;
477 }
478
479 output_array.copy_shape(input_array.shape());
480 }
481
ProcessSimpleBinaryOperator(Model * model,Operator * op)482 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
483 CHECK_EQ(op->inputs.size(), 2);
484 const auto& input0_array = model->GetArray(op->inputs[0]);
485 const auto& input1_array = model->GetArray(op->inputs[1]);
486 // Yield until input dims have been resolved.
487 if (!input0_array.has_shape() || !input1_array.has_shape()) {
488 return;
489 }
490 const string& output_name = op->outputs[0];
491 auto& output_array = model->GetArray(output_name);
492 ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
493 &output_array);
494 }
495
ProcessSelectOperator(Model * model,SelectOperator * op)496 void ProcessSelectOperator(Model* model, SelectOperator* op) {
497 // Yield until all input dims have been resolved.
498 for (const auto& input : op->inputs) {
499 const auto& input_array = model->GetArray(input);
500 if (!input_array.has_shape()) {
501 return;
502 }
503 }
504
505 // Select's output matches the second and third output.
506 const auto& input1_array = model->GetArray(op->inputs[1]);
507 auto& output_array = model->GetArray(op->outputs[0]);
508 output_array.copy_shape(input1_array.shape());
509 }
510
ProcessAddNOperator(Model * model,Operator * op)511 void ProcessAddNOperator(Model* model, Operator* op) {
512 // Yield until all input dims have been resolved.
513 //
514 // TODO(myenik): Since AddN does not support broadcasting, maybe we could
515 // actually use this to improve shape propagation by propagating the shape of
516 // one input to all other inputs once it is resolved instead of just the
517 // output, since all inputs must be the same size and shape for a well-formed
518 // graph.
519 for (const auto& input : op->inputs) {
520 const auto& input_array = model->GetArray(input);
521 if (!input_array.has_shape()) {
522 return;
523 }
524 }
525
526 // AddN does not support broadcasting, all inputs must be the same shape, so
527 // we just take the first input shape and apply it to the output.
528 const auto& input0_array = model->GetArray(op->inputs[0]);
529 auto& output_array = model->GetArray(op->outputs[0]);
530 output_array.copy_shape(input0_array.shape());
531 }
532
KeepDims(const Operator & op)533 bool KeepDims(const Operator& op) {
534 switch (op.type) {
535 case OperatorType::kReduceMin: // Reduction Min
536 return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
537 case OperatorType::kReduceMax: // Reduction Max
538 return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
539 case OperatorType::kSum:
540 return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
541 case OperatorType::kReduceProd:
542 return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
543 case OperatorType::kMean:
544 return static_cast<const MeanOperator&>(op).keep_dims;
545 case OperatorType::kAny:
546 return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
547 default:
548 LOG(FATAL) << "Not a reduction operator!";
549 return false;
550 }
551 }
552
ProcessTensorFlowReductionOperator(Model * model,Operator * op)553 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
554 CHECK_LE(op->inputs.size(), 2);
555 auto& output_array = model->GetArray(op->outputs[0]);
556 if (output_array.has_shape()) {
557 return;
558 }
559 const auto& input_array = model->GetArray(op->inputs[0]);
560 if (!input_array.has_shape()) {
561 return;
562 }
563 const auto& input_shape = input_array.shape();
564 const bool keep_dims = KeepDims(*op);
565 if (op->inputs.size() == 2) {
566 // There is a reduction_indices input.
567 const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
568 if (!reduction_indices_array.buffer) {
569 return;
570 }
571 CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
572
573 int input_rank = input_shape.dimensions_count();
574 std::set<int32> true_indices;
575 const auto& reduction_indices =
576 reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
577 for (int i = 0; i < reduction_indices.size(); ++i) {
578 const int32 reduction_index = reduction_indices[i];
579 if (reduction_index < -input_rank || reduction_index >= input_rank) {
580 CHECK(false) << "Invalid reduction dimension " << reduction_index
581 << " for input with " << input_rank << " dimensions";
582 }
583 int32 wrapped_index = reduction_index;
584 if (wrapped_index < 0) {
585 wrapped_index += input_rank;
586 }
587 true_indices.insert(wrapped_index);
588 }
589
590 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
591 mutable_dims->clear();
592 for (int i = 0; i < input_rank; ++i) {
593 if (true_indices.count(i) > 0) {
594 if (keep_dims) {
595 mutable_dims->emplace_back(1);
596 }
597 } else {
598 mutable_dims->emplace_back(input_shape.dims(i));
599 }
600 }
601 } else {
602 // No reduction_indices means complete reduction to a single scalar.
603 if (keep_dims) {
604 output_array.copy_shape(input_shape);
605 } else {
606 output_array.copy_shape(Shape({}));
607 }
608 }
609 }
610
ProcessSliceOperator(Model * model,SliceOperator * op)611 void ProcessSliceOperator(Model* model, SliceOperator* op) {
612 CHECK_EQ(op->inputs.size(), 3);
613 CHECK_EQ(op->outputs.size(), 1);
614
615 // Yield until the Slice params have been resolved.
616 if (op->begin.empty()) return;
617
618 // Yield until input dims have been resolved.
619 const auto& input_array = model->GetArray(op->inputs[0]);
620 if (!input_array.has_shape()) return;
621 const Shape& input_shape = input_array.shape();
622
623 auto& output_array = model->GetArray(op->outputs[0]);
624 if (output_array.has_shape()) return;
625
626 CHECK_EQ(input_shape.dims().size(), op->size.size());
627 CHECK_EQ(op->begin.size(), op->size.size());
628
629 std::vector<int> output_dims;
630 for (int i = 0; i < op->begin.size(); ++i) {
631 int size = op->size[i];
632 if (size == -1) {
633 size = input_array.shape().dims(i) - op->begin[i];
634 }
635 output_dims.push_back(size);
636 }
637
638 *output_array.mutable_shape()->mutable_dims() = output_dims;
639 }
640
ProcessReorderAxesOperator(Model * model,ReorderAxesOperator * op)641 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
642 const string& input_name = op->inputs[0];
643 const auto& input_array = model->GetArray(input_name);
644 // Yield until input dims have been resolved.
645 if (!input_array.has_shape()) {
646 return;
647 }
648 const auto& input_shape = input_array.shape();
649 const string& output_name = op->outputs[0];
650 Shape* output_shape = model->GetArray(output_name).mutable_shape();
651 ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
652 output_shape);
653 }
654
ProcessConcatenationOperator(Model * model,ConcatenationOperator * op)655 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
656 // Yield until input dims have been resolved.
657 for (const auto& input_name : op->inputs) {
658 auto& input_array = model->GetArray(input_name);
659 if (!input_array.has_shape()) {
660 return;
661 }
662 }
663 auto& output_array = model->GetArray(op->outputs[0]);
664 // Use first non-empty input as basis for output dimensions.
665 for (const auto& input_name : op->inputs) {
666 const auto& input_array = model->GetArray(input_name);
667 if (input_array.shape().dimensions_count() > 0) {
668 output_array.copy_shape(input_array.shape());
669 // Negative axis means the count starts at the back of the dims().
670 if (op->axis < 0) op->axis += input_array.shape().dims().size();
671 break;
672 }
673 }
674 // Determine the concat size, and enfore that all inputs have
675 // the same dimensions count.
676 int concat_size = 0;
677 for (const auto& input_name : op->inputs) {
678 auto& input_array = model->GetArray(input_name);
679 CHECK(input_array.has_shape());
680 if (input_array.shape().dimensions_count() == 0) {
681 continue;
682 }
683 CHECK_EQ(input_array.shape().dimensions_count(),
684 output_array.shape().dimensions_count());
685 const std::vector<int>& input_dims = input_array.shape().dims();
686 CHECK_LT(op->axis, input_dims.size());
687 concat_size += input_dims[op->axis];
688 }
689 // Write out the concat_size on the output array shape.
690 auto& output_shape = *output_array.mutable_shape();
691 auto& output_dims = *output_shape.mutable_dims();
692 CHECK_LT(op->axis, output_shape.dimensions_count());
693 output_dims[op->axis] = concat_size;
694 }
695
ProcessRangeOperator(Model * model,RangeOperator * op)696 void ProcessRangeOperator(Model* model, RangeOperator* op) {
697 CHECK_EQ(op->inputs.size(), 3);
698 const auto& start_array = model->GetArray(op->inputs[0]);
699 if (!start_array.has_shape()) {
700 // Yield until input dims have been resolved.
701 return;
702 }
703 const auto& limit_array = model->GetArray(op->inputs[1]);
704 if (!limit_array.has_shape()) {
705 return;
706 }
707 const auto& delta_array = model->GetArray(op->inputs[2]);
708 if (!delta_array.has_shape()) {
709 return;
710 }
711
712 if (!IsConstantParameterArray(*model, op->inputs[0])) {
713 // Yield until inputs are constant.
714 return;
715 }
716 if (!IsConstantParameterArray(*model, op->inputs[1])) {
717 return;
718 }
719 if (!IsConstantParameterArray(*model, op->inputs[2])) {
720 return;
721 }
722
723 const ArrayDataType& start_dtype = start_array.data_type;
724 CHECK(start_dtype == ArrayDataType::kInt32 ||
725 start_dtype == ArrayDataType::kFloat)
726 << "Range op inputs must be int32 or float.";
727 CHECK(limit_array.data_type == start_dtype)
728 << "In Range op, limit tensor must have the same data type as start "
729 "tensor.";
730 CHECK(delta_array.data_type == start_dtype)
731 << "In Range op, delta tensor must have the same data type as start "
732 "tensor.";
733 CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
734 << "Range op inputs must be scalar.";
735 CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
736 << "Range op inputs must be scalar.";
737 CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
738 << "Range op inputs must be scalar.";
739
740 int size = 0;
741 if (start_dtype == ArrayDataType::kInt32) {
742 size = std::floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
743 start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
744 delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
745 } else if (start_dtype == ArrayDataType::kFloat) {
746 size = std::floor((limit_array.GetBuffer<ArrayDataType::kFloat>().data[0] -
747 start_array.GetBuffer<ArrayDataType::kFloat>().data[0]) /
748 delta_array.GetBuffer<ArrayDataType::kFloat>().data[0]);
749 }
750
751 // Only set the output shape. Contents are set by ResolveConstantRange.
752 CHECK_EQ(op->outputs.size(), 1);
753 auto& output_array = model->GetArray(op->outputs[0]);
754 Shape* output_shape = output_array.mutable_shape();
755 output_shape->ReplaceDims({size});
756 }
757
ProcessTensorFlowSplitOperator(Model * model,TensorFlowSplitOperator * op)758 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
759 CHECK_EQ(op->inputs.size(), 2);
760 const string& input_name = op->inputs[1];
761 const auto& input_array = model->GetArray(input_name);
762 // Yield until input dims have been resolved.
763 if (!input_array.has_shape()) {
764 return;
765 }
766 const Shape& input_shape = input_array.shape();
767
768 // Yield until axis is constant.
769 if (!IsConstantParameterArray(*model, op->inputs[0])) {
770 return;
771 }
772
773 const auto& axis_array = model->GetArray(op->inputs[0]);
774
775 // Yield until axis dims have been resolved.
776 if (!axis_array.has_shape()) {
777 return;
778 }
779
780 CHECK(axis_array.data_type == ArrayDataType::kInt32)
781 << "Axis array must be int32.";
782 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
783 << "Axis array must be scalar.";
784
785 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
786 if (axis < 0) {
787 axis += input_shape.dimensions_count();
788 }
789
790 const int split_dim = input_shape.dims(axis);
791 CHECK_EQ(split_dim % op->num_split, 0);
792 const int split_depth = split_dim / op->num_split;
793
794 Shape output_shape = input_shape;
795 (*output_shape.mutable_dims())[axis] = split_depth;
796
797 CHECK_EQ(op->outputs.size(), op->num_split);
798 for (const auto& output : op->outputs) {
799 model->GetArray(output).copy_shape(output_shape);
800 }
801 }
802
ProcessTensorFlowSplitVOperator(Model * model,TensorFlowSplitVOperator * op)803 void ProcessTensorFlowSplitVOperator(Model* model,
804 TensorFlowSplitVOperator* op) {
805 CHECK_EQ(op->inputs.size(), 3);
806
807 const auto& input_array = model->GetArray(op->inputs[0]);
808 // Yield until input dims have been resolved.
809 if (!input_array.has_shape()) {
810 return;
811 }
812 const Shape& input_shape = input_array.shape();
813
814 // Yield until size_splits is constant.
815 if (!IsConstantParameterArray(*model, op->inputs[1])) {
816 return;
817 }
818 const auto& size_array = model->GetArray(op->inputs[1]);
819 // Yield until size_splits dims have been resolved.
820 if (!size_array.has_shape()) {
821 return;
822 }
823 const Shape& size_shape = size_array.shape();
824
825 CHECK(size_array.data_type == ArrayDataType::kInt32 ||
826 size_array.data_type == ArrayDataType::kInt64)
827 << "size_splits must be int32, int64";
828 CHECK_EQ(size_shape.dimensions_count(), 1) << "size_splits must be 1-D";
829
830 std::vector<int64> size_splits_vector;
831 if (size_array.data_type == ArrayDataType::kInt32) {
832 for (const auto each_size :
833 size_array.GetBuffer<ArrayDataType::kInt32>().data) {
834 size_splits_vector.push_back(each_size);
835 }
836 } else {
837 size_splits_vector = size_array.GetBuffer<ArrayDataType::kInt64>().data;
838 }
839
840 // Yield until axis is constant.
841 if (!IsConstantParameterArray(*model, op->inputs[2])) {
842 return;
843 }
844 const auto& axis_array = model->GetArray(op->inputs[2]);
845 // Yield until axis dims have been resolved.
846 if (!axis_array.has_shape()) {
847 return;
848 }
849
850 CHECK(axis_array.data_type == ArrayDataType::kInt32)
851 << "Axis array must be int32.";
852 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
853 << "Axis array must be scalar.";
854
855 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
856 if (axis < 0) {
857 axis += input_shape.dimensions_count();
858 }
859
860 CHECK_EQ(op->num_split, size_splits_vector.size());
861
862 int64_t minus_one_count = 0, size_splits_sum = 0;
863 for (auto size : size_splits_vector) {
864 if (size == -1) {
865 ++minus_one_count;
866 } else {
867 size_splits_sum += size;
868 }
869 }
870
871 const int input_size = input_shape.dims(axis);
872
873 CHECK_LE(minus_one_count, 1) << "size_splits can contain at most one -1.";
874
875 if (minus_one_count == 1) {
876 CHECK_LE(size_splits_sum, input_size);
877 auto iter =
878 std::find(size_splits_vector.begin(), size_splits_vector.end(), -1);
879 *iter = input_size - size_splits_sum;
880 } else {
881 CHECK_EQ(size_splits_sum, input_size);
882 }
883
884 CHECK_EQ(op->outputs.size(), op->num_split);
885
886 for (int i = 0; i < op->outputs.size(); ++i) {
887 const auto& output = op->outputs[i];
888 Shape output_shape = input_shape;
889 (*output_shape.mutable_dims())[axis] = size_splits_vector.at(i);
890 model->GetArray(output).copy_shape(output_shape);
891 }
892 }
893
ProcessAveragePoolOperator(Model * model,AveragePoolOperator * op)894 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
895 const string& input_name = op->inputs[0];
896 const auto& input_array = model->GetArray(input_name);
897 // Yield until input dims have been resolved.
898 if (!input_array.has_shape()) {
899 return;
900 }
901 const auto& input_shape = input_array.shape();
902 CHECK_EQ(input_shape.dimensions_count(), 4);
903 const string& output_name = op->outputs[0];
904 const int output_depth = input_shape.dims(3);
905 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
906 op->stride_width, op->stride_height, 1, 1, op->padding.type,
907 model->GetArray(output_name).mutable_shape(),
908 &op->padding.GetOrCreateFixedPadding());
909 }
910
ProcessMaxPoolOperator(Model * model,MaxPoolOperator * op)911 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
912 const string& input_name = op->inputs[0];
913 const auto& input_array = model->GetArray(input_name);
914 // Yield until input dims have been resolved.
915 if (!input_array.has_shape()) {
916 return;
917 }
918 const auto& input_shape = input_array.shape();
919 CHECK_EQ(input_shape.dimensions_count(), 4);
920 const string& output_name = op->outputs[0];
921 const int output_depth = input_shape.dims(3);
922 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
923 op->stride_width, op->stride_height, 1, 1, op->padding.type,
924 model->GetArray(output_name).mutable_shape(),
925 &op->padding.GetOrCreateFixedPadding());
926 }
927
ProcessL2PoolOperator(Model * model,L2PoolOperator * op)928 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
929 const string& input_name = op->inputs[0];
930 const auto& input_array = model->GetArray(input_name);
931 // Yield until input dims have been resolved.
932 if (!input_array.has_shape()) {
933 return;
934 }
935 const auto& input_shape = input_array.shape();
936 if (input_shape.dimensions_count() < 4) {
937 LOG(FATAL) << "missing dimensions for " << input_name;
938 }
939 const string& output_name = op->outputs[0];
940 const int output_depth = input_shape.dims(3);
941 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
942 op->stride_width, op->stride_height, 1, 1, op->padding.type,
943 model->GetArray(output_name).mutable_shape(),
944 &op->padding.GetOrCreateFixedPadding());
945 }
946
ProcessResizeBilinearOperator(Model * model,ResizeBilinearOperator * op)947 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
948 CHECK_EQ(op->inputs.size(), 2);
949 CHECK_EQ(op->outputs.size(), 1);
950
951 if (!model->GetArray(op->inputs[0]).has_shape() ||
952 !model->GetArray(op->inputs[1]).has_shape()) {
953 return;
954 }
955 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
956
957 const string& output_size_name = op->inputs[1];
958 const auto& output_size_array = model->GetArray(output_size_name);
959 CHECK(output_size_array.data_type == ArrayDataType::kInt32);
960 CHECK(output_size_array.has_shape());
961 const auto& output_size_shape = output_size_array.shape();
962 CHECK_EQ(output_size_shape.dimensions_count(), 1);
963 CHECK_EQ(output_size_shape.dims(0), 2);
964 if (!output_size_array.buffer) {
965 return;
966 }
967 std::vector<int32> output_shape =
968 output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
969 model->GetArray(op->outputs[0])
970 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
971 output_shape[1], input_data_shape.dims(3)}));
972 }
973
ProcessResizeNearestNeighborOperator(Model * model,ResizeNearestNeighborOperator * op)974 void ProcessResizeNearestNeighborOperator(Model* model,
975 ResizeNearestNeighborOperator* op) {
976 CHECK_EQ(op->inputs.size(), 2);
977 CHECK_EQ(op->outputs.size(), 1);
978
979 if (!model->GetArray(op->inputs[0]).has_shape() ||
980 !model->GetArray(op->inputs[1]).has_shape()) {
981 return;
982 }
983 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
984
985 const string& output_size_name = op->inputs[1];
986 const auto& output_size_array = model->GetArray(output_size_name);
987 CHECK(output_size_array.data_type == ArrayDataType::kInt32);
988 CHECK(output_size_array.has_shape());
989 const auto& output_size_shape = output_size_array.shape();
990 CHECK_EQ(output_size_shape.dimensions_count(), 1);
991 CHECK_EQ(output_size_shape.dims(0), 2);
992 if (!output_size_array.buffer) {
993 return;
994 }
995 std::vector<int32> output_shape =
996 output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
997 model->GetArray(op->outputs[0])
998 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
999 output_shape[1], input_data_shape.dims(3)}));
1000 }
1001
ProcessLstmCellOperator(Model * model,LstmCellOperator * op)1002 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
1003 // Only required for compact LstmCell with default NUM_INPUTS of inputs.
1004 if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
1005
1006 const auto& input_array =
1007 model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
1008 // Yield until all input dims have been resolved.
1009 if (!input_array.has_shape()) {
1010 return;
1011 }
1012 const auto& input_shape = input_array.shape();
1013 CHECK_GE(input_shape.dimensions_count(), 2);
1014
1015 const auto& prev_activ_array =
1016 model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
1017 // Yield until all input dims have been resolved.
1018 if (!prev_activ_array.has_shape()) {
1019 return;
1020 }
1021 const auto& prev_activ_shape = prev_activ_array.shape();
1022 CHECK_GE(prev_activ_shape.dimensions_count(), 2);
1023
1024 const auto& weights_array =
1025 model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
1026 // Yield until weights dims have been resolved.
1027 if (!weights_array.has_shape()) {
1028 return;
1029 }
1030 const auto& weights_shape = weights_array.shape();
1031 CHECK_EQ(weights_shape.dimensions_count(), 2);
1032
1033 const auto& bias_array =
1034 model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
1035 // Yield until bias dims have been resolved.
1036 if (!bias_array.has_shape()) {
1037 return;
1038 }
1039 const auto& bias_shape = bias_array.shape();
1040 CHECK_GE(bias_shape.dimensions_count(), 1);
1041
1042 const auto& prev_state_array =
1043 model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
1044 // Yield until all input dims have been resolved.
1045 if (!prev_state_array.has_shape()) {
1046 return;
1047 }
1048 const auto& prev_state_shape = prev_state_array.shape();
1049 CHECK_GE(prev_state_shape.dimensions_count(), 2);
1050
1051 const int fc_output_depth = weights_shape.dims(0);
1052 CHECK_EQ(fc_output_depth, bias_shape.dims(0));
1053 CHECK_EQ(fc_output_depth % 4, 0);
1054 const int depth = fc_output_depth / 4;
1055
1056 const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
1057 const int fc_input_depth = weights_shape.dims(1);
1058 CHECK_EQ(input_depth + depth, fc_input_depth);
1059 Shape output_shape(input_shape);
1060 (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
1061
1062 // Set output dimensions
1063 model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
1064 .copy_shape(output_shape);
1065 model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
1066 .copy_shape(output_shape);
1067
1068 Shape concat_temp_shape(input_shape);
1069 (*concat_temp_shape
1070 .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
1071 fc_input_depth;
1072 model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
1073 .copy_shape(concat_temp_shape);
1074
1075 Shape activ_temp_shape(input_shape);
1076 (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
1077 fc_output_depth;
1078 model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
1079 .copy_shape(activ_temp_shape);
1080 }
1081
ProcessUnidirectionalSequenceLstmOperator(Model * model,UnidirectionalSequenceLstmOperator * op)1082 void ProcessUnidirectionalSequenceLstmOperator(
1083 Model* model, UnidirectionalSequenceLstmOperator* op) {
1084 auto& output_array = model->GetArray(op->outputs[0]);
1085 if (output_array.has_shape()) {
1086 // Shape already propagated
1087 return;
1088 }
1089
1090 if (output_array.data_type == ArrayDataType::kNone) {
1091 // Yield until the output type has been set by PropagateArrayDataTypes
1092 return;
1093 }
1094
1095 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1096 const auto& input_array = model->GetArray(op->inputs[0]);
1097
1098 constexpr int kInputActivationStateTensor = 18;
1099 constexpr int kInputCellStateTensor = 19;
1100
1101 // TFlite intepreter does not support array which is variable and contains a
1102 // buffer (see b/115961645 for more discussion).
1103 // The follow block remove buffer from the array to work around the
1104 // restriction, as a consequence, downstream applications should not
1105 // read lstm state as input to other operations.
1106 model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
1107 model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
1108
1109 // Yield until input dims have been resolved.
1110 if (!input_array.has_shape()) {
1111 return;
1112 }
1113 const auto& input_shape = input_array.shape();
1114 const int batch_size = input_shape.dims(1);
1115 const int timestamp = input_shape.dims(0);
1116
1117 const auto& recurrent_to_output_weights_array =
1118 model->GetArray(op->inputs[8]);
1119 // Yield until input dims have been resolved.
1120 if (!recurrent_to_output_weights_array.has_shape()) {
1121 return;
1122 }
1123
1124 const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1125 const int output_size = output_weights_shape.dims(1);
1126
1127 Shape* output_shape = output_array.mutable_shape();
1128 output_shape->ReplaceDims({timestamp, batch_size, output_size});
1129 }
1130
ProcessUnidirectionalSequenceRnnOperator(Model * model,UnidirectionalSequenceRnnOperator * op)1131 void ProcessUnidirectionalSequenceRnnOperator(
1132 Model* model, UnidirectionalSequenceRnnOperator* op) {
1133 auto& output_array = model->GetArray(op->outputs[0]);
1134 if (output_array.has_shape()) {
1135 // Shape already propagated.
1136 return;
1137 }
1138
1139 if (output_array.data_type == ArrayDataType::kNone) {
1140 // Yield until the output type has been set by PropagateArrayDataTypes
1141 return;
1142 }
1143
1144 constexpr int kHiddenStateTensor = 4;
1145 // TFlite intepreter does not support array which is variable and contains a
1146 // buffer (see b/115961645 for more discussion).
1147 // The follow block remove buffer from the array to work around the
1148 // restriction, as a consequence, downstream applications should not
1149 // read lstm state as input to other operations.
1150 model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset();
1151
1152 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1153 const auto& input_array = model->GetArray(op->inputs[0]);
1154 // Yield until input dims have been resolved.
1155 if (!input_array.has_shape()) {
1156 return;
1157 }
1158 const auto& input_shape = input_array.shape();
1159 const int batch_size = input_shape.dims(1);
1160 const int timestamp = input_shape.dims(0);
1161
1162 const auto& bias_array = model->GetArray(op->inputs[3]);
1163 // Yield until input dims have been resolved.
1164 if (!bias_array.has_shape()) {
1165 return;
1166 }
1167
1168 const auto& bias_shape = bias_array.shape();
1169 const int output_size = bias_shape.dims(0);
1170
1171 Shape* output_shape = output_array.mutable_shape();
1172 output_shape->ReplaceDims({timestamp, batch_size, output_size});
1173 }
1174
ProcessBidirectionalSequenceLstmOperator(Model * model,BidirectionalSequenceLstmOperator * op)1175 void ProcessBidirectionalSequenceLstmOperator(
1176 Model* model, BidirectionalSequenceLstmOperator* op) {
1177 // We assume time major.
1178 auto& fw_output_array = model->GetArray(op->outputs[0]);
1179 auto& bw_output_array = model->GetArray(op->outputs[1]);
1180 if (fw_output_array.has_shape()) {
1181 // Shape already propagated
1182 return;
1183 }
1184
1185 if (fw_output_array.data_type == ArrayDataType::kNone) {
1186 // Yield until the output type has been set by PropagateArrayDataTypes
1187 return;
1188 }
1189
1190 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1191 const auto& input_array = model->GetArray(op->inputs[0]);
1192 // Yield until input dims have been resolved.
1193 if (!input_array.has_shape()) {
1194 return;
1195 }
1196 const auto& input_shape = input_array.shape();
1197 const int batch_size = input_shape.dims(1);
1198 const int timestamp = input_shape.dims(0);
1199
1200 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
1201 const auto& recurrent_to_output_weights_array =
1202 model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]);
1203 // Yield until input dims have been resolved.
1204 if (!recurrent_to_output_weights_array.has_shape()) {
1205 return;
1206 }
1207
1208 constexpr int kFwInputActivationStateTensor = 35;
1209 constexpr int kFwInputCellStateTensor = 36;
1210 constexpr int kBwInputActivationStateTensor = 37;
1211 constexpr int kBwInputCellStateTensor = 38;
1212 // b(115961645): This is a hack to work around.
1213 model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset();
1214 model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset();
1215 model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset();
1216 model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset();
1217
1218 const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1219 const int output_size = output_weights_shape.dims(1);
1220
1221 Shape* fw_output_shape = fw_output_array.mutable_shape();
1222 if (op->merge_outputs) {
1223 fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1224 } else {
1225 fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1226 Shape* bw_output_shape = bw_output_array.mutable_shape();
1227 bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1228 }
1229 }
1230
ProcessBidirectionalSequenceRnnOperator(Model * model,BidirectionalSequenceRnnOperator * op)1231 void ProcessBidirectionalSequenceRnnOperator(
1232 Model* model, BidirectionalSequenceRnnOperator* op) {
1233 // We assume time major.
1234 auto& fw_output_array = model->GetArray(op->outputs[0]);
1235 auto& bw_output_array = model->GetArray(op->outputs[1]);
1236 if (fw_output_array.has_shape()) {
1237 // Shape already propagated
1238 return;
1239 }
1240
1241 if (fw_output_array.data_type == ArrayDataType::kNone) {
1242 // Yield until the output type has been set by PropagateArrayDataTypes
1243 return;
1244 }
1245
1246 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1247 const auto& input_array = model->GetArray(op->inputs[0]);
1248 // Yield until input dims have been resolved.
1249 if (!input_array.has_shape()) {
1250 return;
1251 }
1252 const auto& input_shape = input_array.shape();
1253 const int batch_size = input_shape.dims(1);
1254 const int timestamp = input_shape.dims(0);
1255
1256 constexpr int kFwWeightsTensor = 1;
1257 const auto& forward_weights_array =
1258 model->GetArray(op->inputs[kFwWeightsTensor]);
1259 // Yield until input dims have been resolved.
1260 if (!forward_weights_array.has_shape()) {
1261 return;
1262 }
1263
1264 constexpr int kFwHiddenStateTensor = 4;
1265 constexpr int kBwHiddenStateTensor = 8;
1266 // b(115961645): This is a hack to work around.
1267 model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset();
1268 model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset();
1269
1270 const auto& output_weights_shape = forward_weights_array.shape();
1271 const int output_size = output_weights_shape.dims(0);
1272
1273 Shape* fw_output_shape = fw_output_array.mutable_shape();
1274 if (op->merge_outputs) {
1275 fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1276 } else {
1277 fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1278 Shape* bw_output_shape = bw_output_array.mutable_shape();
1279 bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1280 }
1281 }
1282
ProcessSpaceToBatchNDOperator(Model * model,SpaceToBatchNDOperator * op)1283 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
1284 const auto& input_array = model->GetArray(op->inputs[0]);
1285 // Yield until input dims have been resolved.
1286 if (!input_array.has_shape()) {
1287 return;
1288 }
1289 const auto& input_shape = input_array.shape();
1290 // This method only handles input dimensions of 4.
1291 if (input_shape.dimensions_count() != 4) {
1292 return;
1293 }
1294 const auto input_height = input_shape.dims(1);
1295 const auto input_width = input_shape.dims(2);
1296
1297 const auto& block_shape_array = model->GetArray(op->inputs[1]);
1298 const auto& paddings_array = model->GetArray(op->inputs[2]);
1299 const auto& block_shape_array_shape = block_shape_array.shape();
1300 const auto& paddings_array_shape = paddings_array.shape();
1301 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1302 QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
1303
1304 // We only support two dimensions.
1305 QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1306 if (!block_shape_array.buffer) {
1307 return;
1308 }
1309 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1310 const auto& block_shape_data =
1311 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1312 auto block_height = block_shape_data[0];
1313 auto block_width = block_shape_data[1];
1314
1315 QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions
1316 QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension.
1317 if (!paddings_array.buffer) {
1318 return;
1319 }
1320 QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
1321 const auto& paddings_data =
1322 paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
1323 int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
1324 int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
1325 QCHECK_EQ(height_with_paddings % block_height, 0);
1326 QCHECK_EQ(width_with_paddings % block_width, 0);
1327 int output_height = height_with_paddings / block_height;
1328 int output_width = width_with_paddings / block_width;
1329
1330 model->GetArray(op->outputs[0])
1331 .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
1332 output_height, output_width, input_shape.dims(3)}));
1333 }
1334
ProcessBatchToSpaceNDOperator(Model * model,BatchToSpaceNDOperator * op)1335 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
1336 const auto& input_array = model->GetArray(op->inputs[0]);
1337 // Yield until input dims have been resolved.
1338 if (!input_array.has_shape()) {
1339 return;
1340 }
1341 const auto& input_shape = input_array.shape();
1342 CHECK_EQ(input_shape.dimensions_count(), 4);
1343 const auto input_height = input_shape.dims(1);
1344 const auto input_width = input_shape.dims(2);
1345
1346 const auto& block_shape_array = model->GetArray(op->inputs[1]);
1347 const auto& crops_array = model->GetArray(op->inputs[2]);
1348 const auto& block_shape_array_shape = block_shape_array.shape();
1349 const auto& crops_array_shape = crops_array.shape();
1350 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1351 QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
1352
1353 // We only support two dimensions.
1354 QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1355 if (!block_shape_array.buffer) {
1356 return;
1357 }
1358 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1359 const auto& block_shape_data =
1360 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1361 auto block_height = block_shape_data[0];
1362 auto block_width = block_shape_data[1];
1363
1364 QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions
1365 QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension.
1366 if (!crops_array.buffer) {
1367 return;
1368 }
1369 QCHECK(crops_array.data_type == ArrayDataType::kInt32);
1370 const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
1371 const int crops_top = crops_data[0];
1372 const int crops_bottom = crops_data[1];
1373 const int crops_left = crops_data[2];
1374 const int crops_right = crops_data[3];
1375 const int output_height =
1376 input_height * block_height - crops_top - crops_bottom;
1377 const int output_width = input_width * block_width - crops_left - crops_right;
1378 QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
1379
1380 model->GetArray(op->outputs[0])
1381 .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
1382 output_height, output_width, input_shape.dims(3)}));
1383 }
1384
ProcessGatherOperator(Model * model,GatherOperator * op)1385 void ProcessGatherOperator(Model* model, GatherOperator* op) {
1386 const auto& input_array = model->GetArray(op->inputs[0]);
1387 const auto& indices_array = model->GetArray(op->inputs[1]);
1388 auto& output_array = model->GetArray(op->outputs[0]);
1389
1390 // Bail if we already know the output shape.
1391 if (output_array.has_shape()) {
1392 return;
1393 }
1394
1395 // Yield until input dims have been resolved.
1396 if (!input_array.has_shape() || !indices_array.has_shape()) {
1397 return;
1398 }
1399
1400 // Yield until the axis has been resolved.
1401 if (!op->axis) {
1402 return;
1403 }
1404 int axis = op->axis.value();
1405
1406 const auto& input_shape = input_array.shape();
1407 const auto& indices_shape = indices_array.shape();
1408 QCHECK_GE(input_shape.dimensions_count(), 1);
1409 op->input_rank = input_shape.dimensions_count();
1410 QCHECK_LT(axis, op->input_rank);
1411
1412 // Copy the input dimensions to the output except for the axis dimensions
1413 // where the dimension of indices_shape is used.
1414 auto output_dims = output_array.mutable_shape()->mutable_dims();
1415 for (int dim = 0; dim < axis; ++dim) {
1416 output_dims->push_back(input_shape.dims(dim));
1417 }
1418 for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
1419 output_dims->push_back(indices_shape.dims(dim));
1420 }
1421 for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
1422 output_dims->push_back(input_shape.dims(dim));
1423 }
1424 }
1425
ProcessGatherNdOperator(Model * model,GatherNdOperator * op)1426 void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) {
1427 const auto& input_array = model->GetArray(op->inputs[0]);
1428 const auto& indices_array = model->GetArray(op->inputs[1]);
1429 auto& output_array = model->GetArray(op->outputs[0]);
1430
1431 // Bail if we already know the output shape.
1432 if (output_array.has_shape()) {
1433 return;
1434 }
1435
1436 // Yield until input dims have been resolved.
1437 if (!input_array.has_shape() || !indices_array.has_shape()) {
1438 return;
1439 }
1440
1441 const auto& input_shape = input_array.shape();
1442 const auto& indices_shape = indices_array.shape();
1443 QCHECK_GE(input_shape.dimensions_count(), 1);
1444 QCHECK_GE(indices_shape.dimensions_count(), 1);
1445 const int indices_nd =
1446 indices_shape.dims(indices_shape.dimensions_count() - 1);
1447 QCHECK_LE(indices_nd, input_shape.dimensions_count());
1448
1449 auto output_dims = output_array.mutable_shape()->mutable_dims();
1450 for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) {
1451 output_dims->push_back(indices_shape.dims(dim));
1452 }
1453 for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) {
1454 output_dims->push_back(input_shape.dims(dim));
1455 }
1456 }
1457
ProcessTopkV2Operator(Model * model,TopKV2Operator * op)1458 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
1459 const auto& input_values = model->GetArray(op->inputs[0]);
1460 const auto& input_k = model->GetArray(op->inputs[1]);
1461 auto& output_values = model->GetArray(op->outputs[0]);
1462 auto& output_indexes = model->GetArray(op->outputs[1]);
1463
1464 // Bail if we already know the output shape.
1465 if (output_indexes.has_shape()) {
1466 QCHECK(output_values.has_shape());
1467 return;
1468 }
1469
1470 // Yield until input dims have been resolved.
1471 if (!input_values.has_shape() || !input_k.has_shape()) {
1472 return;
1473 }
1474
1475 // If the value is initialized, we can specify the last dimension, otherwise
1476 // unknown.
1477 if (input_k.buffer) {
1478 const auto& input_values_shape = input_values.shape();
1479 auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1480 auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1481 for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1482 output_indexes_dims->push_back(input_values_shape.dims(dim));
1483 output_values_dims->push_back(input_values_shape.dims(dim));
1484 }
1485 const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1486 output_indexes_dims->push_back(k_value);
1487 output_values_dims->push_back(k_value);
1488 }
1489 }
1490
ProcessPadOperator(Model * model,PadOperator * op)1491 void ProcessPadOperator(Model* model, PadOperator* op) {
1492 CHECK_EQ(op->inputs.size(), 2);
1493 CHECK_EQ(op->outputs.size(), 1);
1494
1495 const auto& input_array = model->GetArray(op->inputs[0]);
1496
1497 // Yield until input dims have been resolved.
1498 if (!input_array.has_shape()) return;
1499
1500 if (op->left_padding.empty()) return;
1501 CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1502
1503 auto& output_array = model->GetArray(op->outputs[0]);
1504 if (output_array.has_shape()) return;
1505
1506 Shape output_shape = input_array.shape();
1507 std::vector<int>& dims = *output_shape.mutable_dims();
1508 CHECK_EQ(op->left_padding.size(), dims.size());
1509
1510 for (int i = 0; i < op->left_padding.size(); ++i) {
1511 dims[i] += op->left_padding[i] + op->right_padding[i];
1512 }
1513
1514 output_array.copy_shape(output_shape);
1515 }
1516
ProcessPadV2Operator(Model * model,PadV2Operator * op)1517 void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
1518 CHECK_EQ(op->inputs.size(), 3);
1519 CHECK_EQ(op->outputs.size(), 1);
1520
1521 const auto& input_array = model->GetArray(op->inputs[0]);
1522
1523 // Yield until input dims have been resolved.
1524 if (!input_array.has_shape()) return;
1525
1526 if (op->left_padding.empty()) return;
1527 CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1528
1529 auto& output_array = model->GetArray(op->outputs[0]);
1530 if (output_array.has_shape()) return;
1531
1532 Shape output_shape = input_array.shape();
1533 std::vector<int>& dims = *output_shape.mutable_dims();
1534 CHECK_EQ(op->left_padding.size(), dims.size());
1535
1536 for (int i = 0; i < op->left_padding.size(); ++i) {
1537 dims[i] += op->left_padding[i] + op->right_padding[i];
1538 }
1539
1540 output_array.copy_shape(output_shape);
1541 }
1542
ProcessRankOperator(Model * model,TensorFlowRankOperator * op)1543 void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
1544 CHECK_GE(op->inputs.size(), 1);
1545 CHECK_EQ(op->outputs.size(), 1);
1546 auto& output_array = model->GetArray(op->outputs[0]);
1547 if (output_array.has_shape()) {
1548 // Shape already propagated
1549 return;
1550 }
1551
1552 if (output_array.data_type == ArrayDataType::kNone) {
1553 // Yield until the output type has been set by PropagateArrayDataTypes
1554 return;
1555 }
1556
1557 const auto& input_array = model->GetArray(op->inputs[0]);
1558 if (!input_array.has_shape()) {
1559 // Yield until input dims have been resolved.
1560 return;
1561 }
1562
1563 // Only set the output shape. Array contents are set by
1564 // ResolveConstantShapeOrRank.
1565 Shape* output_shape = output_array.mutable_shape();
1566 output_shape->ReplaceDims({});
1567 }
1568
ProcessShapeOperator(Model * model,TensorFlowShapeOperator * op)1569 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1570 CHECK_GE(op->inputs.size(), 1);
1571 CHECK_EQ(op->outputs.size(), 1);
1572 auto& output_array = model->GetArray(op->outputs[0]);
1573 if (output_array.has_shape()) {
1574 // Shape already propagated
1575 return;
1576 }
1577
1578 if (output_array.data_type == ArrayDataType::kNone) {
1579 // Yield until the output type has been set by PropagateArrayDataTypes
1580 return;
1581 }
1582
1583 const auto& input_array = model->GetArray(op->inputs[0]);
1584 if (!input_array.has_shape()) {
1585 // Yield until input dims have been resolved.
1586 return;
1587 }
1588
1589 // Only set the output shape. Array contents are set by
1590 // ResolveConstantShapeOrRank.
1591 Shape* output_shape = output_array.mutable_shape();
1592 output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1593 }
1594
ProcessPackOperator(Model * model,PackOperator * op)1595 void ProcessPackOperator(Model* model, PackOperator* op) {
1596 CHECK_GE(op->inputs.size(), 1);
1597 CHECK_EQ(op->outputs.size(), 1);
1598 auto& output_array = model->GetArray(op->outputs[0]);
1599 if (output_array.has_shape()) {
1600 // Shape already propagated
1601 return;
1602 }
1603
1604 std::unique_ptr<Shape> packed_shape;
1605 for (const auto& input : op->inputs) {
1606 const auto& input_array = model->GetArray(input);
1607 if (!input_array.has_shape()) {
1608 // Yield until all input dims have been resolved.
1609 return;
1610 }
1611
1612 Shape shape = input_array.shape();
1613 if (!packed_shape) {
1614 packed_shape.reset(new Shape(shape));
1615 } else {
1616 CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
1617 "must have the same shape. Input \""
1618 << input << "\" is different.";
1619 }
1620 }
1621
1622 int axis = op->axis;
1623 if (axis < 0) {
1624 // Handle negative axis
1625 axis += packed_shape->dims().size() + 1;
1626 }
1627 packed_shape->mutable_dims()->insert(
1628 packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
1629 output_array.copy_shape(*packed_shape);
1630 }
1631
ProcessStridedSliceOperator(Model * model,StridedSliceOperator * op)1632 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1633 CHECK_GE(op->inputs.size(), 1);
1634 CHECK_EQ(op->outputs.size(), 1);
1635 auto& output_array = model->GetArray(op->outputs[0]);
1636 if (output_array.has_shape()) {
1637 // Shape already propagated
1638 return;
1639 }
1640
1641 if (op->start_indices.empty() || op->stop_indices.empty() ||
1642 op->strides.empty()) {
1643 // ResolveStridedSliceAttributes has not run yet.
1644 return;
1645 }
1646
1647 const auto& input_array = model->GetArray(op->inputs[0]);
1648 if (!input_array.has_shape()) {
1649 // Yield until input dims have been resolved.
1650 return;
1651 }
1652
1653 if (op->ellipsis_mask != 0) {
1654 // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
1655 // log noise. However, the TensorFlow logging library does not appear to
1656 // support this.
1657 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1658 << "\". ellipsis_mask is not supported (mask="
1659 << op->ellipsis_mask << ")";
1660 return;
1661 }
1662 if (op->new_axis_mask != 0) {
1663 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1664 << "\". new_axis_mask is not supported (mask="
1665 << op->new_axis_mask << ")";
1666 return;
1667 }
1668
1669 int num_input_axes = input_array.shape().dimensions_count();
1670 CHECK_LE(op->start_indices.size(), num_input_axes)
1671 << "StridedSlice op with output \"" << op->outputs[0]
1672 << "\", requires no more than " << num_input_axes << " start indices";
1673 CHECK_LE(op->stop_indices.size(), num_input_axes)
1674 << "StridedSlice op with output \"" << op->outputs[0]
1675 << "\", requires no more than " << num_input_axes << " stop indices";
1676 CHECK_LE(op->strides.size(), num_input_axes)
1677 << "StridedSlice op with output \"" << op->outputs[0]
1678 << "\", requires no more than " << num_input_axes << " strides";
1679 for (int i = 0; i < op->strides.size(); i++) {
1680 CHECK_NE(op->strides[i], 0) << "Strides must be non-zero. Axis " << i
1681 << " has stride=" << op->strides[i] << ".";
1682 }
1683
1684 // Create output shape
1685 std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1686
1687 // Compute output shape
1688 for (int axis = 0; axis < num_input_axes; ++axis) {
1689 const auto strided_slice_params =
1690 tflite::strided_slice::BuildStridedSliceParams(
1691 op->begin_mask, op->end_mask, op->shrink_axis_mask,
1692 op->start_indices, op->stop_indices, op->strides);
1693 int start_index = tflite::strided_slice::StartForAxis(
1694 strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
1695 int stop_index = tflite::strided_slice::StopForAxis(
1696 strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
1697 start_index);
1698
1699 int dim_size = std::ceil(static_cast<float>(stop_index - start_index) /
1700 op->strides[axis]);
1701
1702 CHECK_GT(dim_size, 0)
1703 << "Output size for an axis must be greater than 0. Axis " << axis
1704 << " computes to size " << dim_size
1705 << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1706 if (op->shrink_axis_mask & (1 << axis)) {
1707 CHECK_EQ(dim_size, 1)
1708 << "Output size for an axis must compute to 1 when shrinking an "
1709 "axis. Axis "
1710 << axis << " computes to size " << dim_size
1711 << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1712 } else {
1713 dims->push_back(dim_size);
1714 }
1715 }
1716 }
1717
ProcessSqueezeOperator(Model * model,SqueezeOperator * op)1718 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1719 CHECK_EQ(op->inputs.size(), 1);
1720 CHECK_EQ(op->outputs.size(), 1);
1721
1722 const auto& input_array = model->GetArray(op->inputs[0]);
1723
1724 // Yield until input dims have been resolved.
1725 if (!input_array.has_shape()) return;
1726
1727 auto& output_array = model->GetArray(op->outputs[0]);
1728 if (output_array.has_shape()) return;
1729
1730 const std::vector<int>& input_dims = input_array.shape().dims();
1731 std::vector<int> output_dims;
1732
1733 std::vector<int> squeeze_dims;
1734 const int input_num_dims = input_dims.size();
1735 for (int i : op->squeeze_dims) {
1736 squeeze_dims.push_back(i < 0 ? i + input_num_dims : i);
1737 }
1738 for (int i = 0; i < input_num_dims; ++i) {
1739 if (input_dims[i] != 1 ||
1740 (!squeeze_dims.empty() &&
1741 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) ==
1742 squeeze_dims.end())) {
1743 output_dims.push_back(input_dims[i]);
1744 }
1745 }
1746 *output_array.mutable_shape()->mutable_dims() = output_dims;
1747 }
1748
ProcessSvdfOperator(Model * model,SvdfOperator * op)1749 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1750 CHECK(op->inputs.size() == 4 || op->inputs.size() == 5);
1751 const auto& input_array = model->GetArray(op->inputs[0]);
1752 if (!input_array.has_shape()) return;
1753
1754 auto& weights_feature_array = model->GetArray(op->inputs[1]);
1755 if (!weights_feature_array.has_shape()) return;
1756
1757 const auto& weights_time_array = model->GetArray(op->inputs[2]);
1758 if (!weights_time_array.has_shape()) return;
1759
1760 const bool has_bias = (op->inputs.size() == 5);
1761 if (has_bias) {
1762 const auto& bias_array = model->GetArray(op->inputs[3]);
1763 if (!bias_array.has_shape()) return;
1764 }
1765
1766 const int batch_size = input_array.shape().dims()[0];
1767 const int num_units = weights_feature_array.shape().dims()[0];
1768 const int memory_size = weights_time_array.shape().dims()[1];
1769
1770 auto& state_array = model->GetArray(op->outputs[0]);
1771 state_array.mutable_shape()->ReplaceDims(
1772 {batch_size, memory_size * num_units});
1773
1774 auto& output_array = model->GetArray(op->outputs[1]);
1775 output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1776 }
1777
ProcessTransposeOperator(Model * model,TransposeOperator * op)1778 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1779 auto& output_array = model->GetArray(op->outputs[0]);
1780 if (output_array.has_shape()) {
1781 // We have already run
1782 return;
1783 }
1784
1785 const auto& input_array = model->GetArray(op->inputs[0]);
1786 if (!input_array.has_shape()) {
1787 // Yield until input dims have been resolved.
1788 return;
1789 }
1790 const auto& input_shape = input_array.shape();
1791
1792 auto& perm_array = model->GetArray(op->inputs[1]);
1793 if (!perm_array.has_shape()) {
1794 // Yield until permutation shape been resolved.
1795 return;
1796 }
1797 if (!perm_array.buffer) {
1798 // Yield until the permutation is constant
1799 return;
1800 }
1801 CHECK(perm_array.data_type == ArrayDataType::kInt32)
1802 << "Transpose permutation input must be int32";
1803
1804 std::vector<int32> const& perm =
1805 perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1806 CHECK_EQ(perm.size(), input_shape.dimensions_count())
1807 << "Transpose permutation input " << op->inputs[1]
1808 << " must be same length as input dimensions";
1809 std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1810 for (int i = 0; i < perm.size(); i++) {
1811 int axis = perm[i];
1812 CHECK_GE(axis, 0);
1813 CHECK_LT(axis, input_shape.dimensions_count());
1814 output_dims->push_back(input_shape.dims(axis));
1815 }
1816 }
1817
1818 template <typename Op>
ProcessArgMinMaxOperator(Model * model,Op * op)1819 void ProcessArgMinMaxOperator(Model* model, Op* op) {
1820 CHECK_EQ(op->inputs.size(), 2);
1821 const auto& input_array = model->GetArray(op->inputs[0]);
1822 // Yield until input dims have been resolved.
1823 if (!input_array.has_shape()) {
1824 return;
1825 }
1826
1827 const Array& axis_array = model->GetArray(op->inputs[1]);
1828 // Yield until input axis array shape has been resolved.
1829 if (!axis_array.has_shape()) {
1830 return;
1831 }
1832
1833 const std::vector<int>& input_dims = input_array.shape().dims();
1834
1835 CHECK(axis_array.data_type == ArrayDataType::kInt32 ||
1836 axis_array.data_type == ArrayDataType::kInt64)
1837 << "axis_array must be int32, int64";
1838
1839 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
1840 << "Axis array must be scalar.";
1841
1842 int64 axis;
1843 if (axis_array.data_type == ArrayDataType::kInt32) {
1844 axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1845 } else {
1846 axis = axis_array.GetBuffer<ArrayDataType::kInt64>().data[0];
1847 }
1848
1849 std::vector<int> output_dims;
1850
1851 output_dims.reserve(input_dims.size() - 1);
1852 for (int i = 0; i < input_dims.size(); ++i) {
1853 if (i != axis) {
1854 output_dims.push_back(input_dims[i]);
1855 }
1856 }
1857
1858 const string& output_name = op->outputs[0];
1859 auto& output_array = model->GetArray(output_name);
1860 if (output_array.has_shape()) {
1861 return;
1862 }
1863 *output_array.mutable_shape()->mutable_dims() = output_dims;
1864 }
1865
ProcessSparseToDenseOperator(Model * model,SparseToDenseOperator * op)1866 void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
1867 CHECK_EQ(op->inputs.size(), 4);
1868
1869 const Array& output_shape_array = model->GetArray(op->inputs[1]);
1870 if (!output_shape_array.has_shape()) return;
1871 CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
1872
1873 // Output should not go over four dimensions.
1874 CHECK_LE(output_shape_array.shape().dims(0), 4);
1875
1876 const string& output_name = op->outputs[0];
1877 Array& output_array = model->GetArray(output_name);
1878 if (output_array.has_shape()) return;
1879
1880 CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
1881 output_shape_array.data_type == ArrayDataType::kInt64);
1882 if (output_shape_array.data_type == ArrayDataType::kInt32) {
1883 *output_array.mutable_shape()->mutable_dims() =
1884 output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1885 } else {
1886 const std::vector<int64>& output_shape_data =
1887 output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
1888 std::copy(
1889 output_shape_data.begin(), output_shape_data.end(),
1890 std::back_inserter(*output_array.mutable_shape()->mutable_dims()));
1891 }
1892 }
1893
ProcessTileOperator(Model * model,TensorFlowTileOperator * op)1894 void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
1895 CHECK_EQ(op->inputs.size(), 2);
1896 CHECK_EQ(op->outputs.size(), 1);
1897
1898 auto& output_array = model->GetArray(op->outputs[0]);
1899 if (output_array.has_shape()) {
1900 // We have already run.
1901 return;
1902 }
1903
1904 const auto& input_array = model->GetArray(op->inputs[0]);
1905 if (!input_array.has_shape()) {
1906 // Yield until input dims have been resolved.
1907 return;
1908 }
1909 const auto& input_shape = input_array.shape();
1910
1911 auto& multiples_array = model->GetArray(op->inputs[1]);
1912 if (!multiples_array.has_shape()) {
1913 // Yield until multiples shape been resolved.
1914 return;
1915 }
1916 if (!multiples_array.buffer) {
1917 // Yield until the multiples is constant.
1918 return;
1919 }
1920 CHECK(multiples_array.data_type == ArrayDataType::kInt32)
1921 << "Tile multiples input must be int32";
1922
1923 std::vector<int32> const& multiples =
1924 multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
1925 CHECK_EQ(multiples.size(), input_shape.dimensions_count())
1926 << "Tile multiples input " << op->inputs[1]
1927 << " must be same length as input dimensions";
1928
1929 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1930 mutable_dims->resize(multiples.size());
1931 for (int i = 0; i < mutable_dims->size(); ++i) {
1932 (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
1933 }
1934 }
1935
ProcessOneHotOperator(Model * model,OneHotOperator * op)1936 void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
1937 CHECK_EQ(op->inputs.size(), 4);
1938 CHECK_EQ(op->outputs.size(), 1);
1939 auto& output_array = model->GetArray(op->outputs[0]);
1940 if (output_array.has_shape()) {
1941 // Shape already propagated
1942 return;
1943 }
1944
1945 // Yield until indices dims have been resolved.
1946 const auto& indices_array =
1947 model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
1948 if (!indices_array.has_shape()) {
1949 return;
1950 }
1951
1952 // Yield until depth is constant and dims have been resolved.
1953 if (!IsConstantParameterArray(*model,
1954 op->inputs[OneHotOperator::DEPTH_INPUT])) {
1955 return;
1956 }
1957 const auto& depth_array =
1958 model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
1959 if (!depth_array.has_shape()) {
1960 return;
1961 }
1962
1963 CHECK(depth_array.data_type == ArrayDataType::kInt32)
1964 << "Depth array must be int32.";
1965 CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
1966 << "Depth array must be scalar.";
1967
1968 const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1969 CHECK_GE(depth, 0) << "Depth must be non-negative.";
1970
1971 const int indices_dims = indices_array.shape().dimensions_count();
1972 const int output_dims = indices_dims + 1;
1973 const int axis = op->axis == -1 ? indices_dims : op->axis;
1974 CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
1975
1976 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1977 mutable_dims->resize(output_dims);
1978 for (int i = 0; i < output_dims; ++i) {
1979 int dim = 0;
1980 if (i < axis) {
1981 dim = indices_array.shape().dims(i);
1982 } else if (i == axis) {
1983 dim = depth;
1984 } else {
1985 dim = indices_array.shape().dims(i - 1);
1986 }
1987 (*mutable_dims)[i] = dim;
1988 }
1989 }
1990
ProcessUnpackOperator(Model * model,UnpackOperator * op)1991 void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
1992 CHECK_EQ(op->inputs.size(), 1);
1993 const auto& input_array = model->GetArray(op->inputs[0]);
1994 // Yield until input dims have been resolved.
1995 if (!input_array.has_shape()) {
1996 return;
1997 }
1998
1999 const std::vector<int>& input_dims = input_array.shape().dims();
2000 std::vector<int> output_dims;
2001
2002 output_dims.reserve(input_dims.size() - 1);
2003 for (int i = 0; i < input_dims.size(); ++i) {
2004 if (i != op->axis) {
2005 output_dims.push_back(input_dims[i]);
2006 }
2007 }
2008 for (const string& output_name : op->outputs) {
2009 auto& output_array = model->GetArray(output_name);
2010 if (output_array.has_shape()) {
2011 return;
2012 }
2013 *output_array.mutable_shape()->mutable_dims() = output_dims;
2014 }
2015 }
2016
ProcessMirrorPadOperator(Model * model,MirrorPadOperator * op)2017 void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
2018 CHECK_EQ(op->inputs.size(), 2);
2019 const auto& input_array = model->GetArray(op->inputs[0]);
2020 const auto& padding_matrix = model->GetArray(op->inputs[1]);
2021
2022 // Yield until input dims have been resolved.
2023 if (!input_array.has_shape()) {
2024 return;
2025 }
2026
2027 auto& output_array = model->GetArray(op->outputs[0]);
2028 // If output already computed or padding matrix is non
2029 // const then return.
2030 if (output_array.has_shape() ||
2031 !IsConstantParameterArray(*model, op->inputs[1])) {
2032 return;
2033 }
2034 Shape output_shape = input_array.shape();
2035 std::vector<int>& dims = *output_shape.mutable_dims();
2036
2037 std::vector<int64_t> padding;
2038 if (padding_matrix.data_type == ArrayDataType::kInt32) {
2039 const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt32>().data;
2040 for (auto elem : data) {
2041 padding.push_back(static_cast<int64_t>(elem));
2042 }
2043 } else if (padding_matrix.data_type == ArrayDataType::kInt64) {
2044 const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt64>().data;
2045 for (auto elem : data) {
2046 padding.push_back(elem);
2047 }
2048 } else {
2049 CHECK(padding_matrix.data_type == ArrayDataType::kInt64 ||
2050 padding_matrix.data_type == ArrayDataType::kInt32);
2051 }
2052 CHECK_EQ(padding_matrix.shape().dimensions_count(), 2);
2053 CHECK_EQ(input_array.shape().dimensions_count(),
2054 padding_matrix.shape().dims(0));
2055 for (int i = 0; i < input_array.shape().dimensions_count(); ++i) {
2056 dims[i] += padding[i * 2] + padding[i * 2 + 1];
2057 }
2058
2059 output_array.copy_shape(output_shape);
2060 }
2061
ProcessUniqueOperator(Model * model,UniqueOperator * op)2062 void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
2063 const auto& input_array = model->GetArray(op->inputs[0]);
2064 // We have 2 outputs, the shape of the index tensor, is the same size
2065 // as the input array. The unique values tensor, is unknown until runtime.
2066 CHECK_EQ(op->outputs.size(), 2);
2067 auto& idx_output_array = model->GetArray(op->outputs[1]);
2068
2069 // Yield until input dims have been resolved, or output already computed
2070 if (!input_array.has_shape() || idx_output_array.has_shape()) {
2071 return;
2072 }
2073 idx_output_array.copy_shape(input_array.shape());
2074 }
2075
ProcessMatrixDiagOperator(Model * model,MatrixDiagOperator * op)2076 void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
2077 CHECK_EQ(op->inputs.size(), 1);
2078 CHECK_EQ(op->outputs.size(), 1);
2079 auto& input_array = model->GetArray(op->inputs[0]);
2080 auto& output_array = model->GetArray(op->outputs[0]);
2081 // The input array must have a shape in order to proceed. Also,
2082 // bail out if the output shape has already been calculated.
2083 if (!input_array.has_shape() || output_array.has_shape()) {
2084 // We have already run
2085 return;
2086 }
2087 // Get the input_shape
2088 Shape* mutable_shape = input_array.mutable_shape();
2089 std::vector<int>* dims = mutable_shape->mutable_dims();
2090 int dims_size = dims->size();
2091 // Scalars are not allowed.
2092 CHECK_GT(dims_size, 0);
2093 int last_dim = (*dims)[dims_size - 1];
2094 dims->push_back(last_dim);
2095 output_array.copy_shape(*mutable_shape);
2096 }
2097
ProcessMatrixSetDiagOperator(Model * model,MatrixSetDiagOperator * op)2098 void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
2099 CHECK_EQ(op->inputs.size(), 2);
2100 CHECK_EQ(op->outputs.size(), 1);
2101 auto& input_array = model->GetArray(op->inputs[0]);
2102 auto& output_array = model->GetArray(op->outputs[0]);
2103 // The shape of the input array must be known because that will
2104 // be the shape of the output array.
2105 if (!input_array.has_shape() || !output_array.has_shape()) {
2106 // We have already run
2107 return;
2108 }
2109
2110 output_array.copy_shape(input_array.shape());
2111 }
2112
2113 } // namespace
2114
Run(Model * model,std::size_t op_index,bool * modified)2115 ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
2116 std::size_t op_index,
2117 bool* modified) {
2118 *modified = false;
2119 auto it = model->operators.begin() + op_index;
2120 auto* op = it->get();
2121 std::unordered_map<string, std::vector<int>> old_output_dims;
2122 for (const auto& output : op->outputs) {
2123 if (model->GetArray(output).has_shape()) {
2124 old_output_dims[output] = model->GetArray(output).shape().dims();
2125 }
2126 }
2127
2128 switch (op->type) {
2129 case OperatorType::kAbs:
2130 case OperatorType::kBatchNormalization:
2131 case OperatorType::kL2Normalization:
2132 case OperatorType::kDequantize:
2133 case OperatorType::kElu:
2134 case OperatorType::kHardSwish:
2135 case OperatorType::kRelu:
2136 case OperatorType::kRelu1:
2137 case OperatorType::kRelu6:
2138 case OperatorType::kPRelu:
2139 case OperatorType::kLeakyRelu:
2140 case OperatorType::kSoftmax:
2141 case OperatorType::kLogSoftmax:
2142 case OperatorType::kLog:
2143 case OperatorType::kLogistic:
2144 case OperatorType::kTanh:
2145 case OperatorType::kLocalResponseNormalization:
2146 case OperatorType::kIdentity:
2147 case OperatorType::kFakeQuant:
2148 case OperatorType::kNeg:
2149 case OperatorType::kRsqrt:
2150 case OperatorType::kSqrt:
2151 case OperatorType::kSquare:
2152 case OperatorType::kAll:
2153 case OperatorType::kAssert:
2154 case OperatorType::kCast:
2155 case OperatorType::kFloor:
2156 case OperatorType::kCeil:
2157 case OperatorType::kRound:
2158 case OperatorType::kExp:
2159 case OperatorType::kSin:
2160 case OperatorType::kCos:
2161 case OperatorType::kLogicalAnd:
2162 case OperatorType::kLogicalNot:
2163 case OperatorType::kLogicalOr:
2164 case OperatorType::kZerosLike:
2165 case OperatorType::kReverseV2:
2166 case OperatorType::kReverseSequence:
2167 ProcessSimpleOperator(model, op, 0);
2168 break;
2169 case OperatorType::kGather:
2170 ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
2171 break;
2172 case OperatorType::kGatherNd:
2173 ProcessGatherNdOperator(model, static_cast<GatherNdOperator*>(op));
2174 break;
2175 case OperatorType::kTopK_V2:
2176 ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
2177 break;
2178 case OperatorType::kAdd:
2179 case OperatorType::kSub:
2180 case OperatorType::kMul:
2181 case OperatorType::kDiv:
2182 case OperatorType::kFloorDiv:
2183 case OperatorType::kFloorMod:
2184 case OperatorType::kLess:
2185 case OperatorType::kLessEqual:
2186 case OperatorType::kGreater:
2187 case OperatorType::kMaximum: // Element-wise Maximum
2188 case OperatorType::kMinimum: // Element-wise Minimum
2189 case OperatorType::kGreaterEqual:
2190 case OperatorType::kEqual:
2191 case OperatorType::kNotEqual:
2192 case OperatorType::kPow:
2193 case OperatorType::kSquaredDifference:
2194 ProcessSimpleBinaryOperator(model, op);
2195 break;
2196 case OperatorType::kAddN:
2197 ProcessAddNOperator(model, op);
2198 break;
2199 case OperatorType::kConv:
2200 ProcessConvOperator(model, static_cast<ConvOperator*>(op));
2201 break;
2202 case OperatorType::kTransposeConv:
2203 ProcessTransposeConvOperator(model,
2204 static_cast<TransposeConvOperator*>(op));
2205 break;
2206 case OperatorType::kDepthwiseConv:
2207 ProcessDepthwiseConvOperator(model,
2208 static_cast<DepthwiseConvOperator*>(op));
2209 break;
2210 case OperatorType::kDepthToSpace:
2211 ProcessDepthToSpaceOperator(model,
2212 static_cast<DepthToSpaceOperator*>(op));
2213 break;
2214 case OperatorType::kSpaceToDepth:
2215 ProcessSpaceToDepthOperator(model,
2216 static_cast<SpaceToDepthOperator*>(op));
2217 break;
2218 case OperatorType::kFill:
2219 CHECK_EQ(op->inputs.size(), 2);
2220 ProcessOpWithShapeInput(model, op);
2221 break;
2222 case OperatorType::kFullyConnected:
2223 ProcessFullyConnectedOperator(model,
2224 static_cast<FullyConnectedOperator*>(op));
2225 break;
2226 case OperatorType::kReshape:
2227 ProcessTensorFlowReshapeOperator(
2228 model, static_cast<TensorFlowReshapeOperator*>(op));
2229 break;
2230 case OperatorType::kAveragePool:
2231 ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
2232 break;
2233 case OperatorType::kMaxPool:
2234 ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
2235 break;
2236 case OperatorType::kL2Pool:
2237 ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
2238 break;
2239 case OperatorType::kReduceMin: // Reduction Min
2240 case OperatorType::kReduceMax: // Reduction Max
2241 case OperatorType::kSum:
2242 case OperatorType::kReduceProd:
2243 case OperatorType::kMean:
2244 case OperatorType::kAny:
2245 ProcessTensorFlowReductionOperator(model, op);
2246 break;
2247 case OperatorType::kSelect:
2248 ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
2249 break;
2250 case OperatorType::kSlice:
2251 ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
2252 break;
2253
2254 case OperatorType::kSwitch:
2255 // We can't know the sizes of the outputs until we have resolved the
2256 // predicate, and once we have resolved the predicate, the whole
2257 // Switch node will get resolved away.
2258 // See ResolveTensorFlowSwitch.
2259 break;
2260 case OperatorType::kMerge:
2261 // No need to bother resolving TensorFlow Merge ops: other graph
2262 // transformations will remove them anyway.
2263 // See ResolveTensorFlowMerge.
2264 break;
2265 case OperatorType::kSplit:
2266 ProcessTensorFlowSplitOperator(model,
2267 static_cast<TensorFlowSplitOperator*>(op));
2268 break;
2269 case OperatorType::kSplitV:
2270 ProcessTensorFlowSplitVOperator(
2271 model, static_cast<TensorFlowSplitVOperator*>(op));
2272 break;
2273 case OperatorType::kSqueeze:
2274 ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
2275 break;
2276 case OperatorType::kConcat:
2277 case OperatorType::kConcatV2:
2278 // Unimplemented, hopefully another graph transformation will
2279 // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
2280 // will resolve this node to a DepthConcatenation, or else we have
2281 // a more general non-depth concatenation that will hopefully be dropped,
2282 // or else at the moment we will abort.
2283 break;
2284 case OperatorType::kExpandDims:
2285 // Yield until ExpandDims is converted to Reshape
2286 break;
2287 case OperatorType::kRange:
2288 ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
2289 break;
2290 case OperatorType::kRank:
2291 ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
2292 break;
2293 case OperatorType::kShape:
2294 ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
2295 break;
2296 case OperatorType::kPack:
2297 ProcessPackOperator(model, static_cast<PackOperator*>(op));
2298 break;
2299 case OperatorType::kReorderAxes:
2300 ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
2301 break;
2302 case OperatorType::kConcatenation:
2303 ProcessConcatenationOperator(model,
2304 static_cast<ConcatenationOperator*>(op));
2305 break;
2306 case OperatorType::kResizeBilinear:
2307 ProcessResizeBilinearOperator(model,
2308 static_cast<ResizeBilinearOperator*>(op));
2309 break;
2310 case OperatorType::kResizeNearestNeighbor:
2311 ProcessResizeNearestNeighborOperator(
2312 model, static_cast<ResizeNearestNeighborOperator*>(op));
2313 break;
2314 case OperatorType::kUnidirectionalSequenceLstm:
2315 ProcessUnidirectionalSequenceLstmOperator(
2316 model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
2317 break;
2318 case OperatorType::kUnidirectionalSequenceRnn:
2319 ProcessUnidirectionalSequenceRnnOperator(
2320 model, static_cast<UnidirectionalSequenceRnnOperator*>(op));
2321 break;
2322 case OperatorType::kBidirectionalSequenceLstm:
2323 ProcessBidirectionalSequenceLstmOperator(
2324 model, static_cast<BidirectionalSequenceLstmOperator*>(op));
2325 break;
2326 case OperatorType::kBidirectionalSequenceRnn:
2327 ProcessBidirectionalSequenceRnnOperator(
2328 model, static_cast<BidirectionalSequenceRnnOperator*>(op));
2329 break;
2330 case OperatorType::kLstmCell:
2331 ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
2332 break;
2333 case OperatorType::kBatchMatMul:
2334 case OperatorType::kMatMul:
2335 // MatMul operators are converted to FullyConnected, after which their
2336 // shapes are propagated.
2337 break;
2338 case OperatorType::kSpaceToBatchND:
2339 ProcessSpaceToBatchNDOperator(model,
2340 static_cast<SpaceToBatchNDOperator*>(op));
2341 break;
2342 case OperatorType::kBatchToSpaceND:
2343 ProcessBatchToSpaceNDOperator(model,
2344 static_cast<BatchToSpaceNDOperator*>(op));
2345 break;
2346 case OperatorType::kPad:
2347 ProcessPadOperator(model, static_cast<PadOperator*>(op));
2348 break;
2349 case OperatorType::kPadV2:
2350 ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
2351 break;
2352 case OperatorType::kStridedSlice:
2353 ProcessStridedSliceOperator(model,
2354 static_cast<StridedSliceOperator*>(op));
2355 break;
2356 case OperatorType::kArgMax:
2357 ProcessArgMinMaxOperator<ArgMaxOperator>(
2358 model, static_cast<ArgMaxOperator*>(op));
2359 break;
2360 case OperatorType::kArgMin:
2361 ProcessArgMinMaxOperator<ArgMinOperator>(
2362 model, static_cast<ArgMinOperator*>(op));
2363 break;
2364 case OperatorType::kUnsupported: {
2365 const auto* unsupported_op =
2366 static_cast<TensorFlowUnsupportedOperator*>(op);
2367 // Attribute can be not specified, ignore it.
2368 if (unsupported_op->output_shapes.size() < op->outputs.size()) {
2369 return ::tensorflow::Status::OK();
2370 }
2371 for (int i = 0; i < op->outputs.size(); ++i) {
2372 const string& output = op->outputs[i];
2373 model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
2374 }
2375 break;
2376 }
2377 case OperatorType::kSvdf:
2378 ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
2379 break;
2380 case OperatorType::kTranspose:
2381 ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
2382 break;
2383 case OperatorType::kDynamicPartition:
2384 case OperatorType::kDynamicStitch:
2385 // DynamicPartition/DynamicStitch are currently only supported for
2386 // transforms that remove them, so we avoid propagating shapes through
2387 // them and let things settle once they've been removed.
2388 break;
2389 case OperatorType::kRandomUniform:
2390 CHECK_EQ(op->inputs.size(), 1);
2391 ProcessOpWithShapeInput(model, op);
2392 break;
2393 case OperatorType::kSparseToDense:
2394 ProcessSparseToDenseOperator(model,
2395 static_cast<SparseToDenseOperator*>(op));
2396 break;
2397 case OperatorType::kTile:
2398 ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
2399 break;
2400 break;
2401 case OperatorType::kOneHot:
2402 ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
2403 break;
2404 case OperatorType::kUnpack:
2405 ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
2406 break;
2407 case OperatorType::kMirrorPad:
2408 ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
2409 break;
2410 case OperatorType::kUnique:
2411 ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
2412 break;
2413 case OperatorType::kWhere:
2414 // The size of the output can only be known after evaluating the cond
2415 // tensor. Ignore shape propagation here and defer that to the
2416 // interpreter.
2417 break;
2418 case OperatorType::kMatrixDiag:
2419 ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
2420 break;
2421 case OperatorType::kMatrixSetDiag:
2422 ProcessMatrixSetDiagOperator(model,
2423 static_cast<MatrixSetDiagOperator*>(op));
2424 break;
2425 case OperatorType::kCTCBeamSearchDecoder:
2426 // The sizes of the outputs are only known in runtime based on the input.
2427 // Ignore shape progapation here and defer that to the interpreter.
2428 break;
2429 case OperatorType::kMatrixSetDiagV2:
2430 // MatrixSetDiagV2 operators are converted to MatrixSetDiag,
2431 // after which their shapes are propagated.
2432 break;
2433 case OperatorType::kMatrixDiagV2:
2434 // MatrixDiagV2 operators are converted to MatrixDiag, after which their
2435 // shapes are propagated.
2436 break;
2437 case OperatorType::kMatrixDiagV3:
2438 // MatrixDiagV3 operators are converted to MatrixDiag, after which their
2439 // shapes are propagated.
2440 break;
2441 case OperatorType::kMatrixSetDiagV3:
2442 // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
2443 // their shapes are propagated.
2444 break;
2445 case OperatorType::kSegmentSum:
2446 break;
2447 default:
2448 // Unimplemented, another graph transformation should drop it.
2449 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
2450 }
2451
2452 // Return true if any output dim changed, false if none changed.
2453 // Assumption: no transformation clears an output shape, they only add shapes.
2454 for (const auto& output : op->outputs) {
2455 if (model->GetArray(output).has_shape() &&
2456 (old_output_dims[output] != model->GetArray(output).shape().dims())) {
2457 AddMessageF("Set shape of %s to [%s]", output,
2458 absl::StrJoin(model->GetArray(output).shape().dims(), ","));
2459 *modified = true;
2460 return ::tensorflow::Status::OK();
2461 }
2462 }
2463 return ::tensorflow::Status::OK();
2464 }
2465
2466 } // namespace toco
2467