1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/string_util.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace fill {
29
30 namespace {
31
32 constexpr int kDimsTensor = 0;
33 constexpr int kValueTensor = 1;
34 constexpr int kOutputTensor = 0;
35
36 template <typename T>
ResizeOutputImpl(TfLiteContext * context,const TfLiteTensor * dims,TfLiteTensor * output)37 TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims,
38 TfLiteTensor* output) {
39 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dims->dims->data[0]);
40 for (int i = 0; i < output_shape->size; ++i) {
41 T data = GetTensorData<T>(dims)[i];
42 if (data < 0) {
43 TfLiteIntArrayFree(output_shape);
44 TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0", dims->type);
45 return kTfLiteError;
46 }
47 output_shape->data[i] = data;
48 }
49 return context->ResizeTensor(context, output, output_shape);
50 }
51
ResizeOutput(TfLiteContext * context,const TfLiteTensor * dims,TfLiteTensor * output)52 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* dims,
53 TfLiteTensor* output) {
54 switch (dims->type) {
55 case kTfLiteInt32:
56 return ResizeOutputImpl<int32_t>(context, dims, output);
57 case kTfLiteInt64:
58 return ResizeOutputImpl<int64_t>(context, dims, output);
59 default:
60 TF_LITE_KERNEL_LOG(
61 context,
62 "Fill only currently supports int32, int64 for input 0, "
63 "got %d.",
64 dims->type);
65 return kTfLiteError;
66 }
67 }
68
69 } // namespace
70
Prepare(TfLiteContext * context,TfLiteNode * node)71 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
72 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
73 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
74
75 const TfLiteTensor* dims;
76 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
77 const TfLiteTensor* value;
78 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
79
80 // Make sure the 1st input tensor is 1-D.
81 TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1);
82
83 // Make sure the 1st input tensor is int32 or int64.
84 const auto dtype = dims->type;
85 TF_LITE_ENSURE(context, dtype == kTfLiteInt32 || dtype == kTfLiteInt64);
86
87 // Make sure the 2nd input tensor is a scalar.
88 TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
89
90 TfLiteTensor* output;
91 TF_LITE_ENSURE_OK(context,
92 GetOutputSafe(context, node, kOutputTensor, &output));
93 output->type = value->type;
94
95 TF_LITE_ENSURE_EQ(context, output->params.scale, value->params.scale);
96 TF_LITE_ENSURE_EQ(context, output->params.zero_point,
97 value->params.zero_point);
98
99 if (value->type == kTfLiteInt16) {
100 TF_LITE_ENSURE_EQ(context, value->params.zero_point, 0);
101 }
102
103 if (IsConstantTensor(dims)) {
104 TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
105 } else {
106 SetTensorToDynamic(output);
107 }
108 return kTfLiteOk;
109 }
110
FillString(const TfLiteTensor * value,TfLiteTensor * output)111 TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
112 DynamicBuffer buffer;
113 const auto string_ref = GetString(value, 0);
114 int n = 1;
115 for (int i = 0; i < output->dims->size; ++i) {
116 n *= output->dims->data[i];
117 }
118 for (int i = 0; i < n; ++i) {
119 buffer.AddString(string_ref.str, string_ref.len);
120 }
121 buffer.WriteToTensor(output, /*new_shape=*/nullptr);
122 return kTfLiteOk;
123 }
124
Eval(TfLiteContext * context,TfLiteNode * node)125 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
126 const TfLiteTensor* value;
127 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
128
129 TfLiteTensor* output;
130 TF_LITE_ENSURE_OK(context,
131 GetOutputSafe(context, node, kOutputTensor, &output));
132
133 if (IsDynamicTensor(output)) {
134 const TfLiteTensor* dims;
135 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
136 TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
137 }
138 #define TF_LITE_FILL(data_type) \
139 reference_ops::Fill(GetTensorShape(value), GetTensorData<data_type>(value), \
140 GetTensorShape(output), \
141 GetTensorData<data_type>(output))
142 switch (output->type) {
143 case kTfLiteInt8:
144 TF_LITE_FILL(int8_t);
145 break;
146 case kTfLiteInt16:
147 TF_LITE_FILL(int16_t);
148 break;
149 case kTfLiteInt32:
150 TF_LITE_FILL(int32_t);
151 break;
152 case kTfLiteInt64:
153 TF_LITE_FILL(int64_t);
154 break;
155 case kTfLiteFloat32:
156 TF_LITE_FILL(float);
157 break;
158 case kTfLiteBool:
159 TF_LITE_FILL(bool);
160 break;
161 case kTfLiteString:
162 FillString(value, output);
163 break;
164 default:
165 TF_LITE_KERNEL_LOG(
166 context,
167 "Fill only currently supports int8, int16, int32, int64, float32, "
168 "bool, string for input 1, got %d.",
169 value->type);
170 return kTfLiteError;
171 }
172 #undef TF_LITE_FILL
173 return kTfLiteOk;
174 }
175
176 } // namespace fill
177
Register_FILL()178 TfLiteRegistration* Register_FILL() {
179 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
180 fill::Prepare, fill::Eval};
181 return &r;
182 }
183
184 } // namespace builtin
185 } // namespace ops
186 } // namespace tflite
187