1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h" // TF:flatbuffers
20 #include "tensorflow/lite/string.h"
21
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29 const std::vector<string>& values) {
30 DynamicBuffer dynamic_buffer;
31 for (const string& s : values) {
32 dynamic_buffer.AddString(s.data(), s.size());
33 }
34 dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35 /*new_shape=*/nullptr);
36 }
37
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39 std::vector<string> result;
40
41 TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42 auto num_strings = GetStringCount(tensor->data.raw);
43 for (size_t i = 0; i < num_strings; ++i) {
44 auto ref = GetString(tensor->data.raw, i);
45 result.push_back(string(ref.str, ref.len));
46 }
47
48 return result;
49 }
50
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52 ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57 std::vector<int> result;
58 auto* dims = interpreter_->tensor(tensor_index)->dims;
59 result.reserve(dims->size);
60 for (int i = 0; i < dims->size; ++i) {
61 result.push_back(dims->data[i]);
62 }
63 return result;
64 }
65
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67 return interpreter_->tensor(tensor_index)->type;
68 }
69
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)70 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
71 const std::vector<int>& outputs, TfLiteType type,
72 const std::vector<int>& dims) {
73 interpreter_->AddTensors(num_tensors);
74 for (int i = 0; i < num_tensors; ++i) {
75 TfLiteQuantizationParams quant;
76 // Suppress explicit output type specification to ensure type inference
77 // works properly.
78 if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
79 type = kTfLiteFloat32;
80 }
81 CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
82 /*name=*/"",
83 /*dims=*/dims, quant),
84 kTfLiteOk);
85 }
86
87 CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
88 CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
89 }
90
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)91 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
92 const std::vector<int>& outputs) {
93 ++next_op_index_;
94
95 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
96 reg.builtin_code = BuiltinOperator_MUL;
97 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
98 auto* i0 = &context->tensors[node->inputs->data[0]];
99 auto* o = &context->tensors[node->outputs->data[0]];
100 return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
101 };
102 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
103 auto* i0 = &context->tensors[node->inputs->data[0]];
104 auto* i1 = &context->tensors[node->inputs->data[1]];
105 auto* o = &context->tensors[node->outputs->data[0]];
106 for (int i = 0; i < o->bytes / sizeof(float); ++i) {
107 o->data.f[i] = i0->data.f[i] * i1->data.f[i];
108 }
109 return kTfLiteOk;
110 };
111
112 CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
113 nullptr, ®),
114 kTfLiteOk);
115 }
116
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)117 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
118 const std::vector<int>& outputs) {
119 tf_ops_.push_back(next_op_index_);
120 ++next_op_index_;
121
122 auto attr = [](const string& key, const string& value) {
123 return " attr{ key: '" + key + "' value {" + value + "}}";
124 };
125
126 string type_attribute;
127 switch (interpreter_->tensor(inputs[0])->type) {
128 case kTfLiteInt32:
129 type_attribute = attr("T", "type: DT_INT32");
130 break;
131 case kTfLiteFloat32:
132 type_attribute = attr("T", "type: DT_FLOAT");
133 break;
134 case kTfLiteString:
135 type_attribute = attr("T", "type: DT_STRING");
136 break;
137 default:
138 // TODO(b/113613439): Use nodedef string utilities to properly handle all
139 // types.
140 LOG(FATAL) << "Type not supported";
141 break;
142 }
143
144 if (op == kUnpack) {
145 string attributes =
146 type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
147 AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
148 } else if (op == kIdentity) {
149 string attributes = type_attribute;
150 AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
151 } else if (op == kAdd) {
152 string attributes = type_attribute;
153 AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
154 } else if (op == kMul) {
155 string attributes = type_attribute;
156 AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
157 } else if (op == kNonExistent) {
158 AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
159 } else if (op == kIncompatibleNodeDef) {
160 // "Cast" op is created without attributes - making it incompatible.
161 AddTfOp("FlexCast", "Cast", "", inputs, outputs);
162 }
163 }
164
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)165 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
166 const string& nodedef_str,
167 const std::vector<int>& inputs,
168 const std::vector<int>& outputs) {
169 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
170 reg.builtin_code = BuiltinOperator_CUSTOM;
171 reg.custom_name = tflite_name;
172
173 tensorflow::NodeDef nodedef;
174 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
175 nodedef_str + " op: '" + tf_name + "'", &nodedef));
176 string serialized_nodedef;
177 CHECK(nodedef.SerializeToString(&serialized_nodedef));
178 flexbuffers::Builder fbb;
179 fbb.Vector([&]() {
180 fbb.String(nodedef.op());
181 fbb.String(serialized_nodedef);
182 });
183 fbb.Finish();
184
185 flexbuffers_.push_back(fbb.GetBuffer());
186 auto& buffer = flexbuffers_.back();
187 CHECK_EQ(interpreter_->AddNodeWithParameters(
188 inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
189 buffer.size(), nullptr, ®),
190 kTfLiteOk);
191 }
192
193 } // namespace testing
194 } // namespace flex
195 } // namespace tflite
196