• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "flatbuffers/flexbuffers.h"  // TF:flatbuffers
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/context_util.h"
19 #include "tensorflow/lite/core/subgraph.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 
22 namespace tflite {
23 namespace ops {
24 namespace custom {
25 namespace while_kernel {
26 
27 namespace {
28 
29 // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
30 // to `dst_tensor_indices` in `dst_subgraph`.
31 template <typename SrcVector, typename DstVector>
CopyTensorsShapeAndType(TfLiteContext * context,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)32 TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
33                                      Subgraph* src_subgraph,
34                                      const SrcVector& src_tensor_indices,
35                                      Subgraph* dst_subgraph,
36                                      const DstVector& dst_tensor_indices) {
37   TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
38                     dst_tensor_indices.size());
39   for (int i = 0; i < src_tensor_indices.size(); ++i) {
40     const TfLiteTensor* src_tensor =
41         src_subgraph->tensor(src_tensor_indices[i]);
42     std::vector<int> dims(src_tensor->dims->data,
43                           src_tensor->dims->data + src_tensor->dims->size);
44     dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
45     TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
46     dst_tensor->type = src_tensor->type;
47   }
48   return kTfLiteOk;
49 }
50 
51 // Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph`
52 // to `dst_tensor_indices` in `dst_subgraph`.
53 template <typename SrcVector, typename DstVector>
CopyTensorsData(TfLiteContext * context,Subgraph * src_subgraph,const SrcVector & src_tensor_indices,Subgraph * dst_subgraph,const DstVector & dst_tensor_indices)54 TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph,
55                              const SrcVector& src_tensor_indices,
56                              Subgraph* dst_subgraph,
57                              const DstVector& dst_tensor_indices) {
58   TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
59                     dst_tensor_indices.size());
60   for (int i = 0; i < src_tensor_indices.size(); ++i) {
61     const TfLiteTensor* src_tensor =
62         src_subgraph->tensor(src_tensor_indices[i]);
63     TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
64     TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes);
65     memcpy(dst_tensor->data.raw, src_tensor->data.raw, src_tensor->bytes);
66   }
67   return kTfLiteOk;
68 }
69 
CheckCondOutput(TfLiteContext * context,const TfLiteTensor * cond_output)70 TfLiteStatus CheckCondOutput(TfLiteContext* context,
71                              const TfLiteTensor* cond_output) {
72   // The condition output must be a single boolean value.
73   TF_LITE_ENSURE_EQ(context, cond_output->type, kTfLiteBool);
74   if (cond_output->dims->size == 0) {
75     // It's okay if it's a 0D scalar.
76     return kTfLiteOk;
77   }
78   // Otherwise it must be 1D with shape [1].
79   TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1);
80   TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1);
81   return kTfLiteOk;
82 }
83 
84 }  // namespace
85 
86 struct OpData {
87   int cond_subgraph_index;
88   int body_subgraph_index;
89   bool cond_has_dynamic_output_tensors;
90   bool body_has_dynamic_output_tensors;
91 };
92 
Init(TfLiteContext * context,const char * buffer,size_t length)93 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
94   auto* op_data = new OpData;
95   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
96   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
97   op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32();
98   op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32();
99   op_data->cond_has_dynamic_output_tensors = false;
100   op_data->body_has_dynamic_output_tensors = false;
101   return op_data;
102 }
103 
Free(TfLiteContext * context,void * buffer)104 void Free(TfLiteContext* context, void* buffer) {
105   delete reinterpret_cast<OpData*>(buffer);
106 }
107 
Prepare(TfLiteContext * context,TfLiteNode * node)108 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
109   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
110   int num_inputs = node->inputs->size;
111   // The number of outputs should be the same as number of inputs.
112   TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs);
113 
114   // Check subgraph indices and get subgraphs.
115   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
116   auto* subgraphs = this_subgraph->GetSubgraphs();
117   TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
118   TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
119 
120   Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
121   Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
122 
123   // Check input & output count of the condition subgraph.
124   TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs);
125   TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1);
126 
127   // Check input & output count of the body subgraph.
128   TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs);
129   TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs);
130 
131   // Prepare and check the condition subgraph.
132   TF_LITE_ENSURE_OK(
133       context, CopyTensorsShapeAndType(context, this_subgraph,
134                                        TfLiteIntArrayView(node->inputs),
135                                        cond_subgraph, cond_subgraph->inputs()));
136   TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
137   TfLiteTensor* cond_output =
138       cond_subgraph->tensor(cond_subgraph->outputs()[0]);
139   // TODO(ycling): Handle the case the cond subgraph has dynamic tensor outputs.
140   // This should rarely happens. In most cases the output is static with shape
141   // [1]. However theoretically intermediate tensors in the cond subgraph
142   // can be dynamic.
143   if (IsDynamicTensor(cond_output)) {
144     op_data->cond_has_dynamic_output_tensors = true;
145   } else {
146     TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
147   }
148 
149   // Prepare and check the body subgraph.
150   TF_LITE_ENSURE_OK(
151       context, CopyTensorsShapeAndType(context, this_subgraph,
152                                        TfLiteIntArrayView(node->inputs),
153                                        body_subgraph, body_subgraph->inputs()));
154   TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
155   if (body_subgraph->HasDynamicTensors()) {
156     op_data->body_has_dynamic_output_tensors = true;
157   } else {
158     for (int i = 0; i < num_inputs; ++i) {
159       TfLiteTensor* body_input =
160           body_subgraph->tensor(body_subgraph->inputs()[i]);
161       TfLiteTensor* body_output =
162           body_subgraph->tensor(body_subgraph->outputs()[i]);
163       TF_LITE_ENSURE_EQ(context, body_input->type, body_output->type);
164 
165       // TODO(ycling): Support dynamic sized body subgraph.
166       TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));
167       if (!TfLiteIntArrayEqual(body_input->dims, body_output->dims)) {
168         // If the output shape of the body subgraph is static w.r.t. a fixed
169         // input size, but it's different from input size, it's still considered
170         // dynamic. For example: If a subgraph keeps padding its input with a
171         // fixed padding, the output shape is static w.r.t the input shape and
172         // padding, but running it in a loop will keep bloating the tensor.
173         op_data->body_has_dynamic_output_tensors = true;
174         break;
175       }
176     }
177   }
178   for (int i = 0; i < num_inputs; ++i) {
179     TfLiteTensor* output = GetOutput(context, node, i);
180     if (op_data->body_has_dynamic_output_tensors) {
181       SetTensorToDynamic(output);
182     } else {
183       TfLiteTensor* body_output =
184           body_subgraph->tensor(body_subgraph->outputs()[i]);
185       TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims);
186       TF_LITE_ENSURE_OK(context,
187                         context->ResizeTensor(context, output, output_size));
188     }
189   }
190   return kTfLiteOk;
191 }
192 
Eval(TfLiteContext * context,TfLiteNode * node)193 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
194   const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
195   Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
196   auto* subgraphs = this_subgraph->GetSubgraphs();
197   Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
198   Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
199 
200   // The follow graph illustrates the current implementation.
201   //
202   // This Subgraph          Cond Subgraph         Body Subgraph
203   // +-----------+   (1)   +------------+   (3)   +------------+
204   // |   WHILE   |-------->|  SUBGRAPH  |-------->|  SUBGRAPH  |
205   // |   INPUT   |        /|   INPUT    |<-----   |   INPUT    |
206   // +-----------+       / +------------+      \  +------------+
207   //                    /        |              \       |
208   //               (6) /         | (2)       (5) \      | (4)
209   //                  /          v                \     v
210   // +-----------+   /     +------------+         +------------+
211   // |   WHILE   |<--      |  SUBGRAPH  |         |  SUBGRAPH  |
212   // |   OUTPUT  |         |   OUTPUT   |         |   OUTPUT   |
213   // +-----------+         +------------+         +------------+
214   //
215   // (1) Copy the inputs of WHILE op to the inputs of condition subgraph.
216   // (2) Invoke condition subgraph.
217   //     Jump to step 5 if result is false.
218   // (3) Copy the inputs of condition subgraph to the inputs of body subgraph.
219   // (4) Invoke body subgraph.
220   // (5) Copy the outputs of body subgraph to the inputs condition subgraph.
221   //     Jump back to step 2!
222   // (6) Copy the inputs of condition subgraph to the outputs of WHILE op.
223   //
224   // If the body subgraph has dynamic sized outputs, it's required to resize the
225   // tensor before copying in step 1, 3, 4 and 6.
226   //
227   // Note the flow is carefully designed to handle the dynamic sized output
228   // case. The loop invariant is: The newest value is in the inputs of condition
229   // subgraph. This is always true before step 2.
230   //
231   // This is the best we can do without sharing tensor buffer across subgraph
232   // boundary. Currently we copy the input / output between the subgraphs. This
233   // isn't optimized yet and a lot of redundant copies are made.
234   // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
235   TF_LITE_ENSURE_OK(
236       context,
237       CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
238                       cond_subgraph, cond_subgraph->inputs()));
239 
240   while (true) {
241     TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke());
242     int cond_subgraph_output_index = cond_subgraph->outputs()[0];
243     cond_subgraph->EnsureTensorDataIsReadable(cond_subgraph_output_index);
244     TfLiteTensor* cond_output =
245         cond_subgraph->tensor(cond_subgraph_output_index);
246     if (op_data->cond_has_dynamic_output_tensors) {
247       TF_LITE_ENSURE_STATUS(CheckCondOutput(context, cond_output));
248     }
249 
250     if (!cond_output->data.b[0]) {
251       break;
252     }
253     if (op_data->body_has_dynamic_output_tensors) {
254       TF_LITE_ENSURE_OK(context,
255                         CopyTensorsShapeAndType(
256                             context, cond_subgraph, cond_subgraph->inputs(),
257                             body_subgraph, body_subgraph->inputs()));
258       TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
259     }
260 
261     TF_LITE_ENSURE_OK(
262         context,
263         CopyTensorsData(context, cond_subgraph, cond_subgraph->inputs(),
264                         body_subgraph, body_subgraph->inputs()));
265 
266     TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
267 
268     for (int tensor_index : body_subgraph->outputs()) {
269       body_subgraph->EnsureTensorDataIsReadable(tensor_index);
270     }
271 
272     if (op_data->body_has_dynamic_output_tensors) {
273       TF_LITE_ENSURE_OK(context,
274                         CopyTensorsShapeAndType(
275                             context, body_subgraph, body_subgraph->outputs(),
276                             cond_subgraph, cond_subgraph->inputs()));
277       TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
278     }
279 
280     TF_LITE_ENSURE_OK(
281         context,
282         CopyTensorsData(context, body_subgraph, body_subgraph->outputs(),
283                         cond_subgraph, cond_subgraph->inputs()));
284   }
285 
286   // Note that copying from body's output will fail if body is never invoked.
287   // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
288   if (op_data->body_has_dynamic_output_tensors) {
289     TF_LITE_ENSURE_OK(
290         context, CopyTensorsShapeAndType(context, cond_subgraph,
291                                          cond_subgraph->inputs(), this_subgraph,
292                                          TfLiteIntArrayView(node->outputs)));
293   }
294 
295   TF_LITE_ENSURE_OK(
296       context,
297       CopyTensorsData(context, cond_subgraph, cond_subgraph->inputs(),
298                       this_subgraph, TfLiteIntArrayView(node->outputs)));
299   return kTfLiteOk;
300 }
301 
302 }  // namespace while_kernel
303 
Register_WHILE()304 TfLiteRegistration* Register_WHILE() {
305   static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free,
306                                  while_kernel::Prepare, while_kernel::Eval};
307   return &r;
308 }
309 
310 }  // namespace custom
311 }  // namespace ops
312 }  // namespace tflite
313