• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h"  // TF:flatbuffers
20 #include "tensorflow/lite/string.h"
21 
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25 
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27 
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29                                     const std::vector<string>& values) {
30   DynamicBuffer dynamic_buffer;
31   for (const string& s : values) {
32     dynamic_buffer.AddString(s.data(), s.size());
33   }
34   dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35                                /*new_shape=*/nullptr);
36 }
37 
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39   std::vector<string> result;
40 
41   TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42   auto num_strings = GetStringCount(tensor->data.raw);
43   for (size_t i = 0; i < num_strings; ++i) {
44     auto ref = GetString(tensor->data.raw, i);
45     result.push_back(string(ref.str, ref.len));
46   }
47 
48   return result;
49 }
50 
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52   ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55 
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57   std::vector<int> result;
58   auto* dims = interpreter_->tensor(tensor_index)->dims;
59   result.reserve(dims->size);
60   for (int i = 0; i < dims->size; ++i) {
61     result.push_back(dims->data[i]);
62   }
63   return result;
64 }
65 
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67   return interpreter_->tensor(tensor_index)->type;
68 }
69 
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)70 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
71                                const std::vector<int>& outputs, TfLiteType type,
72                                const std::vector<int>& dims) {
73   interpreter_->AddTensors(num_tensors);
74   for (int i = 0; i < num_tensors; ++i) {
75     TfLiteQuantizationParams quant;
76     // Suppress explicit output type specification to ensure type inference
77     // works properly.
78     if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
79       type = kTfLiteFloat32;
80     }
81     CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
82                                                         /*name=*/"",
83                                                         /*dims=*/dims, quant),
84              kTfLiteOk);
85   }
86 
87   CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
88   CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
89 }
90 
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)91 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
92                                    const std::vector<int>& outputs) {
93   ++next_op_index_;
94 
95   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
96   reg.builtin_code = BuiltinOperator_MUL;
97   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
98     auto* i0 = &context->tensors[node->inputs->data[0]];
99     auto* o = &context->tensors[node->outputs->data[0]];
100     return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
101   };
102   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
103     auto* i0 = &context->tensors[node->inputs->data[0]];
104     auto* i1 = &context->tensors[node->inputs->data[1]];
105     auto* o = &context->tensors[node->outputs->data[0]];
106     for (int i = 0; i < o->bytes / sizeof(float); ++i) {
107       o->data.f[i] = i0->data.f[i] * i1->data.f[i];
108     }
109     return kTfLiteOk;
110   };
111 
112   CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
113                                                nullptr, &reg),
114            kTfLiteOk);
115 }
116 
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)117 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
118                             const std::vector<int>& outputs) {
119   tf_ops_.push_back(next_op_index_);
120   ++next_op_index_;
121 
122   auto attr = [](const string& key, const string& value) {
123     return " attr{ key: '" + key + "' value {" + value + "}}";
124   };
125 
126   string type_attribute;
127   switch (interpreter_->tensor(inputs[0])->type) {
128     case kTfLiteInt32:
129       type_attribute = attr("T", "type: DT_INT32");
130       break;
131     case kTfLiteFloat32:
132       type_attribute = attr("T", "type: DT_FLOAT");
133       break;
134     case kTfLiteString:
135       type_attribute = attr("T", "type: DT_STRING");
136       break;
137     default:
138       // TODO(b/113613439): Use nodedef string utilities to properly handle all
139       // types.
140       LOG(FATAL) << "Type not supported";
141       break;
142   }
143 
144   if (op == kUnpack) {
145     string attributes =
146         type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
147     AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
148   } else if (op == kIdentity) {
149     string attributes = type_attribute;
150     AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
151   } else if (op == kAdd) {
152     string attributes = type_attribute;
153     AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
154   } else if (op == kMul) {
155     string attributes = type_attribute;
156     AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
157   } else if (op == kNonExistent) {
158     AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
159   } else if (op == kIncompatibleNodeDef) {
160     // "Cast" op is created without attributes - making it incompatible.
161     AddTfOp("FlexCast", "Cast", "", inputs, outputs);
162   }
163 }
164 
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)165 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
166                             const string& nodedef_str,
167                             const std::vector<int>& inputs,
168                             const std::vector<int>& outputs) {
169   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
170   reg.builtin_code = BuiltinOperator_CUSTOM;
171   reg.custom_name = tflite_name;
172 
173   tensorflow::NodeDef nodedef;
174   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
175       nodedef_str + " op: '" + tf_name + "'", &nodedef));
176   string serialized_nodedef;
177   CHECK(nodedef.SerializeToString(&serialized_nodedef));
178   flexbuffers::Builder fbb;
179   fbb.Vector([&]() {
180     fbb.String(nodedef.op());
181     fbb.String(serialized_nodedef);
182   });
183   fbb.Finish();
184 
185   flexbuffers_.push_back(fbb.GetBuffer());
186   auto& buffer = flexbuffers_.back();
187   CHECK_EQ(interpreter_->AddNodeWithParameters(
188                inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
189                buffer.size(), nullptr, &reg),
190            kTfLiteOk);
191 }
192 
193 }  // namespace testing
194 }  // namespace flex
195 }  // namespace tflite
196