• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
20 #include "tensorflow/lite/string_type.h"
21 
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25 
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27 
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29                                     const std::vector<string>& values) {
30   DynamicBuffer dynamic_buffer;
31   for (const string& s : values) {
32     dynamic_buffer.AddString(s.data(), s.size());
33   }
34   dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35                                /*new_shape=*/nullptr);
36 }
37 
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39   std::vector<string> result;
40 
41   TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42   auto num_strings = GetStringCount(tensor);
43   for (size_t i = 0; i < num_strings; ++i) {
44     auto ref = GetString(tensor, i);
45     result.push_back(string(ref.str, ref.len));
46   }
47 
48   return result;
49 }
50 
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52   ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55 
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57   std::vector<int> result;
58   auto* dims = interpreter_->tensor(tensor_index)->dims;
59   result.reserve(dims->size);
60   for (int i = 0; i < dims->size; ++i) {
61     result.push_back(dims->data[i]);
62   }
63   return result;
64 }
65 
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67   return interpreter_->tensor(tensor_index)->type;
68 }
69 
IsDynamicTensor(int tensor_index)70 bool FlexModelTest::IsDynamicTensor(int tensor_index) {
71   return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
72 }
73 
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)74 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
75                                const std::vector<int>& outputs, TfLiteType type,
76                                const std::vector<int>& dims) {
77   interpreter_->AddTensors(num_tensors);
78   for (int i = 0; i < num_tensors; ++i) {
79     TfLiteQuantizationParams quant;
80     // Suppress explicit output type specification to ensure type inference
81     // works properly.
82     if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
83       type = kTfLiteFloat32;
84     }
85     CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
86                                                         /*name=*/"",
87                                                         /*dims=*/dims, quant),
88              kTfLiteOk);
89   }
90 
91   CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
92   CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
93 }
94 
SetConstTensor(int tensor_index,const std::vector<int> & values,TfLiteType type,const char * buffer,size_t bytes)95 void FlexModelTest::SetConstTensor(int tensor_index,
96                                    const std::vector<int>& values,
97                                    TfLiteType type, const char* buffer,
98                                    size_t bytes) {
99   TfLiteQuantizationParams quant;
100   CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
101                                                      /*name=*/"",
102                                                      /*dims=*/values, quant,
103                                                      buffer, bytes),
104            kTfLiteOk);
105 }
106 
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)107 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
108                                    const std::vector<int>& outputs) {
109   ++next_op_index_;
110 
111   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
112   reg.builtin_code = BuiltinOperator_MUL;
113   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
114     auto* i0 = &context->tensors[node->inputs->data[0]];
115     auto* o = &context->tensors[node->outputs->data[0]];
116     return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
117   };
118   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
119     auto* i0 = &context->tensors[node->inputs->data[0]];
120     auto* i1 = &context->tensors[node->inputs->data[1]];
121     auto* o = &context->tensors[node->outputs->data[0]];
122     for (int i = 0; i < o->bytes / sizeof(float); ++i) {
123       o->data.f[i] = i0->data.f[i] * i1->data.f[i];
124     }
125     return kTfLiteOk;
126   };
127 
128   CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
129                                                nullptr, &reg),
130            kTfLiteOk);
131 }
132 
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)133 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
134                             const std::vector<int>& outputs) {
135   tf_ops_.push_back(next_op_index_);
136   ++next_op_index_;
137 
138   auto attr = [](const string& key, const string& value) {
139     return " attr{ key: '" + key + "' value {" + value + "}}";
140   };
141 
142   string type_attribute;
143   switch (interpreter_->tensor(inputs[0])->type) {
144     case kTfLiteInt32:
145       type_attribute = attr("T", "type: DT_INT32");
146       break;
147     case kTfLiteFloat32:
148       type_attribute = attr("T", "type: DT_FLOAT");
149       break;
150     case kTfLiteString:
151       type_attribute = attr("T", "type: DT_STRING");
152       break;
153     default:
154       // TODO(b/113613439): Use nodedef string utilities to properly handle all
155       // types.
156       LOG(FATAL) << "Type not supported";
157       break;
158   }
159 
160   if (op == kUnpack) {
161     string attributes =
162         type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
163     AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
164   } else if (op == kIdentity) {
165     string attributes = type_attribute;
166     AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
167   } else if (op == kAdd) {
168     string attributes = type_attribute;
169     AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
170   } else if (op == kMul) {
171     string attributes = type_attribute;
172     AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
173   } else if (op == kRfft) {
174     AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
175   } else if (op == kImag) {
176     AddTfOp("FlexImag", "Imag", "", inputs, outputs);
177   } else if (op == kNonExistent) {
178     AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
179   } else if (op == kIncompatibleNodeDef) {
180     // "Cast" op is created without attributes - making it incompatible.
181     AddTfOp("FlexCast", "Cast", "", inputs, outputs);
182   }
183 }
184 
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)185 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
186                             const string& nodedef_str,
187                             const std::vector<int>& inputs,
188                             const std::vector<int>& outputs) {
189   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
190   reg.builtin_code = BuiltinOperator_CUSTOM;
191   reg.custom_name = tflite_name;
192 
193   tensorflow::NodeDef nodedef;
194   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
195       nodedef_str + " op: '" + tf_name + "'", &nodedef));
196   string serialized_nodedef;
197   CHECK(nodedef.SerializeToString(&serialized_nodedef));
198   flexbuffers::Builder fbb;
199   fbb.Vector([&]() {
200     fbb.String(nodedef.op());
201     fbb.String(serialized_nodedef);
202   });
203   fbb.Finish();
204 
205   flexbuffers_.push_back(fbb.GetBuffer());
206   auto& buffer = flexbuffers_.back();
207   CHECK_EQ(interpreter_->AddNodeWithParameters(
208                inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
209                buffer.size(), nullptr, &reg),
210            kTfLiteOk);
211 }
212 
213 }  // namespace testing
214 }  // namespace flex
215 }  // namespace tflite
216