1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
20 #include "tensorflow/lite/string_type.h"
21
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29 const std::vector<string>& values) {
30 DynamicBuffer dynamic_buffer;
31 for (const string& s : values) {
32 dynamic_buffer.AddString(s.data(), s.size());
33 }
34 dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35 /*new_shape=*/nullptr);
36 }
37
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39 std::vector<string> result;
40
41 TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42 auto num_strings = GetStringCount(tensor);
43 for (size_t i = 0; i < num_strings; ++i) {
44 auto ref = GetString(tensor, i);
45 result.push_back(string(ref.str, ref.len));
46 }
47
48 return result;
49 }
50
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52 ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57 std::vector<int> result;
58 auto* dims = interpreter_->tensor(tensor_index)->dims;
59 result.reserve(dims->size);
60 for (int i = 0; i < dims->size; ++i) {
61 result.push_back(dims->data[i]);
62 }
63 return result;
64 }
65
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67 return interpreter_->tensor(tensor_index)->type;
68 }
69
IsDynamicTensor(int tensor_index)70 bool FlexModelTest::IsDynamicTensor(int tensor_index) {
71 return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
72 }
73
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)74 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
75 const std::vector<int>& outputs, TfLiteType type,
76 const std::vector<int>& dims) {
77 interpreter_->AddTensors(num_tensors);
78 for (int i = 0; i < num_tensors; ++i) {
79 TfLiteQuantizationParams quant;
80 // Suppress explicit output type specification to ensure type inference
81 // works properly.
82 if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
83 type = kTfLiteFloat32;
84 }
85 CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
86 /*name=*/"",
87 /*dims=*/dims, quant),
88 kTfLiteOk);
89 }
90
91 CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
92 CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
93 }
94
SetConstTensor(int tensor_index,const std::vector<int> & values,TfLiteType type,const char * buffer,size_t bytes)95 void FlexModelTest::SetConstTensor(int tensor_index,
96 const std::vector<int>& values,
97 TfLiteType type, const char* buffer,
98 size_t bytes) {
99 TfLiteQuantizationParams quant;
100 CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
101 /*name=*/"",
102 /*dims=*/values, quant,
103 buffer, bytes),
104 kTfLiteOk);
105 }
106
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)107 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
108 const std::vector<int>& outputs) {
109 ++next_op_index_;
110
111 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
112 reg.builtin_code = BuiltinOperator_MUL;
113 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
114 auto* i0 = &context->tensors[node->inputs->data[0]];
115 auto* o = &context->tensors[node->outputs->data[0]];
116 return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
117 };
118 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
119 auto* i0 = &context->tensors[node->inputs->data[0]];
120 auto* i1 = &context->tensors[node->inputs->data[1]];
121 auto* o = &context->tensors[node->outputs->data[0]];
122 for (int i = 0; i < o->bytes / sizeof(float); ++i) {
123 o->data.f[i] = i0->data.f[i] * i1->data.f[i];
124 }
125 return kTfLiteOk;
126 };
127
128 CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
129 nullptr, ®),
130 kTfLiteOk);
131 }
132
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)133 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
134 const std::vector<int>& outputs) {
135 tf_ops_.push_back(next_op_index_);
136 ++next_op_index_;
137
138 auto attr = [](const string& key, const string& value) {
139 return " attr{ key: '" + key + "' value {" + value + "}}";
140 };
141
142 string type_attribute;
143 switch (interpreter_->tensor(inputs[0])->type) {
144 case kTfLiteInt32:
145 type_attribute = attr("T", "type: DT_INT32");
146 break;
147 case kTfLiteFloat32:
148 type_attribute = attr("T", "type: DT_FLOAT");
149 break;
150 case kTfLiteString:
151 type_attribute = attr("T", "type: DT_STRING");
152 break;
153 default:
154 // TODO(b/113613439): Use nodedef string utilities to properly handle all
155 // types.
156 LOG(FATAL) << "Type not supported";
157 break;
158 }
159
160 if (op == kUnpack) {
161 string attributes =
162 type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
163 AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
164 } else if (op == kIdentity) {
165 string attributes = type_attribute;
166 AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
167 } else if (op == kAdd) {
168 string attributes = type_attribute;
169 AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
170 } else if (op == kMul) {
171 string attributes = type_attribute;
172 AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
173 } else if (op == kRfft) {
174 AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
175 } else if (op == kImag) {
176 AddTfOp("FlexImag", "Imag", "", inputs, outputs);
177 } else if (op == kNonExistent) {
178 AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
179 } else if (op == kIncompatibleNodeDef) {
180 // "Cast" op is created without attributes - making it incompatible.
181 AddTfOp("FlexCast", "Cast", "", inputs, outputs);
182 }
183 }
184
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)185 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
186 const string& nodedef_str,
187 const std::vector<int>& inputs,
188 const std::vector<int>& outputs) {
189 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
190 reg.builtin_code = BuiltinOperator_CUSTOM;
191 reg.custom_name = tflite_name;
192
193 tensorflow::NodeDef nodedef;
194 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
195 nodedef_str + " op: '" + tf_name + "'", &nodedef));
196 string serialized_nodedef;
197 CHECK(nodedef.SerializeToString(&serialized_nodedef));
198 flexbuffers::Builder fbb;
199 fbb.Vector([&]() {
200 fbb.String(nodedef.op());
201 fbb.String(serialized_nodedef);
202 });
203 fbb.Finish();
204
205 flexbuffers_.push_back(fbb.GetBuffer());
206 auto& buffer = flexbuffers_.back();
207 CHECK_EQ(interpreter_->AddNodeWithParameters(
208 inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
209 buffer.size(), nullptr, ®),
210 kTfLiteOk);
211 }
212
213 } // namespace testing
214 } // namespace flex
215 } // namespace tflite
216