• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
20 #include "tensorflow/lite/string_type.h"
21 
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25 
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27 
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29                                     const std::vector<string>& values) {
30   DynamicBuffer dynamic_buffer;
31   for (const string& s : values) {
32     dynamic_buffer.AddString(s.data(), s.size());
33   }
34   dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35                                /*new_shape=*/nullptr);
36 }
37 
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39   std::vector<string> result;
40 
41   TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42   auto num_strings = GetStringCount(tensor);
43   for (size_t i = 0; i < num_strings; ++i) {
44     auto ref = GetString(tensor, i);
45     result.push_back(string(ref.str, ref.len));
46   }
47 
48   return result;
49 }
50 
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52   ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55 
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57   std::vector<int> result;
58   auto* dims = interpreter_->tensor(tensor_index)->dims;
59   result.reserve(dims->size);
60   for (int i = 0; i < dims->size; ++i) {
61     result.push_back(dims->data[i]);
62   }
63   return result;
64 }
65 
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67   return interpreter_->tensor(tensor_index)->type;
68 }
69 
IsDynamicTensor(int tensor_index)70 bool FlexModelTest::IsDynamicTensor(int tensor_index) {
71   return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
72 }
73 
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)74 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
75                                const std::vector<int>& outputs, TfLiteType type,
76                                const std::vector<int>& dims) {
77   interpreter_->AddTensors(num_tensors);
78   for (int i = 0; i < num_tensors; ++i) {
79     TfLiteQuantizationParams quant;
80     // Suppress explicit output type specification to ensure type inference
81     // works properly.
82     if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
83       type = kTfLiteFloat32;
84     }
85     CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
86                                                         /*name=*/"",
87                                                         /*dims=*/dims, quant),
88              kTfLiteOk);
89   }
90 
91   CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
92   CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
93 }
94 
SetConstTensor(int tensor_index,const std::vector<int> & values,TfLiteType type,const char * buffer,size_t bytes)95 void FlexModelTest::SetConstTensor(int tensor_index,
96                                    const std::vector<int>& values,
97                                    TfLiteType type, const char* buffer,
98                                    size_t bytes) {
99   TfLiteQuantizationParams quant;
100   CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
101                                                      /*name=*/"",
102                                                      /*dims=*/values, quant,
103                                                      buffer, bytes),
104            kTfLiteOk);
105 }
106 
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)107 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
108                                    const std::vector<int>& outputs) {
109   ++next_op_index_;
110 
111   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
112   reg.builtin_code = BuiltinOperator_MUL;
113   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
114     auto* i0 = &context->tensors[node->inputs->data[0]];
115     auto* o = &context->tensors[node->outputs->data[0]];
116     return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
117   };
118   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
119     auto* i0 = &context->tensors[node->inputs->data[0]];
120     auto* i1 = &context->tensors[node->inputs->data[1]];
121     auto* o = &context->tensors[node->outputs->data[0]];
122     for (int i = 0; i < o->bytes / sizeof(float); ++i) {
123       o->data.f[i] = i0->data.f[i] * i1->data.f[i];
124     }
125     return kTfLiteOk;
126   };
127 
128   CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
129                                                nullptr, &reg),
130            kTfLiteOk);
131 }
132 
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)133 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
134                             const std::vector<int>& outputs) {
135   tf_ops_.push_back(next_op_index_);
136   ++next_op_index_;
137 
138   auto attr = [](const string& key, const string& value) {
139     return " attr{ key: '" + key + "' value {" + value + "}}";
140   };
141 
142   string type_attribute;
143   switch (interpreter_->tensor(inputs[0])->type) {
144     case kTfLiteInt32:
145       type_attribute = attr("T", "type: DT_INT32");
146       break;
147     case kTfLiteFloat32:
148       type_attribute = attr("T", "type: DT_FLOAT");
149       break;
150     case kTfLiteString:
151       type_attribute = attr("T", "type: DT_STRING");
152       break;
153     case kTfLiteBool:
154       type_attribute = attr("T", "type: DT_BOOL");
155       break;
156     default:
157       // TODO(b/113613439): Use nodedef string utilities to properly handle all
158       // types.
159       LOG(FATAL) << "Type not supported";
160       break;
161   }
162 
163   if (op == kUnpack) {
164     string attributes =
165         type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
166     AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
167   } else if (op == kIdentity) {
168     string attributes = type_attribute;
169     AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
170   } else if (op == kAdd) {
171     string attributes = type_attribute;
172     AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
173   } else if (op == kMul) {
174     string attributes = type_attribute;
175     AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
176   } else if (op == kRfft) {
177     AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
178   } else if (op == kImag) {
179     AddTfOp("FlexImag", "Imag", "", inputs, outputs);
180   } else if (op == kLoopCond) {
181     string attributes = type_attribute;
182     AddTfOp("FlexLoopCond", "LoopCond", attributes, inputs, outputs);
183   } else if (op == kNonExistent) {
184     AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
185   } else if (op == kIncompatibleNodeDef) {
186     // "Cast" op is created without attributes - making it incompatible.
187     AddTfOp("FlexCast", "Cast", "", inputs, outputs);
188   }
189 }
190 
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)191 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
192                             const string& nodedef_str,
193                             const std::vector<int>& inputs,
194                             const std::vector<int>& outputs) {
195   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
196   reg.builtin_code = BuiltinOperator_CUSTOM;
197   reg.custom_name = tflite_name;
198 
199   tensorflow::NodeDef nodedef;
200   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
201       nodedef_str + " op: '" + tf_name + "'", &nodedef));
202   string serialized_nodedef;
203   CHECK(nodedef.SerializeToString(&serialized_nodedef));
204   flexbuffers::Builder fbb;
205   fbb.Vector([&]() {
206     fbb.String(nodedef.op());
207     fbb.String(serialized_nodedef);
208   });
209   fbb.Finish();
210 
211   flexbuffers_.push_back(fbb.GetBuffer());
212   auto& buffer = flexbuffers_.back();
213   CHECK_EQ(interpreter_->AddNodeWithParameters(
214                inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
215                buffer.size(), nullptr, &reg),
216            kTfLiteOk);
217 }
218 
219 }  // namespace testing
220 }  // namespace flex
221 }  // namespace tflite
222