1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
20 #include "tensorflow/lite/string_type.h"
21
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29 const std::vector<string>& values) {
30 DynamicBuffer dynamic_buffer;
31 for (const string& s : values) {
32 dynamic_buffer.AddString(s.data(), s.size());
33 }
34 dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35 /*new_shape=*/nullptr);
36 }
37
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39 std::vector<string> result;
40
41 TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42 auto num_strings = GetStringCount(tensor);
43 for (size_t i = 0; i < num_strings; ++i) {
44 auto ref = GetString(tensor, i);
45 result.push_back(string(ref.str, ref.len));
46 }
47
48 return result;
49 }
50
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52 ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57 std::vector<int> result;
58 auto* dims = interpreter_->tensor(tensor_index)->dims;
59 result.reserve(dims->size);
60 for (int i = 0; i < dims->size; ++i) {
61 result.push_back(dims->data[i]);
62 }
63 return result;
64 }
65
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67 return interpreter_->tensor(tensor_index)->type;
68 }
69
IsDynamicTensor(int tensor_index)70 bool FlexModelTest::IsDynamicTensor(int tensor_index) {
71 return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
72 }
73
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)74 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
75 const std::vector<int>& outputs, TfLiteType type,
76 const std::vector<int>& dims) {
77 interpreter_->AddTensors(num_tensors);
78 for (int i = 0; i < num_tensors; ++i) {
79 TfLiteQuantizationParams quant;
80 // Suppress explicit output type specification to ensure type inference
81 // works properly.
82 if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
83 type = kTfLiteFloat32;
84 }
85 CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
86 /*name=*/"",
87 /*dims=*/dims, quant),
88 kTfLiteOk);
89 }
90
91 CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
92 CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
93 }
94
SetConstTensor(int tensor_index,const std::vector<int> & values,TfLiteType type,const char * buffer,size_t bytes)95 void FlexModelTest::SetConstTensor(int tensor_index,
96 const std::vector<int>& values,
97 TfLiteType type, const char* buffer,
98 size_t bytes) {
99 TfLiteQuantizationParams quant;
100 CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
101 /*name=*/"",
102 /*dims=*/values, quant,
103 buffer, bytes),
104 kTfLiteOk);
105 }
106
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)107 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
108 const std::vector<int>& outputs) {
109 ++next_op_index_;
110
111 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
112 reg.builtin_code = BuiltinOperator_MUL;
113 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
114 auto* i0 = &context->tensors[node->inputs->data[0]];
115 auto* o = &context->tensors[node->outputs->data[0]];
116 return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
117 };
118 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
119 auto* i0 = &context->tensors[node->inputs->data[0]];
120 auto* i1 = &context->tensors[node->inputs->data[1]];
121 auto* o = &context->tensors[node->outputs->data[0]];
122 for (int i = 0; i < o->bytes / sizeof(float); ++i) {
123 o->data.f[i] = i0->data.f[i] * i1->data.f[i];
124 }
125 return kTfLiteOk;
126 };
127
128 CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
129 nullptr, ®),
130 kTfLiteOk);
131 }
132
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)133 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
134 const std::vector<int>& outputs) {
135 tf_ops_.push_back(next_op_index_);
136 ++next_op_index_;
137
138 auto attr = [](const string& key, const string& value) {
139 return " attr{ key: '" + key + "' value {" + value + "}}";
140 };
141
142 string type_attribute;
143 switch (interpreter_->tensor(inputs[0])->type) {
144 case kTfLiteInt32:
145 type_attribute = attr("T", "type: DT_INT32");
146 break;
147 case kTfLiteFloat32:
148 type_attribute = attr("T", "type: DT_FLOAT");
149 break;
150 case kTfLiteString:
151 type_attribute = attr("T", "type: DT_STRING");
152 break;
153 case kTfLiteBool:
154 type_attribute = attr("T", "type: DT_BOOL");
155 break;
156 default:
157 // TODO(b/113613439): Use nodedef string utilities to properly handle all
158 // types.
159 LOG(FATAL) << "Type not supported";
160 break;
161 }
162
163 if (op == kUnpack) {
164 string attributes =
165 type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
166 AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
167 } else if (op == kIdentity) {
168 string attributes = type_attribute;
169 AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
170 } else if (op == kAdd) {
171 string attributes = type_attribute;
172 AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
173 } else if (op == kMul) {
174 string attributes = type_attribute;
175 AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
176 } else if (op == kRfft) {
177 AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
178 } else if (op == kImag) {
179 AddTfOp("FlexImag", "Imag", "", inputs, outputs);
180 } else if (op == kLoopCond) {
181 string attributes = type_attribute;
182 AddTfOp("FlexLoopCond", "LoopCond", attributes, inputs, outputs);
183 } else if (op == kNonExistent) {
184 AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
185 } else if (op == kIncompatibleNodeDef) {
186 // "Cast" op is created without attributes - making it incompatible.
187 AddTfOp("FlexCast", "Cast", "", inputs, outputs);
188 }
189 }
190
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)191 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
192 const string& nodedef_str,
193 const std::vector<int>& inputs,
194 const std::vector<int>& outputs) {
195 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
196 reg.builtin_code = BuiltinOperator_CUSTOM;
197 reg.custom_name = tflite_name;
198
199 tensorflow::NodeDef nodedef;
200 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
201 nodedef_str + " op: '" + tf_name + "'", &nodedef));
202 string serialized_nodedef;
203 CHECK(nodedef.SerializeToString(&serialized_nodedef));
204 flexbuffers::Builder fbb;
205 fbb.Vector([&]() {
206 fbb.String(nodedef.op());
207 fbb.String(serialized_nodedef);
208 });
209 fbb.Finish();
210
211 flexbuffers_.push_back(fbb.GetBuffer());
212 auto& buffer = flexbuffers_.back();
213 CHECK_EQ(interpreter_->AddNodeWithParameters(
214 inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
215 buffer.size(), nullptr, ®),
216 kTfLiteOk);
217 }
218
219 } // namespace testing
220 } // namespace flex
221 } // namespace tflite
222