• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "TestUtils.hpp"
9 
10 #include <armnn_delegate.hpp>
11 
12 #include <flatbuffers/flatbuffers.h>
13 #include <tensorflow/lite/interpreter.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/model.h>
16 #include <tensorflow/lite/schema/schema_generated.h>
17 #include <tensorflow/lite/version.h>
18 
19 #include <doctest/doctest.h>
20 
21 namespace
22 {
23 
CreateComparisonTfLiteModel(tflite::BuiltinOperator comparisonOperatorCode,tflite::TensorType tensorType,const std::vector<int32_t> & input0TensorShape,const std::vector<int32_t> & input1TensorShape,const std::vector<int32_t> & outputTensorShape,float quantScale=1.0f,int quantOffset=0)24 std::vector<char> CreateComparisonTfLiteModel(tflite::BuiltinOperator comparisonOperatorCode,
25                                               tflite::TensorType tensorType,
26                                               const std::vector <int32_t>& input0TensorShape,
27                                               const std::vector <int32_t>& input1TensorShape,
28                                               const std::vector <int32_t>& outputTensorShape,
29                                               float quantScale = 1.0f,
30                                               int quantOffset  = 0)
31 {
32     using namespace tflite;
33     flatbuffers::FlatBufferBuilder flatBufferBuilder;
34 
35     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
36     buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
37 
38     auto quantizationParameters =
39         CreateQuantizationParameters(flatBufferBuilder,
40                                      0,
41                                      0,
42                                      flatBufferBuilder.CreateVector<float>({ quantScale }),
43                                      flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
44 
45     std::array<flatbuffers::Offset<Tensor>, 3> tensors;
46     tensors[0] = CreateTensor(flatBufferBuilder,
47                               flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
48                                                                       input0TensorShape.size()),
49                               tensorType,
50                               0,
51                               flatBufferBuilder.CreateString("input_0"),
52                               quantizationParameters);
53     tensors[1] = CreateTensor(flatBufferBuilder,
54                               flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
55                                                                       input1TensorShape.size()),
56                               tensorType,
57                               0,
58                               flatBufferBuilder.CreateString("input_1"),
59                               quantizationParameters);
60     tensors[2] = CreateTensor(flatBufferBuilder,
61                               flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
62                                                                       outputTensorShape.size()),
63                               ::tflite::TensorType_BOOL,
64                               0);
65 
66     // create operator
67     tflite::BuiltinOptions operatorBuiltinOptionsType = BuiltinOptions_EqualOptions;;
68     flatbuffers::Offset<void> operatorBuiltinOptions = CreateEqualOptions(flatBufferBuilder).Union();
69     switch (comparisonOperatorCode)
70     {
71         case BuiltinOperator_EQUAL:
72         {
73             operatorBuiltinOptionsType = BuiltinOptions_EqualOptions;
74             operatorBuiltinOptions = CreateEqualOptions(flatBufferBuilder).Union();
75             break;
76         }
77         case BuiltinOperator_NOT_EQUAL:
78         {
79             operatorBuiltinOptionsType = BuiltinOptions_NotEqualOptions;
80             operatorBuiltinOptions = CreateNotEqualOptions(flatBufferBuilder).Union();
81             break;
82         }
83         case BuiltinOperator_GREATER:
84         {
85             operatorBuiltinOptionsType = BuiltinOptions_GreaterOptions;
86             operatorBuiltinOptions = CreateGreaterOptions(flatBufferBuilder).Union();
87             break;
88         }
89         case BuiltinOperator_GREATER_EQUAL:
90         {
91             operatorBuiltinOptionsType = BuiltinOptions_GreaterEqualOptions;
92             operatorBuiltinOptions = CreateGreaterEqualOptions(flatBufferBuilder).Union();
93             break;
94         }
95         case BuiltinOperator_LESS:
96         {
97             operatorBuiltinOptionsType = BuiltinOptions_LessOptions;
98             operatorBuiltinOptions = CreateLessOptions(flatBufferBuilder).Union();
99             break;
100         }
101         case BuiltinOperator_LESS_EQUAL:
102         {
103             operatorBuiltinOptionsType = BuiltinOptions_LessEqualOptions;
104             operatorBuiltinOptions = CreateLessEqualOptions(flatBufferBuilder).Union();
105             break;
106         }
107         default:
108             break;
109     }
110     const std::vector<int32_t> operatorInputs{ {0, 1} };
111     const std::vector<int32_t> operatorOutputs{{2}};
112     flatbuffers::Offset <Operator> comparisonOperator =
113         CreateOperator(flatBufferBuilder,
114                        0,
115                        flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
116                        flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
117                        operatorBuiltinOptionsType,
118                        operatorBuiltinOptions);
119 
120     const std::vector<int> subgraphInputs{ {0, 1} };
121     const std::vector<int> subgraphOutputs{{2}};
122     flatbuffers::Offset <SubGraph> subgraph =
123         CreateSubGraph(flatBufferBuilder,
124                        flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
125                        flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
126                        flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
127                        flatBufferBuilder.CreateVector(&comparisonOperator, 1));
128 
129     flatbuffers::Offset <flatbuffers::String> modelDescription =
130         flatBufferBuilder.CreateString("ArmnnDelegate: Comparison Operator Model");
131     flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, comparisonOperatorCode);
132 
133     flatbuffers::Offset <Model> flatbufferModel =
134         CreateModel(flatBufferBuilder,
135                     TFLITE_SCHEMA_VERSION,
136                     flatBufferBuilder.CreateVector(&operatorCode, 1),
137                     flatBufferBuilder.CreateVector(&subgraph, 1),
138                     modelDescription,
139                     flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
140 
141     flatBufferBuilder.Finish(flatbufferModel);
142 
143     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
144                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
145 }
146 
147 template <typename T>
ComparisonTest(tflite::BuiltinOperator comparisonOperatorCode,tflite::TensorType tensorType,std::vector<armnn::BackendId> & backends,std::vector<int32_t> & input0Shape,std::vector<int32_t> & input1Shape,std::vector<int32_t> & outputShape,std::vector<T> & input0Values,std::vector<T> & input1Values,std::vector<bool> & expectedOutputValues,float quantScale=1.0f,int quantOffset=0)148 void ComparisonTest(tflite::BuiltinOperator comparisonOperatorCode,
149                     tflite::TensorType tensorType,
150                     std::vector<armnn::BackendId>& backends,
151                     std::vector<int32_t>& input0Shape,
152                     std::vector<int32_t>& input1Shape,
153                     std::vector<int32_t>& outputShape,
154                     std::vector<T>& input0Values,
155                     std::vector<T>& input1Values,
156                     std::vector<bool>& expectedOutputValues,
157                     float quantScale = 1.0f,
158                     int quantOffset  = 0)
159 {
160     using namespace tflite;
161     std::vector<char> modelBuffer = CreateComparisonTfLiteModel(comparisonOperatorCode,
162                                                                 tensorType,
163                                                                 input0Shape,
164                                                                 input1Shape,
165                                                                 outputShape,
166                                                                 quantScale,
167                                                                 quantOffset);
168 
169     const Model* tfLiteModel = GetModel(modelBuffer.data());
170     // Create TfLite Interpreters
171     std::unique_ptr<Interpreter> armnnDelegateInterpreter;
172     CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
173               (&armnnDelegateInterpreter) == kTfLiteOk);
174     CHECK(armnnDelegateInterpreter != nullptr);
175     CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
176 
177     std::unique_ptr<Interpreter> tfLiteInterpreter;
178     CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
179               (&tfLiteInterpreter) == kTfLiteOk);
180     CHECK(tfLiteInterpreter != nullptr);
181     CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
182 
183     // Create the ArmNN Delegate
184     armnnDelegate::DelegateOptions delegateOptions(backends);
185     std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
186         theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
187                          armnnDelegate::TfLiteArmnnDelegateDelete);
188     CHECK(theArmnnDelegate != nullptr);
189     // Modify armnnDelegateInterpreter to use armnnDelegate
190     CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
191 
192     // Set input data
193     auto tfLiteDelegateInput0Id = tfLiteInterpreter->inputs()[0];
194     auto tfLiteDelageInput0Data = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateInput0Id);
195     for (unsigned int i = 0; i < input0Values.size(); ++i)
196     {
197         tfLiteDelageInput0Data[i] = input0Values[i];
198     }
199 
200     auto tfLiteDelegateInput1Id = tfLiteInterpreter->inputs()[1];
201     auto tfLiteDelageInput1Data = tfLiteInterpreter->typed_tensor<T>(tfLiteDelegateInput1Id);
202     for (unsigned int i = 0; i < input1Values.size(); ++i)
203     {
204         tfLiteDelageInput1Data[i] = input1Values[i];
205     }
206 
207     auto armnnDelegateInput0Id = armnnDelegateInterpreter->inputs()[0];
208     auto armnnDelegateInput0Data = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateInput0Id);
209     for (unsigned int i = 0; i < input0Values.size(); ++i)
210     {
211         armnnDelegateInput0Data[i] = input0Values[i];
212     }
213 
214     auto armnnDelegateInput1Id = armnnDelegateInterpreter->inputs()[1];
215     auto armnnDelegateInput1Data = armnnDelegateInterpreter->typed_tensor<T>(armnnDelegateInput1Id);
216     for (unsigned int i = 0; i < input1Values.size(); ++i)
217     {
218         armnnDelegateInput1Data[i] = input1Values[i];
219     }
220 
221     // Run EnqueWorkload
222     CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
223     CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
224     // Compare output data
225     auto tfLiteDelegateOutputId = tfLiteInterpreter->outputs()[0];
226     auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<bool>(tfLiteDelegateOutputId);
227     auto armnnDelegateOutputId = armnnDelegateInterpreter->outputs()[0];
228     auto armnnDelegateOutputData = armnnDelegateInterpreter->typed_tensor<bool>(armnnDelegateOutputId);
229 
230     armnnDelegate::CompareData(expectedOutputValues  , armnnDelegateOutputData, expectedOutputValues.size());
231     armnnDelegate::CompareData(expectedOutputValues  , tfLiteDelageOutputData , expectedOutputValues.size());
232     armnnDelegate::CompareData(tfLiteDelageOutputData, armnnDelegateOutputData, expectedOutputValues.size());
233 }
234 
235 } // anonymous namespace