• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <cstring>
18 #include <memory>
19 
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace reshape {
29 
30 constexpr int kInputTensor = 0;
31 constexpr int kShapeTensor = 1;
32 constexpr int kOutputTensor = 0;
33 
34 TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*);
35 
ResizeOutput(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
37   TfLiteIntArray* output_shape = GetOutputShape(context, node);
38   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
39       scoped_output_shape(output_shape, TfLiteIntArrayFree);
40 
41   const TfLiteTensor* input;
42   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
43   TfLiteTensor* output;
44   TF_LITE_ENSURE_OK(context,
45                     GetOutputSafe(context, node, kOutputTensor, &output));
46 
47   // Tensorflow's Reshape allows one of the shape components to have the
48   // special -1 value, meaning it will be calculated automatically based on the
49   // input. Here we calculate what that dimension should be so that the number
50   // of output elements is the same as the number of input elements.
51   int64_t non_zero_num_input_elements = 1, num_input_elements = 1;
52   const RuntimeShape& input_shape = GetTensorShape(input);
53   for (int i = 0; i < input_shape.DimensionsCount(); ++i) {
54     const int value = input_shape.Dims(i);
55     num_input_elements *= value;
56     if (value != 0) {
57       non_zero_num_input_elements *= value;
58     }
59   }
60 
61   int64_t non_zero_num_output_elements = 1, num_output_elements = 1;
62   int stretch_dim = -1;
63   for (int i = 0; i < output_shape->size; ++i) {
64     const int value = output_shape->data[i];
65     if (value == -1) {
66       TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
67       stretch_dim = i;
68       continue;
69     } else if (value != 0) {
70       non_zero_num_output_elements *= value;
71     }
72     num_output_elements *= value;
73   }
74 
75   if (stretch_dim != -1) {
76     if (num_input_elements == 0 && num_output_elements != 0) {
77       output_shape->data[stretch_dim] = 0;
78     } else {
79       output_shape->data[stretch_dim] =
80           non_zero_num_input_elements / non_zero_num_output_elements;
81     }
82     num_output_elements *= output_shape->data[stretch_dim];
83   }
84 
85   TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
86   return context->ResizeTensor(context, output, scoped_output_shape.release());
87 }
88 
GetOutputShapeFromTensor(TfLiteContext * context,TfLiteNode * node)89 inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
90                                                 TfLiteNode* node) {
91   const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
92   if (shape == nullptr) return nullptr;
93 
94   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
95   for (int i = 0; i < output_shape->size; ++i) {
96     output_shape->data[i] = shape->data.i32[i];
97   }
98 
99   return output_shape;
100 }
101 
GetOutputShapeFromParam(TfLiteContext * context,TfLiteNode * node)102 inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
103                                                TfLiteNode* node) {
104   auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
105 
106   // The function is returned above this line if the shape tensor is usable.
107   // Now fallback to the shape parameter in `TfLiteReshapeParams`.
108   int num_dimensions = params->num_dimensions;
109   if (num_dimensions == 1 && params->shape[0] == 0) {
110     // Legacy tflite models use a shape parameter of [0] to indicate scalars,
111     // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
112     // toco conversion.
113     num_dimensions = 0;
114   }
115   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
116   for (int i = 0; i < num_dimensions; ++i) {
117     output_shape->data[i] = params->shape[i];
118   }
119 
120   return output_shape;
121 }
122 
123 // Check if the shape tensor is valid. Shapes should be int32 vectors.
ShapeIsVector(TfLiteContext * context,TfLiteNode * node)124 inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
125   const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
126   return (shape != nullptr && shape->dims->size == 1 &&
127           shape->type == kTfLiteInt32);
128 }
129 
GetOutputShape(TfLiteContext * context,TfLiteNode * node)130 TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) {
131   if (NumInputs(node) == 2 && ShapeIsVector(context, node)) {
132     return GetOutputShapeFromTensor(context, node);
133   } else {
134     return GetOutputShapeFromParam(context, node);
135   }
136 }
137 
Prepare(TfLiteContext * context,TfLiteNode * node)138 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
139   TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
140   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
141 
142   // Always postpone sizing string tensors, even if we could in principle
143   // calculate their shapes now. String tensors don't benefit from having their
144   // shapes precalculated because the actual memory can only be allocated after
145   // we know all the content.
146   TfLiteTensor* output;
147   TF_LITE_ENSURE_OK(context,
148                     GetOutputSafe(context, node, kOutputTensor, &output));
149   if (output->type != kTfLiteString) {
150     if (NumInputs(node) == 1 ||
151         IsConstantTensor(GetInput(context, node, kShapeTensor))) {
152       TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
153     } else {
154       SetTensorToDynamic(output);
155     }
156   }
157   return kTfLiteOk;
158 }
159 
Eval(TfLiteContext * context,TfLiteNode * node)160 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
161   const TfLiteTensor* input;
162   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
163   TfLiteTensor* output;
164   TF_LITE_ENSURE_OK(context,
165                     GetOutputSafe(context, node, kOutputTensor, &output));
166 
167   // There are two ways in which the 'output' can be made dynamic: it could be
168   // a string tensor, or its shape cannot be calculated during Prepare(). In
169   // either case, we now have all the information to calculate its shape.
170   if (IsDynamicTensor(output)) {
171     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
172   }
173 
174   // Note that string tensors are always "dynamic" in the sense that their size
175   // is not known until we have all the content. This applies even when their
176   // shape is known ahead of time. As a result, a string tensor is never given
177   // any memory by ResizeOutput(), and we need to do it manually here. Since
178   // reshape doesn't change the data, the output tensor needs exactly as many
179   // bytes as the input tensor.
180   if (output->type == kTfLiteString) {
181     auto bytes_required = input->bytes;
182     TfLiteTensorRealloc(bytes_required, output);
183     output->bytes = bytes_required;
184   }
185 
186   memcpy(output->data.raw, input->data.raw, input->bytes);
187 
188   return kTfLiteOk;
189 }
190 
191 }  // namespace reshape
192 
Register_RESHAPE()193 TfLiteRegistration* Register_RESHAPE() {
194   static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare,
195                                  reshape::Eval};
196   return &r;
197 }
198 
199 }  // namespace builtin
200 }  // namespace ops
201 }  // namespace tflite
202