1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "flatbuffers/flexbuffers.h" // TF:flatbuffers
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/context_util.h"
19 #include "tensorflow/lite/core/subgraph.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21
22 namespace tflite {
23 namespace ops {
24 namespace custom {
25 namespace while_kernel {
26
27 namespace {
28
29 // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
30 // to `dst_tensor_indices` in `dst_subgraph`.
31 template <typename SrcVector, typename DstVector>
CopyTensorsShapeAndType(TfLiteContext * context,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)32 TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
33 Subgraph* src_subgraph,
34 const SrcVector& src_tensor_indices,
35 Subgraph* dst_subgraph,
36 const DstVector& dst_tensor_indices) {
37 TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
38 dst_tensor_indices.size());
39 for (int i = 0; i < src_tensor_indices.size(); ++i) {
40 const TfLiteTensor* src_tensor =
41 src_subgraph->tensor(src_tensor_indices[i]);
42 std::vector<int> dims(src_tensor->dims->data,
43 src_tensor->dims->data + src_tensor->dims->size);
44 dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
45 TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
46 dst_tensor->type = src_tensor->type;
47 }
48 return kTfLiteOk;
49 }
50
51 // Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph`
52 // to `dst_tensor_indices` in `dst_subgraph`.
53 template <typename SrcVector, typename DstVector>
CopyTensorsData(TfLiteContext * context,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)54 TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph,
55 const SrcVector& src_tensor_indices,
56 Subgraph* dst_subgraph,
57 const DstVector& dst_tensor_indices) {
58 TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
59 dst_tensor_indices.size());
60 for (int i = 0; i < src_tensor_indices.size(); ++i) {
61 const TfLiteTensor* src_tensor =
62 src_subgraph->tensor(src_tensor_indices[i]);
63 TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
64 TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes);
65 memcpy(dst_tensor->data.raw, src_tensor->data.raw, src_tensor->bytes);
66 }
67 return kTfLiteOk;
68 }
69
CheckCondOutput(TfLiteContext * context,const TfLiteTensor * cond_output)70 TfLiteStatus CheckCondOutput(TfLiteContext* context,
71 const TfLiteTensor* cond_output) {
72 // The condition output must be a single boolean value.
73 TF_LITE_ENSURE_EQ(context, cond_output->type, kTfLiteBool);
74 if (cond_output->dims->size == 0) {
75 // It's okay if it's a 0D scalar.
76 return kTfLiteOk;
77 }
78 // Otherwise it must be 1D with shape [1].
79 TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1);
80 TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1);
81 return kTfLiteOk;
82 }
83
84 } // namespace
85
86 struct OpData {
87 int cond_subgraph_index;
88 int body_subgraph_index;
89 bool cond_has_dynamic_output_tensors;
90 bool body_has_dynamic_output_tensors;
91 };
92
Init(TfLiteContext * context,const char * buffer,size_t length)93 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
94 auto* op_data = new OpData;
95 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
96 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
97 op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32();
98 op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32();
99 op_data->cond_has_dynamic_output_tensors = false;
100 op_data->body_has_dynamic_output_tensors = false;
101 return op_data;
102 }
103
Free(TfLiteContext * context,void * buffer)104 void Free(TfLiteContext* context, void* buffer) {
105 delete reinterpret_cast<OpData*>(buffer);
106 }
107
Prepare(TfLiteContext * context,TfLiteNode * node)108 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
109 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
110 int num_inputs = node->inputs->size;
111 // The number of outputs should be the same as number of inputs.
112 TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs);
113
114 // Check subgraph indices and get subgraphs.
115 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
116 auto* subgraphs = this_subgraph->GetSubgraphs();
117 TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
118 TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
119
120 Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
121 Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
122
123 // Check input & output count of the condition subgraph.
124 TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs);
125 TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1);
126
127 // Check input & output count of the body subgraph.
128 TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs);
129 TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs);
130
131 // Prepare and check the condition subgraph.
132 TF_LITE_ENSURE_OK(
133 context, CopyTensorsShapeAndType(context, this_subgraph,
134 TfLiteIntArrayView(node->inputs),
135 cond_subgraph, cond_subgraph->inputs()));
136 TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
137 TfLiteTensor* cond_output =
138 cond_subgraph->tensor(cond_subgraph->outputs()[0]);
139 // TODO(ycling): Handle the case the cond subgraph has dynamic tensor outputs.
140 // This should rarely happens. In most cases the output is static with shape
141 // [1]. However theoretically intermediate tensors in the cond subgraph
142 // can be dynamic.
143 if (IsDynamicTensor(cond_output)) {
144 op_data->cond_has_dynamic_output_tensors = true;
145 } else {
146 TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
147 }
148
149 // Prepare and check the body subgraph.
150 TF_LITE_ENSURE_OK(
151 context, CopyTensorsShapeAndType(context, this_subgraph,
152 TfLiteIntArrayView(node->inputs),
153 body_subgraph, body_subgraph->inputs()));
154 TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
155 if (body_subgraph->HasDynamicTensors()) {
156 op_data->body_has_dynamic_output_tensors = true;
157 } else {
158 for (int i = 0; i < num_inputs; ++i) {
159 TfLiteTensor* body_input =
160 body_subgraph->tensor(body_subgraph->inputs()[i]);
161 TfLiteTensor* body_output =
162 body_subgraph->tensor(body_subgraph->outputs()[i]);
163 TF_LITE_ENSURE_EQ(context, body_input->type, body_output->type);
164
165 // TODO(ycling): Support dynamic sized body subgraph.
166 TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));
167 if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) {
168 // If the output shape of the body subgraph is static w.r.t. a fixed
169 // input size, but it's different from input size, it's still considered
170 // dynamic. For example: If a subgraph keeps padding its input with a
171 // fixed padding, the output shape is static w.r.t the input shape and
172 // padding, but running it in a loop will keep bloating the tensor.
173 op_data->body_has_dynamic_output_tensors = true;
174 break;
175 }
176 }
177 }
178 for (int i = 0; i < num_inputs; ++i) {
179 TfLiteTensor* output = GetOutput(context, node, i);
180 if (op_data->body_has_dynamic_output_tensors) {
181 SetTensorToDynamic(output);
182 } else {
183 TfLiteTensor* body_output =
184 body_subgraph->tensor(body_subgraph->outputs()[i]);
185 TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims);
186 TF_LITE_ENSURE_OK(context,
187 context->ResizeTensor(context, output, output_size));
188 }
189 }
190 return kTfLiteOk;
191 }
192
Eval(TfLiteContext * context,TfLiteNode * node)193 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
194 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
195 Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
196 auto* subgraphs = this_subgraph->GetSubgraphs();
197 Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
198 Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
199
200 // The follow graph illustrates the current implementation.
201 //
202 // This Subgraph Cond Subgraph Body Subgraph
203 // +-----------+ (1) +------------+ (3) +------------+
204 // | WHILE |-------->| SUBGRAPH |-------->| SUBGRAPH |
205 // | INPUT | /| INPUT |<----- | INPUT |
206 // +-----------+ / +------------+ \ +------------+
207 // / | \ |
208 // (6) / | (2) (5) \ | (4)
209 // / v \ v
210 // +-----------+ / +------------+ +------------+
211 // | WHILE |<-- | SUBGRAPH | | SUBGRAPH |
212 // | OUTPUT | | OUTPUT | | OUTPUT |
213 // +-----------+ +------------+ +------------+
214 //
215 // (1) Copy the inputs of WHILE op to the inputs of condition subgraph.
216 // (2) Invoke condition subgraph.
217 // Jump to step 5 if result is false.
218 // (3) Copy the inputs of condition subgraph to the inputs of body subgraph.
219 // (4) Invoke body subgraph.
220 // (5) Copy the outputs of body subgraph to the inputs condition subgraph.
221 // Jump back to step 2!
222 // (6) Copy the inputs of condition subgraph to the outputs of WHILE op.
223 //
224 // If the body subgraph has dynamic sized outputs, it's required to resize the
225 // tensor before copying in step 1, 3, 4 and 6.
226 //
227 // Note the flow is carefully designed to handle the dynamic sized output
228 // case. The loop invariant is: The newest value is in the inputs of condition
229 // subgraph. This is always true before step 2.
230 //
231 // This is the best we can do without sharing tensor buffer across subgraph
232 // boundary. Currently we copy the input / output between the subgraphs. This
233 // isn't optimized yet and a lot of redundant copies are made.
234 // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
235 TF_LITE_ENSURE_OK(
236 context,
237 CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
238 cond_subgraph, cond_subgraph->inputs()));
239
240 while (true) {
241 TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke());
242 int cond_subgraph_output_index = cond_subgraph->outputs()[0];
243 cond_subgraph->EnsureTensorDataIsReadable(cond_subgraph_output_index);
244 TfLiteTensor* cond_output =
245 cond_subgraph->tensor(cond_subgraph_output_index);
246 if (op_data->cond_has_dynamic_output_tensors) {
247 TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
248 }
249
250 if (!cond_output->data.b[0]) {
251 break;
252 }
253 if (op_data->body_has_dynamic_output_tensors) {
254 TF_LITE_ENSURE_OK(context,
255 CopyTensorsShapeAndType(
256 context, cond_subgraph, cond_subgraph->inputs(),
257 body_subgraph, body_subgraph->inputs()));
258 TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
259 }
260
261 TF_LITE_ENSURE_OK(
262 context,
263 CopyTensorsData(context, cond_subgraph, cond_subgraph->inputs(),
264 body_subgraph, body_subgraph->inputs()));
265
266 TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
267
268 for (int tensor_index : body_subgraph->outputs()) {
269 body_subgraph->EnsureTensorDataIsReadable(tensor_index);
270 }
271
272 if (op_data->body_has_dynamic_output_tensors) {
273 TF_LITE_ENSURE_OK(context,
274 CopyTensorsShapeAndType(
275 context, body_subgraph, body_subgraph->outputs(),
276 cond_subgraph, cond_subgraph->inputs()));
277 TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
278 }
279
280 TF_LITE_ENSURE_OK(
281 context,
282 CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
283 cond_subgraph, cond_subgraph->inputs()));
284 }
285
286 // Note that copying from body's output will fail if body is never invoked.
287 // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
288 if (op_data->body_has_dynamic_output_tensors) {
289 TF_LITE_ENSURE_OK(
290 context, CopyTensorsShapeAndType(context, cond_subgraph,
291 cond_subgraph->inputs(), this_subgraph,
292 TfLiteIntArrayView(node->outputs)));
293 }
294
295 TF_LITE_ENSURE_OK(
296 context,
297 CopyTensorsData(context, cond_subgraph, cond_subgraph->inputs(),
298 this_subgraph, TfLiteIntArrayView(node->outputs)));
299 return kTfLiteOk;
300 }
301
302 } // namespace while_kernel
303
Register_WHILE()304 TfLiteRegistration* Register_WHILE() {
305 static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free,
306 while_kernel::Prepare, while_kernel::Eval};
307 return &r;
308 }
309
310 } // namespace custom
311 } // namespace ops
312 } // namespace tflite
313