• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/string_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace fill {
29 
30 namespace {
31 
32 constexpr int kDimsTensor = 0;
33 constexpr int kValueTensor = 1;
34 constexpr int kOutputTensor = 0;
35 
36 template <typename T>
ResizeOutputImpl(TfLiteContext * context,const TfLiteTensor * dims,TfLiteTensor * output)37 TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims,
38                               TfLiteTensor* output) {
39   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dims->dims->data[0]);
40   for (int i = 0; i < output_shape->size; ++i) {
41     T data = GetTensorData<T>(dims)[i];
42     if (data < 0) {
43       TfLiteIntArrayFree(output_shape);
44       TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0", dims->type);
45       return kTfLiteError;
46     }
47     output_shape->data[i] = data;
48   }
49   return context->ResizeTensor(context, output, output_shape);
50 }
51 
ResizeOutput(TfLiteContext * context,const TfLiteTensor * dims,TfLiteTensor * output)52 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* dims,
53                           TfLiteTensor* output) {
54   switch (dims->type) {
55     case kTfLiteInt32:
56       return ResizeOutputImpl<int32_t>(context, dims, output);
57     case kTfLiteInt64:
58       return ResizeOutputImpl<int64_t>(context, dims, output);
59     default:
60       TF_LITE_KERNEL_LOG(
61           context,
62           "Fill only currently supports int32, int64 for input 0, "
63           "got %d.",
64           dims->type);
65       return kTfLiteError;
66   }
67 }
68 
69 }  // namespace
70 
Prepare(TfLiteContext * context,TfLiteNode * node)71 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
72   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
73   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
74 
75   const TfLiteTensor* dims;
76   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
77   const TfLiteTensor* value;
78   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
79 
80   // Make sure the 1st input tensor is 1-D.
81   TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1);
82 
83   // Make sure the 1st input tensor is int32 or int64.
84   const auto dtype = dims->type;
85   TF_LITE_ENSURE(context, dtype == kTfLiteInt32 || dtype == kTfLiteInt64);
86 
87   // Make sure the 2nd input tensor is a scalar.
88   TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
89 
90   TfLiteTensor* output;
91   TF_LITE_ENSURE_OK(context,
92                     GetOutputSafe(context, node, kOutputTensor, &output));
93   output->type = value->type;
94 
95   TF_LITE_ENSURE_EQ(context, output->params.scale, value->params.scale);
96   TF_LITE_ENSURE_EQ(context, output->params.zero_point,
97                     value->params.zero_point);
98 
99   if (value->type == kTfLiteInt16) {
100     TF_LITE_ENSURE_EQ(context, value->params.zero_point, 0);
101   }
102 
103   if (IsConstantTensor(dims)) {
104     TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
105   } else {
106     SetTensorToDynamic(output);
107   }
108   return kTfLiteOk;
109 }
110 
FillString(const TfLiteTensor * value,TfLiteTensor * output)111 TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
112   DynamicBuffer buffer;
113   const auto string_ref = GetString(value, 0);
114   int n = 1;
115   for (int i = 0; i < output->dims->size; ++i) {
116     n *= output->dims->data[i];
117   }
118   for (int i = 0; i < n; ++i) {
119     buffer.AddString(string_ref.str, string_ref.len);
120   }
121   buffer.WriteToTensor(output, /*new_shape=*/nullptr);
122   return kTfLiteOk;
123 }
124 
Eval(TfLiteContext * context,TfLiteNode * node)125 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
126   const TfLiteTensor* value;
127   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
128 
129   TfLiteTensor* output;
130   TF_LITE_ENSURE_OK(context,
131                     GetOutputSafe(context, node, kOutputTensor, &output));
132 
133   if (IsDynamicTensor(output)) {
134     const TfLiteTensor* dims;
135     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
136     TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
137   }
138 #define TF_LITE_FILL(data_type)                                               \
139   reference_ops::Fill(GetTensorShape(value), GetTensorData<data_type>(value), \
140                       GetTensorShape(output),                                 \
141                       GetTensorData<data_type>(output))
142   switch (output->type) {
143     case kTfLiteInt8:
144       TF_LITE_FILL(int8_t);
145       break;
146     case kTfLiteInt16:
147       TF_LITE_FILL(int16_t);
148       break;
149     case kTfLiteInt32:
150       TF_LITE_FILL(int32_t);
151       break;
152     case kTfLiteInt64:
153       TF_LITE_FILL(int64_t);
154       break;
155     case kTfLiteFloat32:
156       TF_LITE_FILL(float);
157       break;
158     case kTfLiteBool:
159       TF_LITE_FILL(bool);
160       break;
161     case kTfLiteString:
162       FillString(value, output);
163       break;
164     default:
165       TF_LITE_KERNEL_LOG(
166           context,
167           "Fill only currently supports int8, int16, int32, int64, float32, "
168           "bool, string for input 1, got %d.",
169           value->type);
170       return kTfLiteError;
171   }
172 #undef TF_LITE_FILL
173   return kTfLiteOk;
174 }
175 
176 }  // namespace fill
177 
Register_FILL()178 TfLiteRegistration* Register_FILL() {
179   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
180                                  fill::Prepare, fill::Eval};
181   return &r;
182 }
183 
184 }  // namespace builtin
185 }  // namespace ops
186 }  // namespace tflite
187