• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <memory>
18 
19 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
20 #include "tensorflow/core/util/ragged_to_dense_util_common.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/context.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 #include "tensorflow/lite/model.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace custom {
32 namespace ragged {
33 namespace ragged_tensor_to_tensor {
34 namespace {
35 
36 constexpr int kShapeInput = 0;
37 constexpr int kValuesInput = 1;
38 constexpr int kDefaultValueInput = 2;
39 constexpr int kFirstPartitionInputIndex = 3;
40 
41 constexpr int kOutputTensor = 0;
42 
43 constexpr char kRowPartitionTypesAttr[] = "row_partition_types";
44 
45 struct ConversionAttributes {
46   std::vector<tensorflow::RowPartitionType> partition_types;
47   int ragged_rank = 0;
48 
GetRowPartitionTypeByDimensiontflite::ops::custom::ragged::ragged_tensor_to_tensor::__anon7cc6f1010111::ConversionAttributes49   tensorflow::RowPartitionType GetRowPartitionTypeByDimension(
50       int dimension) const {
51     if (partition_types.front() ==
52         tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
53       return partition_types[dimension + 1];
54     } else {
55       return partition_types[dimension];
56     }
57   }
58 };
59 template <typename INDEX_TYPE>
GetFirstDimensionSizeT(TfLiteContext * context,const TfLiteTensor & first_partition_input,const ConversionAttributes * attributes)60 int GetFirstDimensionSizeT(TfLiteContext* context,
61                            const TfLiteTensor& first_partition_input,
62                            const ConversionAttributes* attributes) {
63   const tensorflow::RowPartitionType first_partition_type =
64       attributes->partition_types.front();
65   switch (first_partition_type) {
66     case tensorflow::RowPartitionType::FIRST_DIM_SIZE:
67       return *GetTensorData<INDEX_TYPE>(&first_partition_input);
68     case tensorflow::RowPartitionType::VALUE_ROWIDS:
69       context->ReportError(context,
70                            "Cannot handle VALUE_ROWIDS in first dimension.");
71       return -1;
72     case tensorflow::RowPartitionType::ROW_SPLITS: {
73       const auto shape = GetTensorShape(&first_partition_input);
74       return shape.Dims(0) - 1;
75     }
76 
77     default:
78       context->ReportError(
79           context, "Cannot handle type ",
80           RowPartitionTypeToString(first_partition_type).c_str());
81       return -1;
82   }
83 }
84 
GetFirstDimensionSize(TfLiteContext * context,const TfLiteTensor & first_partition_input,const ConversionAttributes * attributes)85 int GetFirstDimensionSize(TfLiteContext* context,
86                           const TfLiteTensor& first_partition_input,
87                           const ConversionAttributes* attributes) {
88   switch (first_partition_input.type) {
89     case kTfLiteInt32:
90       return GetFirstDimensionSizeT<int32_t>(context, first_partition_input,
91                                              attributes);
92     case kTfLiteInt64:
93       return GetFirstDimensionSizeT<int64_t>(context, first_partition_input,
94                                              attributes);
95     default:
96       context->ReportError(context,
97                            "Not supported row partitioning tensor type");
98       return -1;
99   }
100 }
101 
ValidateDefaultValueShape(TfLiteContext * context,const RuntimeShape & default_value_shape,const RuntimeShape &)102 bool ValidateDefaultValueShape(TfLiteContext* context,
103                                const RuntimeShape& default_value_shape,
104                                const RuntimeShape& /*value_shape*/) {
105   // TF implementation also checks that shapes are not defined, not needed in
106   // TFLite.
107   // TODO(mgubin): Only scalar default value sizes are supported.
108   if (default_value_shape.FlatSize() != 1) {
109     context->ReportError(context, "Only scalar default value is supported");
110     return false;
111   }
112   return true;
113 }
114 
TensorShapeFromTensor(const TfLiteTensor & tensor)115 RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) {
116   // TODO(mgubin): No checks, see
117   // third_party/tensorflow/core/kernels/list_kernels.cc
118   const RuntimeShape tensor_shape(tensor.dims->size, tensor.dims->data);
119   if (0 == tensor.dims->size) {
120     // If the input tensor is scalar then the shape is empty (also scalar).
121     return RuntimeShape{};
122   }
123   RuntimeShape result(tensor_shape.FlatSize());
124   switch (tensor.type) {
125     case kTfLiteInt32: {
126       for (int i = 0; i < tensor_shape.FlatSize(); ++i) {
127         result.SetDim(i, GetTensorData<int32_t>(&tensor)[i]);
128       }
129     } break;
130     case kTfLiteInt64: {
131       for (int i = 0; i < tensor_shape.FlatSize(); ++i) {
132         result.SetDim(i, GetTensorData<int64_t>(&tensor)[i]);
133       }
134     } break;
135     default: {
136       // Checked in Prepare.
137     }
138   }
139   return result;
140 }
141 
GetRowPartitionTensor(const ConversionAttributes & conversion_attributes,TfLiteContext * context,TfLiteNode * node,int dimension)142 const TfLiteTensor* GetRowPartitionTensor(
143     const ConversionAttributes& conversion_attributes, TfLiteContext* context,
144     TfLiteNode* node, int dimension) {
145   if (conversion_attributes.partition_types.front() ==
146       tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
147     return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 +
148                                                 dimension]];
149   } else {
150     return &context->tensors[node->inputs
151                                  ->data[kFirstPartitionInputIndex + dimension]];
152   }
153 }
154 
GetMaxWidthValueRowID(const TfLiteTensor * tensor)155 int GetMaxWidthValueRowID(const TfLiteTensor* tensor) {
156   const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data);
157   const int index_length = tensor_shape.FlatSize();
158   if (index_length == 0) {
159     return 0;
160   }
161   auto value_rowids = [tensor](int index) {
162     switch (tensor->type) {
163       case kTfLiteInt32:
164         return static_cast<int>(tensor->data.i32[index]);
165       case kTfLiteInt64:
166         return static_cast<int>(tensor->data.i64[index]);
167       default:
168         // TODO(mgubin): Add error checks.
169         return 0;
170     }
171   };
172   int first_equal_index = 0;
173   int first_equal_index_value = value_rowids(0);
174   int max_width = 0;
175   for (int i = 0; i < index_length; ++i) {
176     const int value = value_rowids(i);
177     if (value != first_equal_index_value) {
178       first_equal_index_value = value;
179       max_width = std::max(i - first_equal_index, max_width);
180       first_equal_index = i;
181     }
182   }
183   return std::max(index_length - first_equal_index, max_width);
184 }
185 
GetMaxWidthRowSplit(const TfLiteTensor * tensor)186 int GetMaxWidthRowSplit(const TfLiteTensor* tensor) {
187   const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data);
188   const int tensor_length = tensor_shape.FlatSize();
189   if (tensor_length == 0 || tensor_length == 1) {
190     return 0;
191   }
192   auto value_rowsplit = [tensor](int index) {
193     switch (tensor->type) {
194       case kTfLiteInt32:
195         return static_cast<int>(tensor->data.i32[index]);
196       case kTfLiteInt64:
197         return static_cast<int>(tensor->data.i64[index]);
198       default:
199         // TODO(mgubin): Add error checks.
200         return 0;
201     }
202   };
203   int max_width = 1;
204   int prev_split = value_rowsplit(0);
205   for (int i = 1; i < tensor_length; ++i) {
206     const int split = value_rowsplit(i);
207     max_width = std::max(max_width, split - prev_split);
208     prev_split = split;
209   }
210   return max_width;
211 }
212 
GetMaxWidth(const ConversionAttributes & conversion_attributes,TfLiteContext * context,TfLiteNode * node,int dimension)213 int GetMaxWidth(const ConversionAttributes& conversion_attributes,
214                 TfLiteContext* context, TfLiteNode* node, int dimension) {
215   const TfLiteTensor* tensor = GetRowPartitionTensor(
216       conversion_attributes, context, node, dimension - 1);
217   switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) {
218     case tensorflow::RowPartitionType::VALUE_ROWIDS:
219       return GetMaxWidthValueRowID(tensor);
220     case tensorflow::RowPartitionType::ROW_SPLITS:
221       return GetMaxWidthRowSplit(tensor);
222     default:
223       context->ReportError(context, "Cannot handle partition type");
224       return -1;
225   }
226 }
227 
CombineRaggedTensorToTensorShapes(int ragged_rank,const RuntimeShape & output_shape,const RuntimeShape & value_shape)228 RuntimeShape CombineRaggedTensorToTensorShapes(
229     int ragged_rank, const RuntimeShape& output_shape,
230     const RuntimeShape& value_shape) {
231   // TODO(mgubin): No checks, see
232   // third_party/tensorflow/core/ops/ragged_to_dense_util.cc
233   RuntimeShape result(output_shape);
234   if (output_shape.DimensionsCount() == 0) {
235     const int output_shape_rank = ragged_rank + value_shape.DimensionsCount();
236     result.Resize(output_shape_rank);
237     for (int i = 0; i < output_shape_rank; ++i) {
238       result.SetDim(i, -1);
239     }
240   }
241   const int need_to_set =
242       output_shape.DimensionsCount() - value_shape.DimensionsCount();
243   for (int i = 1; i < value_shape.DimensionsCount(); ++i) {
244     result.SetDim(need_to_set + i, value_shape.Dims(i));
245   }
246   return result;
247 }
248 
CalculateOutputSize(const ConversionAttributes & conversion_attributes,TfLiteContext * context,TfLiteNode * node,int first_dimension,int ragged_rank,const TfLiteTensor & values,const TfLiteTensor & default_value,const TfLiteTensor & output_shape)249 RuntimeShape CalculateOutputSize(
250     const ConversionAttributes& conversion_attributes, TfLiteContext* context,
251     TfLiteNode* node, int first_dimension, int ragged_rank,
252     const TfLiteTensor& values, const TfLiteTensor& default_value,
253     const TfLiteTensor& output_shape) {
254   RuntimeShape values_shape(values.dims->size, values.dims->data);
255   RuntimeShape default_value_shape(default_value.dims->size,
256                                    default_value.dims->data);
257 
258   if (!ValidateDefaultValueShape(context, default_value_shape, values_shape)) {
259     return {};
260   }
261   RuntimeShape output_shape_shape = TensorShapeFromTensor(output_shape);
262 
263   RuntimeShape result_shape = CombineRaggedTensorToTensorShapes(
264       ragged_rank, output_shape_shape, values_shape);
265   if (result_shape.Dims(0) < 0) {
266     result_shape.SetDim(0, first_dimension);
267   }
268   for (int i = 1; i <= ragged_rank; ++i) {
269     if (result_shape.Dims(i) < 0) {
270       result_shape.SetDim(i,
271                           GetMaxWidth(conversion_attributes, context, node, i));
272     }
273   }
274   return result_shape;
275 }
276 
IntArrayFromShape(const RuntimeShape & shape)277 TfLiteIntArray* IntArrayFromShape(const RuntimeShape& shape) {
278   TfLiteIntArray* result = TfLiteIntArrayCreate(shape.DimensionsCount());
279   for (int i = 0; i < shape.DimensionsCount(); ++i) {
280     result->data[i] = shape.Dims(i);
281   }
282   return result;
283 }
284 
285 /**
286  * The output_index represents the index in the output tensor
287  * where the first element of a particular dimension would be written.
288  * If it is -1, it indicates that the index is out of scope.
289  * Example, given first_dimension = 10, first_dimension_output = 6,
290  * and output_index_multiplier = 100:
291  * result = [0 100 200 300 400 500 -1 -1 -1 -1]
292  * If first_dimension_output = 11 instead, then:
293  * result = [0 100 200 300 400 500 600 700 800 900]
294  */
CalculateFirstParentOutputIndex(int first_dimension,int output_index_multiplier,int first_dimension_output,std::vector<int> * result)295 void CalculateFirstParentOutputIndex(int first_dimension,
296                                      int output_index_multiplier,
297                                      int first_dimension_output,
298                                      std::vector<int>* result) {
299   const int min_dimension = std::min(first_dimension, first_dimension_output);
300   result->reserve(first_dimension);
301   int current_output_index = 0;
302   for (int i = 0; i < min_dimension;
303        ++i, current_output_index += output_index_multiplier) {
304     result->push_back(current_output_index);
305   }
306   for (int i = min_dimension; i < first_dimension; ++i) {
307     result->push_back(-1);
308   }
309 }
310 // Calculate the output index of the first element of a list.
311 // The parent_output_index is the same computation for the previous list.
312 // -1 indicates an element or list that is out of range.
313 // The output_index_multiplier is the number of output indices one moves
314 // forward for each column.
315 // E.g., given:
316 // value_rowids:[0 1 2 2 2 3 5 5 6]
317 // parent_output_index:[1000 1100 2000 2100 -1 3000 4000]
318 // output_index_multiplier: 10
319 // output_size: 2
320 // You get:
321 // result = [1000 1100 2000 2010 -1 2100 -1 -1 3000]
322 // result[0] = parent_output_index[value_rowids[0]]
323 // result[1] = parent_output_index[value_rowids[1]]
324 // result[2] = parent_output_index[value_rowids[2]]
325 // result[3] = parent_output_index[value_rowids[2] + 10]
326 // result[4] = -1 because it is the third element the size is 2.
327 // result[5] = parent_output_index[value_rowids[3]]
328 // result[6] = -1 because parent_output_index[value_rowids[6]] == -1
329 // result[7] = -1 because parent_output_index[value_rowids[6]] == -1
330 // result[8] = parent_output_index[value_rowids[7]]
CalculateOutputIndexValueRowID(const TfLiteTensor & value_rowids,const std::vector<int> & parent_output_index,int output_index_multiplier,int output_size,std::vector<int> * result)331 void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
332                                     const std::vector<int>& parent_output_index,
333                                     int output_index_multiplier,
334                                     int output_size, std::vector<int>* result) {
335   const RuntimeShape tensor_shape(value_rowids.dims->size,
336                                   value_rowids.dims->data);
337   const int index_size = tensor_shape.FlatSize();
338   result->reserve(index_size);
339   if (index_size == 0) {
340     return;
341   }
342 
343   auto value_rowids_val = [value_rowids](int index) {
344     switch (value_rowids.type) {
345       case kTfLiteInt32:
346         return static_cast<int>(value_rowids.data.i32[index]);
347       case kTfLiteInt64:
348         return static_cast<int>(value_rowids.data.i64[index]);
349       default:
350         // TODO(mgubin): Add error checks.
351         return 0;
352     }
353   };
354   int current_output_column = 0;
355   int current_value_rowid = value_rowids_val(0);
356   // DCHECK_LT(current_value_rowid, parent_output_index.size());
357   int current_output_index = parent_output_index[current_value_rowid];
358   result->push_back(current_output_index);
359   for (int i = 1; i < index_size; ++i) {
360     int next_value_rowid = value_rowids_val(i);
361     if (next_value_rowid == current_value_rowid) {
362       if (current_output_index >= 0) {
363         ++current_output_column;
364         if (current_output_column < output_size) {
365           current_output_index += output_index_multiplier;
366         } else {
367           current_output_index = -1;
368         }
369       }
370     } else {
371       current_output_column = 0;
372       current_value_rowid = next_value_rowid;
373       // DCHECK_LT(next_value_rowid, parent_output_index.size());
374       current_output_index = parent_output_index[next_value_rowid];
375     }
376     result->push_back(current_output_index);
377   }
378   // DCHECK_EQ(result->size(), value_rowids.size());
379 }
380 
CalculateOutputIndexRowSplit(const TfLiteTensor & row_split,const std::vector<int> & parent_output_index,int output_index_multiplier,int output_size,std::vector<int> * result)381 void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
382                                   const std::vector<int>& parent_output_index,
383                                   int output_index_multiplier, int output_size,
384                                   std::vector<int>* result) {
385   const RuntimeShape row_split_shape(row_split.dims->size,
386                                      row_split.dims->data);
387   const int row_split_size = row_split_shape.FlatSize();
388   auto row_split_val = [row_split](int index) {
389     switch (row_split.type) {
390       case kTfLiteInt32:
391         return static_cast<int>(row_split.data.i32[index]);
392       case kTfLiteInt64:
393         return static_cast<int>(row_split.data.i64[index]);
394       default:
395         // TODO(mgubin): Add error checks.
396         return 0;
397     }
398   };
399   if (row_split_size > 0) {
400     result->reserve(row_split_val(row_split_size - 1));
401   }
402   for (int i = 0; i < row_split_size - 1; ++i) {
403     const int row_length = row_split_val(i + 1) - row_split_val(i);
404     int real_length = std::min(output_size, row_length);
405     int parent_output_index_current = parent_output_index[i];
406 
407     if (parent_output_index_current == -1) {
408       real_length = 0;
409     }
410     for (int j = 0; j < real_length; ++j) {
411       result->push_back(parent_output_index_current);
412       parent_output_index_current += output_index_multiplier;
413     }
414     for (int j = 0; j < row_length - real_length; ++j) {
415       result->push_back(-1);
416     }
417   }
418   // if (row_split_size > 0) {
419   //  DCHECK_EQ(result->size(), row_split(row_split_size - 1));
420   //}
421 }
422 
CalculateOutputIndex(const ConversionAttributes & conversion_attributes,TfLiteContext * context,TfLiteNode * node,int dimension,const std::vector<int> & parent_output_index,int output_index_multiplier,int output_size,std::vector<int> * result)423 TfLiteStatus CalculateOutputIndex(
424     const ConversionAttributes& conversion_attributes, TfLiteContext* context,
425     TfLiteNode* node, int dimension,
426     const std::vector<int>& parent_output_index, int output_index_multiplier,
427     int output_size, std::vector<int>* result) {
428   const TfLiteTensor* row_partition_tensor =
429       GetRowPartitionTensor(conversion_attributes, context, node, dimension);
430   auto partition_type =
431       conversion_attributes.GetRowPartitionTypeByDimension(dimension);
432   switch (partition_type) {
433     case tensorflow::RowPartitionType::VALUE_ROWIDS:
434       CalculateOutputIndexValueRowID(*row_partition_tensor, parent_output_index,
435                                      output_index_multiplier, output_size,
436                                      result);
437       return kTfLiteOk;
438     case tensorflow::RowPartitionType::ROW_SPLITS:
439       CalculateOutputIndexRowSplit(*row_partition_tensor, parent_output_index,
440                                    output_index_multiplier, output_size,
441                                    result);
442       return kTfLiteOk;
443     default:
444       context->ReportError(context, "Unsupported partition type");
445       return kTfLiteError;
446   }
447 }
448 
449 template <typename VALUE_TYPE>
SetOutputT(TfLiteContext * context,int ragged_rank,const std::vector<int> & output_index,const TfLiteTensor & values_tensor,const TfLiteTensor & default_value_tensor,TfLiteTensor * output_tensor)450 void SetOutputT(TfLiteContext* context, int ragged_rank,
451                 const std::vector<int>& output_index,
452                 const TfLiteTensor& values_tensor,
453                 const TfLiteTensor& default_value_tensor,
454                 TfLiteTensor* output_tensor) {
455   const VALUE_TYPE* values_base = GetTensorData<VALUE_TYPE>(&values_tensor);
456   VALUE_TYPE* output_base = GetTensorData<VALUE_TYPE>(output_tensor);
457   const VALUE_TYPE* default_value =
458       GetTensorData<VALUE_TYPE>(&default_value_tensor);
459 
460   RuntimeShape output_shape = GetTensorShape(output_tensor);
461   RuntimeShape element_shape =
462       RuntimeShape(output_shape.DimensionsCount() - ragged_rank - 1,
463                    output_shape.DimsData() + ragged_rank + 1);
464 
465   // element_shape.RemoveDimRange(0, ragged_rank + 1);
466   const int value_element_size = element_shape.FlatSize();
467   size_t output_index_size = output_index.size();
468 
469   // Loop through the output_index vector, finding contiguous regions that
470   // should be copied.  Once we find the end of a contiguous region, copy it
471   // and add any necessary padding (with default_value).
472   int src_start = 0;  // Start of contiguous region (in values)
473   int dst_start = 0;  // Destination for contiguous region (in output)
474   int dst_end = 0;    // Destination for contiguous region (in output)
475   for (int src_i = 0; src_i <= output_index_size; ++src_i) {
476     // dst_i is the destination where the value at src_i should be copied.
477     int dst_i = src_i < output_index_size ? output_index[src_i] : -1;
478 
479     // If we're still in a contiguous region, then update dst_end go to the
480     // next src_i.
481     if (dst_i == dst_end) {
482       ++dst_end;
483       continue;
484     }
485 
486     // We found the end of contiguous region.  This can be because we found
487     // a gap (dst_i > dst_end), or a source value that shouldn't be copied
488     // because it's out-of-bounds (dst_i == -1), or the end of the tensor
489     // (dst_i = -1).
490     if (dst_start < dst_end) {
491       // Copy the contiguous region.
492       const VALUE_TYPE* src = values_base + src_start * value_element_size;
493       VALUE_TYPE* dst = output_base + dst_start * value_element_size;
494       int nvals = (dst_end - dst_start) * value_element_size;
495       std::copy(src, src + nvals, dst);
496       // copy_array<VALUE_TYPE, int>(dst, src, nvals);
497     }
498 
499     // Add any necessary padding (w/ default_value).
500     if (src_i >= output_index_size) {
501       // We reached the end of values: pad to the end of output.
502       const int output_size = output_shape.FlatSize();
503       dst_i = output_size / value_element_size;
504     }
505     if (dst_i > dst_end) {
506       std::fill(output_base + dst_end * value_element_size,
507                 output_base + dst_i * value_element_size, *default_value);
508       dst_end = dst_i;
509     }
510 
511     // Update indices.
512     if (dst_i < 0) {
513       // src_i should be skipped -- leave it out of the contiguous region.
514       src_start = src_i + 1;
515       dst_start = dst_end;
516     } else {
517       // src_i should be copied -- include it in the contiguous region.
518       src_start = src_i;
519       dst_start = dst_end;
520       dst_end = dst_start + 1;
521     }
522   }
523 }
524 
SetOutput(TfLiteContext * context,int ragged_rank,const std::vector<int> & output_index,const TfLiteTensor & values_tensor,const TfLiteTensor & default_value_tensor,TfLiteTensor * output_tensor)525 void SetOutput(TfLiteContext* context, int ragged_rank,
526                const std::vector<int>& output_index,
527                const TfLiteTensor& values_tensor,
528                const TfLiteTensor& default_value_tensor,
529                TfLiteTensor* output_tensor) {
530   switch (output_tensor->type) {
531     case kTfLiteInt32:
532       SetOutputT<int32_t>(context, ragged_rank, output_index, values_tensor,
533                           default_value_tensor, output_tensor);
534       break;
535     case kTfLiteInt64:
536       SetOutputT<int64_t>(context, ragged_rank, output_index, values_tensor,
537                           default_value_tensor, output_tensor);
538       break;
539     case kTfLiteFloat32:
540       SetOutputT<float>(context, ragged_rank, output_index, values_tensor,
541                         default_value_tensor, output_tensor);
542       break;
543     default:
544       context->ReportError(context, "Not supported values type");
545   }
546 }
547 
548 }  // namespace
549 
Initialize(TfLiteContext * context,const char * buffer,size_t length)550 void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
551   auto attributes = std::make_unique<ConversionAttributes>();
552 
553   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
554 
555   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
556   // TODO (mgubin): Converting flat buffer to a vector of strings looks not very
557   // effective but simple. A cleaner way is needed.
558   const flexbuffers::TypedVector row_partition_types_attr =
559       m[kRowPartitionTypesAttr].AsTypedVector();
560   std::vector<std::string> row_partition_types_attr_strings;
561   row_partition_types_attr_strings.reserve(row_partition_types_attr.size());
562   for (int i = 0; i < row_partition_types_attr.size(); ++i) {
563     row_partition_types_attr_strings.emplace_back(
564         row_partition_types_attr[i].AsString().str());
565   }
566   attributes->partition_types =
567       tensorflow::GetRowPartitionTypesHelper(row_partition_types_attr_strings);
568   if (attributes->partition_types.size() !=
569       row_partition_types_attr_strings.size()) {
570     context->ReportError(context, "Can't parse partition type attribute");
571     return nullptr;
572   }
573   attributes->ragged_rank =
574       tensorflow::GetRaggedRank(attributes->partition_types);
575   return attributes.release();
576 }
Free(TfLiteContext *,void * buffer)577 void Free(TfLiteContext* /*context*/, void* buffer) {
578   ConversionAttributes* attributes =
579       reinterpret_cast<ConversionAttributes*>(buffer);
580   delete attributes;
581 }
582 
Prepare(TfLiteContext * context,TfLiteNode * node)583 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
584   const ConversionAttributes* attributes =
585       reinterpret_cast<ConversionAttributes*>(node->user_data);
586   if (attributes == nullptr) {
587     // Parsing attributes failed, can't prepare.
588     context->ReportError(context, "Attributes are not initialized");
589     return kTfLiteError;
590   }
591   // The output tensor need to be set to dynamic because it can have different
592   // size.
593   TfLiteTensor& output_tensor =
594       context->tensors[node->outputs->data[kOutputTensor]];
595   SetTensorToDynamic(&output_tensor);
596 
597   // Check that input shape tensor is int32 or int64
598   TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]];
599   if (input_shape.type != kTfLiteInt32 && input_shape.type != kTfLiteInt64) {
600     context->ReportError(context,
601                          "Input form tensor could be only int32 or int64");
602     return kTfLiteError;
603   }
604   return kTfLiteOk;
605 }
606 
Eval(TfLiteContext * context,TfLiteNode * node)607 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
608   const ConversionAttributes* attributes =
609       reinterpret_cast<ConversionAttributes*>(node->user_data);
610   TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]];
611   TfLiteTensor& input_values =
612       context->tensors[node->inputs->data[kValuesInput]];
613   TfLiteTensor& default_value =
614       context->tensors[node->inputs->data[kDefaultValueInput]];
615   // TODO (mgubin): Only scallar default value is supported.
616   if (RuntimeShape(default_value.dims->size, default_value.dims->data)
617           .FlatSize() != 1) {
618     context->ReportError(context, "Only scallar default value is supported");
619     return kTfLiteError;
620   }
621   TfLiteTensor& first_partition_input =
622       context->tensors[node->inputs->data[kFirstPartitionInputIndex]];
623 
624   // Calculate dimensions.
625   const int first_dimension =
626       GetFirstDimensionSize(context, first_partition_input, attributes);
627   if (first_dimension < 0) {
628     return kTfLiteError;
629   }
630   RuntimeShape output_shape = CalculateOutputSize(
631       *attributes, context, node, first_dimension, attributes->ragged_rank,
632       input_values, default_value, input_shape);
633   if (output_shape.DimensionsCount() == 0) {
634     return kTfLiteError;
635   }
636 
637   std::vector<int> multiplier;
638   multiplier.resize(attributes->ragged_rank + 1);
639   multiplier.back() = 1;
640   for (int i = multiplier.size() - 2; i >= 0; --i) {
641     multiplier[i] = multiplier[i + 1] * output_shape.Dims(i + 1);
642   }
643 
644   // Allocate output tensor.
645   TfLiteTensor& output_tensor =
646       context->tensors[node->outputs->data[kOutputTensor]];
647 
648   TF_LITE_ENSURE_OK(context,
649                     context->ResizeTensor(context, &output_tensor,
650                                           IntArrayFromShape(output_shape)));
651 
652   // Copy data.
653   const int full_size = multiplier.front() * output_shape.Dims(0);
654   if (full_size > 0) {
655     std::vector<int> output_index, new_output_index;
656     int nvals = input_values.dims->data[0];
657     output_index.reserve(nvals);
658     new_output_index.reserve(nvals);
659 
660     CalculateFirstParentOutputIndex(first_dimension, multiplier[0],
661                                     output_shape.Dims(0), &output_index);
662     for (int i = 1; i <= attributes->ragged_rank; ++i) {
663       TF_LITE_ENSURE_OK(
664           context, CalculateOutputIndex(
665                        *attributes, context, node, i - 1, output_index,
666                        multiplier[i], output_shape.Dims(i), &new_output_index));
667       output_index.swap(new_output_index);
668       new_output_index.clear();
669     }
670 
671     SetOutput(context, attributes->ragged_rank, output_index, input_values,
672               default_value, &output_tensor);
673   }
674   return kTfLiteOk;
675 }
676 
677 }  // namespace ragged_tensor_to_tensor
678 }  // namespace ragged
679 
Register_RAGGED_TENSOR_TO_TENSOR()680 TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR() {
681   static TfLiteRegistration r = {ragged::ragged_tensor_to_tensor::Initialize,
682                                  ragged::ragged_tensor_to_tensor::Free,
683                                  ragged::ragged_tensor_to_tensor::Prepare,
684                                  ragged::ragged_tensor_to_tensor::Eval};
685   return &r;
686 }
687 
688 }  // namespace custom
689 }  // namespace ops
690 }  // namespace tflite
691