• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
16 
17 #include <sys/mman.h>
18 
19 #include <initializer_list>
20 
21 #include <gtest/gtest.h>
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
27 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
28 
29 namespace tflite {
30 namespace {
31 
32 using ::testing::ElementsAre;
33 using ::testing::ElementsAreArray;
34 using ::testing::FloatNear;
35 using ::testing::Matcher;
36 
37 // TODO(b/110368244): figure out how to share the existing tests in kernels/ but
38 // with the delegation on. Also, add more unit tests to improve code coverage.
39 
40 // This matcher uses 1 as maximum tolerance.
41 MATCHER(QuantizedNear, "") {
42   const int diff = abs(std::get<0>(arg) - std::get<1>(arg));
43   if (diff > 1) {
44     *result_listener << "Quantized values can be at most off by one: " << diff;
45     return false;
46   }
47   return true;
48 }
49 
NnapiArrayFloatNear(const std::vector<float> & values,bool relaxed=false)50 auto NnapiArrayFloatNear(const std::vector<float>& values,
51                          bool relaxed = false) {
52   // Uses the same tolerance as NNAPI generated tests.
53   const float atol = relaxed ? 5 * 0.0009765625f : 1e-5f;
54   const float rtol = relaxed ? 5 * 0.0009765625f : 5 * 1.1920928955078125e-7f;
55 
56   std::vector<Matcher<float>> matchers;
57   matchers.reserve(values.size());
58   for (const float& v : values) {
59     const float tolerance = atol + rtol * std::abs(v);
60     matchers.emplace_back(FloatNear(v, tolerance));
61   }
62   return ElementsAreArray(matchers);
63 }
64 
65 class SingleOpModelWithNNAPI : public SingleOpModel {
66  public:
SingleOpModelWithNNAPI()67   SingleOpModelWithNNAPI() {
68     options_.disallow_nnapi_cpu = false;
69     stateful_delegate_.reset(new StatefulNnApiDelegate(options_));
70     SetDelegate(stateful_delegate_.get());
71   }
72 
SingleOpModelWithNNAPI(const StatefulNnApiDelegate::Options & options)73   explicit SingleOpModelWithNNAPI(
74       const StatefulNnApiDelegate::Options& options) {
75     options_ = options;
76     options_.disallow_nnapi_cpu = false;
77     stateful_delegate_.reset(new StatefulNnApiDelegate(options_));
78     SetDelegate(stateful_delegate_.get());
79   }
80 
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)81   TfLiteStatus ResizeInputTensor(int tensor_index,
82                                  const std::vector<int>& dims) {
83     return interpreter_->ResizeInputTensor(tensor_index, dims);
84   }
85 
GetDelegate()86   StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
87 
SetBufferHandle(int index,TfLiteBufferHandle handle)88   void SetBufferHandle(int index, TfLiteBufferHandle handle) {
89     interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
90   }
91 
MarkInputTensorDataStale(int index)92   void MarkInputTensorDataStale(int index) {
93     interpreter_->tensor(index)->data_is_stale = true;
94   }
95 
AllocateTensors()96   TfLiteStatus AllocateTensors() { return interpreter_->AllocateTensors(); }
97 
98  protected:
SetData(int index,TensorType type,const std::vector<float> & data)99   void SetData(int index, TensorType type, const std::vector<float>& data) {
100     switch (type) {
101       case TensorType_FLOAT32:
102         PopulateTensor(index, data);
103         break;
104       case TensorType_INT32:
105         QuantizeAndPopulate<int32_t>(index, data);
106         break;
107       case TensorType_UINT8:
108         QuantizeAndPopulate<uint8_t>(index, data);
109         break;
110       case TensorType_INT8:
111         QuantizeAndPopulate<int8_t>(index, data);
112         break;
113       default:
114         FAIL() << "Type not supported: " << type;
115         break;
116     }
117   }
118 
GetData(int index,TensorType type,std::vector<float> * output)119   void GetData(int index, TensorType type, std::vector<float>* output) {
120     switch (type) {
121       case TensorType_FLOAT32:
122         *output = ExtractVector<float>(index);
123         break;
124       case TensorType_UINT8:
125         *output = Dequantize<uint8_t>(ExtractVector<uint8_t>(index),
126                                       GetScale(index), GetZeroPoint(index));
127         break;
128       default:
129         FAIL() << "Type not supported: " << type;
130         break;
131     }
132   }
133 
BuildInterpreterWithNNAPI(std::vector<std::vector<int>> input_shapes,bool allow_fp32_relax_to_fp16=false)134   void BuildInterpreterWithNNAPI(std::vector<std::vector<int>> input_shapes,
135                                  bool allow_fp32_relax_to_fp16 = false) {
136     // We skip those TfLite delegates that are applied by default in TfLite
137     // runtime by setting 'apply_delegate' to false. Afterwards, we explicitly
138     // call ApplyDelegate to apply the NNAPI delegate to meet the testing
139     // purpose.
140     BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16,
141                      /*apply_delegate=*/false, /*allocate_and_delegate=*/true);
142     ApplyDelegate();
143   }
144 
145  private:
146   // Stateful NNAPI delegate. This is valid only if the state-ful constructor is
147   // used.
148   StatefulNnApiDelegate::Options options_;
149   std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
150 };
151 
152 class FloatAddOpModel : public SingleOpModelWithNNAPI {
153  public:
FloatAddOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)154   FloatAddOpModel(const TensorData& input1, const TensorData& input2,
155                   const TensorData& output,
156                   ActivationFunctionType activation_type,
157                   bool allow_fp32_relax_to_fp16 = false) {
158     Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
159   }
160 
FloatAddOpModel(const StatefulNnApiDelegate::Options & options,const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)161   FloatAddOpModel(const StatefulNnApiDelegate::Options& options,
162                   const TensorData& input1, const TensorData& input2,
163                   const TensorData& output,
164                   ActivationFunctionType activation_type,
165                   bool allow_fp32_relax_to_fp16 = false)
166       : SingleOpModelWithNNAPI(options) {
167     Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
168   }
169 
input1()170   int input1() { return input1_; }
input2()171   int input2() { return input2_; }
172 
GetOutput()173   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
174 
175  protected:
176   int input1_;
177   int input2_;
178   int output_;
179 
180  private:
181   // Performs initialization logic shared across all constructors.
Init(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)182   void Init(const TensorData& input1, const TensorData& input2,
183             const TensorData& output, ActivationFunctionType activation_type,
184             bool allow_fp32_relax_to_fp16 = false) {
185     input1_ = AddInput(input1);
186     input2_ = AddInput(input2);
187     output_ = AddOutput(output);
188     SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
189                  CreateAddOptions(builder_, activation_type).Union());
190     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)},
191                               allow_fp32_relax_to_fp16);
192   }
193 };
194 
195 // Do a test with the NN API using no activation.
TEST(NNAPIDelegate,AddWithNoActivation)196 TEST(NNAPIDelegate, AddWithNoActivation) {
197   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
198                     {TensorType_FLOAT32, {1, 2, 2, 1}},
199                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
200   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
201   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
202   m.Invoke();
203   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
204 }
205 
206 // Do a test with scalar input using no activation.
TEST(NNAPIDelegate,AddScalarWithNoActivation)207 TEST(NNAPIDelegate, AddScalarWithNoActivation) {
208   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
209                     {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
210                     ActivationFunctionType_NONE);
211   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.7});
212   m.PopulateTensor<float>(m.input2(), {0.1});
213   m.Invoke();
214   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.3, 0.8, 0.8}));
215 }
216 
217 // Do a test with the NN API using no activation.
218 // The test allows computing FP32 with FP16 precision. In this particular case,
219 // calculating in FP32 or FP16 should produce the same results.
TEST(NNAPIDelegate,AddWithNoActivationRelaxed)220 TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
221   FloatAddOpModel m(
222       {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
223       {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
224   m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
225   m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
226   m.Invoke();
227   EXPECT_THAT(m.GetOutput(),
228               NnapiArrayFloatNear({-1.0, 1.0, 4.0, 6.0}, /*relaxed=*/true));
229 }
230 
231 // Do a test with the NN api with relu.
TEST(NNAPIDelegate,AddWithRelu)232 TEST(NNAPIDelegate, AddWithRelu) {
233   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
234                     {TensorType_FLOAT32, {1, 2, 2, 1}},
235                     {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
236   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
237   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
238   m.Invoke();
239   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0.0, 0.4, 1.0, 1.3}));
240 }
241 
242 // Verify that resize attempts succeed.
TEST(NNAPIDelegate,ResizeInputTensorsWorks)243 TEST(NNAPIDelegate, ResizeInputTensorsWorks) {
244   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
245                     {TensorType_FLOAT32, {1, 2, 2, 1}},
246                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
247 
248   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 2, 1}), kTfLiteOk);
249   EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 3, 2, 1}), kTfLiteOk);
250   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
251   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 0.9, 0.7});
252   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 0.2, 0.8});
253   m.Invoke();
254   EXPECT_THAT(m.GetOutput(),
255               NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3, 1.1, 1.5}));
256 
257   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 2, 2, 1}), kTfLiteOk);
258   EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 2, 2, 1}), kTfLiteOk);
259   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
260   m.PopulateTensor<float>(m.input1(), {0.7, 0.8, 0.9, 0.7});
261   m.PopulateTensor<float>(m.input2(), {0.3, 0.5, 0.2, 0.8});
262   m.Invoke();
263   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1.0, 1.3, 1.1, 1.5}));
264 }
265 
TEST(NNAPIDelegate,ResizeDynamicBatchInputTensorsWorks)266 TEST(NNAPIDelegate, ResizeDynamicBatchInputTensorsWorks) {
267   StatefulNnApiDelegate::Options options;
268   options.allow_dynamic_dimensions = true;
269 
270   FloatAddOpModel m(options,
271                     {TensorType_FLOAT32, /*shape=*/{1, 3, 2, 1}, /*min=*/0.0f,
272                      /*max=*/0.0f, /*scale=*/0.0f,
273                      /*zero_point=*/0, /*per_channel_quantization=*/false,
274                      /*per_channel_quantization_scales=*/{},
275                      /*per_channel_quantization_offsets=*/{},
276                      /*channel_index=*/0, /*traversal_order=*/{},
277                      /*format=*/{},
278                      /*block_size=*/{}, /*block_map=*/{},
279                      /*shape_signature=*/{1, -1, 2, 1}},
280                     {TensorType_FLOAT32, /*shape=*/{1, 3, 2, 1}, /*min=*/0.0f,
281                      /*max=*/0.0f, /*scale=*/0.0f,
282                      /*zero_point=*/0, /*per_channel_quantization=*/false,
283                      /*per_channel_quantization_scales=*/{},
284                      /*per_channel_quantization_offsets=*/{},
285                      /*channel_index=*/0, /*traversal_order=*/{},
286                      /*format=*/{},
287                      /*block_size=*/{}, /*block_map=*/{},
288                      /*shape_signature=*/{1, -1, 2, 1}},
289                     {TensorType_FLOAT32, /*shape=*/{}, /*min=*/0.0f,
290                      /*max=*/0.0f, /*scale=*/0.0f,
291                      /*zero_point=*/0, /*per_channel_quantization=*/false,
292                      /*per_channel_quantization_scales=*/{},
293                      /*per_channel_quantization_offsets=*/{},
294                      /*channel_index=*/0, /*traversal_order=*/{},
295                      /*format=*/{},
296                      /*block_size=*/{}, /*block_map=*/{},
297                      /*shape_signature=*/{1, -1, 2, 1}},
298                     ActivationFunctionType_NONE);
299   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 2, 1}), kTfLiteOk);
300   EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 3, 2, 1}), kTfLiteOk);
301   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
302   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 0.9, 0.7});
303   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 0.2, 0.8});
304   m.Invoke();
305   EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 1.1, 1.5}));
306 
307   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 2, 2, 1}), kTfLiteOk);
308   EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 2, 2, 1}), kTfLiteOk);
309   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
310   m.PopulateTensor<float>(m.input1(), {0.7, 0.8, 0.9, 0.7});
311   m.PopulateTensor<float>(m.input2(), {0.3, 0.5, 0.2, 0.8});
312   m.Invoke();
313   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.0, 1.3, 1.1, 1.5}));
314 }
315 
316 // Sanity check for the state-ful NNAPI delegate.
TEST(NNAPIDelegate,StatefulDelegate)317 TEST(NNAPIDelegate, StatefulDelegate) {
318   StatefulNnApiDelegate::Options options;
319   options.execution_preference =
320       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
321 
322   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
323                     {TensorType_FLOAT32, {1, 2, 2, 1}},
324                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
325   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
326   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
327   m.Invoke();
328   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
329 }
330 
331 // Sanity check for the state-ful NNAPI delegate with accelerator_name
332 // specified.
TEST(NNAPIDelegate,StatefulDelegateWithAcceleratorName)333 TEST(NNAPIDelegate, StatefulDelegateWithAcceleratorName) {
334   StatefulNnApiDelegate::Options options;
335   options.execution_preference =
336       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
337   options.accelerator_name = "nnapi-reference";
338 
339   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
340                     {TensorType_FLOAT32, {1, 2, 2, 1}},
341                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
342   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
343   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
344   m.Invoke();
345   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
346 }
347 
348 // Sanity check for the state-ful NNAPI delegate with invalid accelerator_name
349 // specified.
TEST(NNAPIDelegate,StatefulDelegateWithInvalidAcceleratorName)350 TEST(NNAPIDelegate, StatefulDelegateWithInvalidAcceleratorName) {
351   if (!NnApiImplementation()->ANeuralNetworksDevice_getName) {
352     GTEST_SKIP();
353   }
354   testing::internal::CaptureStderr();
355   StatefulNnApiDelegate::Options options;
356   options.execution_preference =
357       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
358   options.accelerator_name = "foo";
359 
360   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
361                     {TensorType_FLOAT32, {1, 2, 2, 1}},
362                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
363   //EXPECT_THAT(testing::internal::GetCapturedStderr(),
364   //            testing::HasSubstr(
365   //                "Could not find the specified NNAPI accelerator: foo"));
366 
367   // Execution should fall back to the default CPU path.
368   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
369   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
370   m.Invoke();
371   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
372 }
373 
374 // Sanity check for the state-ful NNAPI delegate with compilation caching
375 // enabled.
TEST(NNAPIDelegate,StatefulDelegateWithCompilationCaching)376 TEST(NNAPIDelegate, StatefulDelegateWithCompilationCaching) {
377   StatefulNnApiDelegate::Options options;
378   options.execution_preference =
379       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
380   options.cache_dir = "/data/local/tmp";
381   options.model_token = "NNAPIDelegate.StatefulDelegateWithCompilationCaching";
382 
383   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
384                     {TensorType_FLOAT32, {1, 2, 2, 1}},
385                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
386   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
387   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
388   m.Invoke();
389   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
390 }
391 
392 // Sanity check for the state-ful NNAPI delegate with QoS hints.
TEST(NNAPIDelegate,StatefulDelegateWithQoS)393 TEST(NNAPIDelegate, StatefulDelegateWithQoS) {
394   StatefulNnApiDelegate::Options options;
395   options.accelerator_name = "nnapi-reference";
396   options.execution_priority = ANEURALNETWORKS_PRIORITY_HIGH;
397   options.max_compilation_timeout_duration_ns = UINT64_MAX;
398   options.max_execution_timeout_duration_ns = UINT64_MAX;
399   options.max_execution_loop_timeout_duration_ns = UINT64_MAX;
400 
401   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
402                     {TensorType_FLOAT32, {1, 2, 2, 1}},
403                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
404   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
405   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
406   m.Invoke();
407   EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
408 }
409 
410 // Sanity check for the state-ful NNAPI delegate using TfLiteBufferHandle.
TEST(NNAPIDelegate,StatefulDelegateWithBufferHandles)411 TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) {
412   // Skip the test if Android specific functions could not be found.
413   if (!NnApiImplementation()->ASharedMemory_create ||
414       !NnApiImplementation()->ANeuralNetworksMemory_createFromFd) {
415     GTEST_SKIP();
416   }
417 
418   StatefulNnApiDelegate::Options options;
419   // Allow NNAPI CPU fallback path.
420   options.disallow_nnapi_cpu = false;
421   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
422                     {TensorType_FLOAT32, {1, 2, 2, 1}},
423                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
424   auto* delegate = m.GetDelegate();
425   // Create ASharedMemory and copy data into it.
426   constexpr auto kInput1ByteSize = 4 * sizeof(float);
427   ANeuralNetworksMemory* input1_memory = nullptr;
428   int fd =
429       NnApiImplementation()->ASharedMemory_create("input1", kInput1ByteSize);
430   EXPECT_GE(fd, 0);
431   void* input1_memory_data =
432       mmap(nullptr, kInput1ByteSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
433   EXPECT_TRUE(input1_memory_data != nullptr);
434   float input1_data[] = {-2.0, 0.2, 0.7, 0.8};
435   memcpy(input1_memory_data, input1_data, kInput1ByteSize);
436   int result = NnApiImplementation()->ANeuralNetworksMemory_createFromFd(
437       kInput1ByteSize, PROT_READ, fd, 0, &input1_memory);
438   EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
439   ASSERT_NE(input1_memory, nullptr);
440 
441   struct DummyMemoryContext {
442     ANeuralNetworksMemory* memory_handle;
443     void* memory_data;
444     size_t byte_size;
445   };
446   DummyMemoryContext memory_context = {input1_memory, input1_memory_data,
447                                        kInput1ByteSize};
448   static StatefulNnApiDelegate::CopyToHostTensorFnPtr memory_callback =
449       [](TfLiteTensor* tensor, ANeuralNetworksMemory* memory,
450          size_t memory_offset, size_t byte_size,
451          void* callback_context) -> TfLiteStatus {
452     auto memory_context =
453         reinterpret_cast<DummyMemoryContext*>(callback_context);
454     if (memory != memory_context->memory_handle ||
455         memory_offset + byte_size > memory_context->byte_size) {
456       return kTfLiteError;
457     }
458     memcpy(
459         tensor->data.raw,
460         reinterpret_cast<uint8_t*>(memory_context->memory_data) + memory_offset,
461         byte_size);
462     return kTfLiteOk;
463   };
464   auto input1_handle = delegate->RegisterNnapiMemory(
465       input1_memory, memory_callback, &memory_context);
466   m.SetBufferHandle(m.input1(), input1_handle);
467   m.MarkInputTensorDataStale(m.input1());
468   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
469   m.Invoke();
470   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
471   // Run the inference multiple times and each time register a buffer.
472   for (int i = 0; i < 10; i++) {
473     // Change the value a little bit.
474     input1_data[0] = -2.0 + i;
475     memcpy(input1_memory_data, input1_data, kInput1ByteSize);
476     auto input1_handle = delegate->RegisterNnapiMemory(
477         input1_memory, memory_callback, &memory_context);
478     m.SetBufferHandle(m.input1(), input1_handle);
479     m.MarkInputTensorDataStale(m.input1());
480     m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
481     m.Invoke();
482     EXPECT_THAT(m.GetOutput(),
483                 NnapiArrayFloatNear({-1.9f + i, 0.4f, 1.0f, 1.3f}));
484   }
485   m.SetBufferHandle(m.input1(), kTfLiteNullBufferHandle);
486 }
487 
488 class FloatMulOpModel : public SingleOpModelWithNNAPI {
489  public:
FloatMulOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)490   FloatMulOpModel(const TensorData& input1, const TensorData& input2,
491                   const TensorData& output,
492                   ActivationFunctionType activation_type) {
493     input1_ = AddInput(input1);
494     input2_ = AddInput(input2);
495     output_ = AddOutput(output);
496     SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
497                  CreateMulOptions(builder_, activation_type).Union());
498     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
499   }
500 
input1()501   int input1() { return input1_; }
input2()502   int input2() { return input2_; }
503 
GetOutput()504   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
505 
506  protected:
507   int input1_;
508   int input2_;
509   int output_;
510 };
511 
TEST(NNAPIDelegate,MulWithNoActivation)512 TEST(NNAPIDelegate, MulWithNoActivation) {
513   FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
514                     {TensorType_FLOAT32, {1, 2, 2, 1}},
515                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
516   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
517   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
518   m.Invoke();
519   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-0.2, 0.04, 0.21, 0.4}));
520 }
521 
522 class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
523  public:
FloatPoolingOpModel(BuiltinOperator type,const TensorData & input,int filter_width,int filter_height,const TensorData & output)524   FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
525                       int filter_width, int filter_height,
526                       const TensorData& output) {
527     input_ = AddInput(input);
528     output_ = AddOutput(output);
529 
530     SetBuiltinOp(
531         type, BuiltinOptions_Pool2DOptions,
532         CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
533                             filter_height, ActivationFunctionType_NONE)
534             .Union());
535 
536     BuildInterpreterWithNNAPI({GetShape(input_)});
537   }
538 
SetInput(std::initializer_list<float> data)539   void SetInput(std::initializer_list<float> data) {
540     PopulateTensor(input_, data);
541   }
542 
GetOutput()543   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
544 
545  protected:
546   int input_;
547   int output_;
548 };
549 
TEST(NNAPIDelegate,AveragePoolWithNoActivation)550 TEST(NNAPIDelegate, AveragePoolWithNoActivation) {
551   FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
552                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
553                         /*filter_width=*/2, /*filter_height=*/2,
554                         /*output=*/{TensorType_FLOAT32, {}});
555   m.SetInput({
556       0, 6, 2, 4,   //
557       3, 2, 10, 7,  //
558   });
559   m.Invoke();
560   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({2.75, 5.75}));
561 }
562 
TEST(NNAPIDelegate,MaxPoolWithNoActivation)563 TEST(NNAPIDelegate, MaxPoolWithNoActivation) {
564   FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
565                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
566                         /*filter_width=*/2, /*filter_height=*/2,
567                         /*output=*/{TensorType_FLOAT32, {}});
568   m.SetInput({
569       0, 6, 2, 4,   //
570       3, 2, 10, 7,  //
571   });
572   m.Invoke();
573   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({6, 10}));
574 }
575 
TEST(NNAPIDelegate,L2PoolWithNoActivation)576 TEST(NNAPIDelegate, L2PoolWithNoActivation) {
577   FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
578                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
579                         /*filter_width=*/2, /*filter_height=*/2,
580                         /*output=*/{TensorType_FLOAT32, {}});
581   m.SetInput({
582       0, 6, 2, 4,   //
583       3, 2, 10, 7,  //
584   });
585   m.Invoke();
586   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({3.5, 6.5}));
587 }
588 
589 class ConvolutionOpModel : public SingleOpModelWithNNAPI {
590  public:
ConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output,int stride_width=2,int stride_height=2,enum Padding padding=Padding_VALID,enum ActivationFunctionType activation=ActivationFunctionType_NONE,int dilation_width_factor=1,int dilation_height_factor=1)591   ConvolutionOpModel(
592       const TensorData& input, const TensorData& filter,
593       const TensorData& output, int stride_width = 2, int stride_height = 2,
594       enum Padding padding = Padding_VALID,
595       enum ActivationFunctionType activation = ActivationFunctionType_NONE,
596       int dilation_width_factor = 1, int dilation_height_factor = 1)
597       : input_type_(input.type), filter_type_(filter.type) {
598     input_ = AddInput(input);
599     filter_ = AddInput(filter);
600 
601     int bias_size = GetShape(filter_)[0];
602     if (input.type == TensorType_FLOAT32) {
603       bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
604     } else {
605       // This is a quantized version. The scale of 'bias' depends on the scales
606       // of input and filter. Supposedly this is correctly set during quantized
607       // training.
608       auto bias_scale = GetScale(input_) * GetScale(filter_);
609       TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
610       bias_ = AddInput(bias);
611     }
612 
613     output_ = AddOutput(output);
614 
615     SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
616                  CreateConv2DOptions(
617                      builder_, padding, stride_width, stride_height, activation,
618                      dilation_width_factor, dilation_height_factor)
619                      .Union());
620 
621     BuildInterpreterWithNNAPI(
622         {GetShape(input_), GetShape(filter_), GetShape(bias_)});
623   }
624 
SetInput(std::initializer_list<float> data)625   void SetInput(std::initializer_list<float> data) {
626     SetData(input_, input_type_, data);
627   }
628 
SetFilter(std::initializer_list<float> data)629   void SetFilter(std::initializer_list<float> data) {
630     SetData(filter_, filter_type_, data);
631   }
632 
SetBias(std::initializer_list<float> data)633   void SetBias(std::initializer_list<float> data) {
634     const auto bias_type =
635         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
636     SetData(bias_, bias_type, data);
637   }
638 
GetOutput()639   std::vector<float> GetOutput() {
640     if (input_type_ == TensorType_FLOAT32) {
641       return ExtractVector<float>(output_);
642     } else {
643       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
644                                  GetScale(output_), GetZeroPoint(output_));
645     }
646   }
647 
GetQuantizedOutput()648   std::vector<uint8_t> GetQuantizedOutput() {
649     if (input_type_ == TensorType_FLOAT32) {
650       return {};  // Not supported.
651     } else {
652       return ExtractVector<uint8_t>(output_);
653     }
654   }
655 
656  protected:
657   int input_;
658   int filter_;
659   int bias_;
660   int output_;
661 
662   const TensorType input_type_;
663   const TensorType filter_type_;
664 };
665 
666 // In this tests we set the input and output scales so that the results
667 // match exactly the 'non-quantized' version.
TEST(ConvolutionOpTest,SimpleTestQuantized)668 TEST(ConvolutionOpTest, SimpleTestQuantized) {
669   ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
670                        {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
671                        {TensorType_UINT8, {}, -127, 128});
672   m.SetInput({
673       // First batch
674       1, 1, 1, 1,  // row = 1
675       2, 2, 2, 2,  // row = 2
676       // Second batch
677       1, 2, 3, 4,  // row = 1
678       1, 2, 3, 4,  // row = 2
679   });
680   m.SetFilter({
681       1, 2, 3, 4,    // first 2x2 filter
682       -1, 1, -1, 1,  // second 2x2 filter
683       -1, -1, 1, 1,  // third 2x2 filter
684   });
685   m.SetBias({1, 2, 3});
686 
687   m.Invoke();
688 
689   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
690                                  {
691                                      18, 2, 5,  // first batch, left
692                                      18, 2, 5,  // first batch, right
693                                      17, 4, 3,  // second batch, left
694                                      37, 4, 3,  // second batch, right
695                                  },
696                                  1e-5)));
697   // For good  measure, let's also verify the quantized values:
698   EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
699                                           145, 129, 132,  //
700                                           145, 129, 132,  //
701                                           144, 131, 130,  //
702                                           164, 131, 130,  //
703                                       }));
704 }
705 
TEST(ConvolutionOpTest,FloatInputQuantizedWeights)706 TEST(ConvolutionOpTest, FloatInputQuantizedWeights) {
707   ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
708                        {TensorType_UINT8, {3, 2, 2, 1}, 0, 64},
709                        {TensorType_FLOAT32, {}});
710   m.SetInput({
711       // First batch
712       1, 1, 1, 2,  // row = 1
713       2, 2, 2, 1,  // row = 2
714       // Second batch
715       1, 2, 3, 4,  // row = 1
716       1, 2, 3, 4,  // row = 2
717   });
718   m.SetFilter({
719       1, 2, 3, 4,  // first 2x2 filter
720       0, 1, 0, 1,  // second 2x2 filter
721       0, 0, 1, 1,  // third 2x2 filter
722   });
723   m.SetBias({1, 2, 3});
724 
725   m.Invoke();
726 
727   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
728                                  {
729                                      18, 5, 7,    // first batch, left
730                                      16, 5, 6,    // first batch, right
731                                      17, 6, 6,    // second batch, left
732                                      37, 10, 10,  // second batch, right
733                                  },
734                                  0.2)));
735 }
736 
TEST(ConvolutionOpTest,NoActivation)737 TEST(ConvolutionOpTest, NoActivation) {
738   ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
739                        {TensorType_FLOAT32, {3, 2, 2, 1}},
740                        {TensorType_FLOAT32, {}});
741 
742   m.SetInput({
743       // First batch
744       1, 1, 1, 1,  // row = 1
745       2, 2, 2, 2,  // row = 2
746       // Second batch
747       1, 2, 3, 4,  // row = 1
748       1, 2, 3, 4,  // row = 2
749   });
750   m.SetFilter({
751       1, 2, 3, 4,    // first 2x2 filter
752       -1, 1, -1, 1,  // second 2x2 filter
753       -1, -1, 1, 1,  // third 2x2 filter
754   });
755   m.SetBias({1, 2, 3});
756 
757   m.Invoke();
758 
759   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
760                                  18, 2, 5,  // first batch, left
761                                  18, 2, 5,  // first batch, right
762                                  17, 4, 3,  // second batch, left
763                                  37, 4, 3,  // second batch, right
764                              }));
765 }
766 
TEST(ConvolutionOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1)767 TEST(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) {
768   // output_multiplier = 1.0118
769   ConvolutionOpModel quant_op({TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128},
770                               {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128},
771                               {TensorType_UINT8, {}, -127, 128});
772   ConvolutionOpModel float_op({TensorType_FLOAT32, {2, 2, 4, 1}},
773                               {TensorType_FLOAT32, {3, 2, 2, 1}},
774                               {TensorType_FLOAT32, {}});
775   std::initializer_list<float> input = {
776       // First batch
777       1, 1, 1, 1,  // row = 1
778       2, 2, 2, 2,  // row = 2
779       // Second batch
780       1, 2, 3, 4,  // row = 1
781       1, 2, 3, 4,  // row = 2
782   };
783   std::initializer_list<float> filter = {
784       1,  2,  3,  4,  // first 2x2 filter
785       -1, 1,  -1, 1,  // second 2x2 filter
786       -1, -1, 1,  1,  // third 2x2 filter
787   };
788   std::initializer_list<float> bias = {1, 2, 3};
789 
790   quant_op.SetInput(input);
791   quant_op.SetFilter(filter);
792   quant_op.SetBias(bias);
793   quant_op.Invoke();
794 
795   float_op.SetInput(input);
796   float_op.SetFilter(filter);
797   float_op.SetBias(bias);
798   float_op.Invoke();
799 
800   EXPECT_THAT(quant_op.GetOutput(),
801               ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
802 }
803 
TEST(ConvolutionOpTest,SimpleTestFloatWithDilation)804 TEST(ConvolutionOpTest, SimpleTestFloatWithDilation) {
805   const int depth = 1;
806   const int image_width = 9;
807   const int image_height = 9;
808   const int image_batch_count = 1;
809   const int filter_size = 3;
810   const int filter_count = 1;
811   const int stride_width = 1;
812   const int stride_height = 1;
813   const int dilation_width_factor = 3;
814   const int dilation_height_factor = 3;
815   const Padding padding = Padding_VALID;
816   ConvolutionOpModel m(
817       {TensorType_FLOAT32,
818        {image_batch_count, image_height, image_width, depth}},
819       {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
820       {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
821       ActivationFunctionType_NONE, dilation_width_factor,
822       dilation_height_factor);
823 
824   // The image matrix is:
825   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
826   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
827   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
828   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
829   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
830   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
831   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
832   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
833   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
834   // clang-format off
835   m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
836               0, 0, 0, 0, 0, 0, 0, 0, 0,
837               0, 0, 0, 0, 0, 0, 0, 0, 0,
838               0, 0, 0, 1, 1, 1, 0, 0, 0,
839               0, 0, 0, 1, 1, 1, 0, 0, 0,
840               0, 0, 0, 1, 1, 1, 0, 0, 0,
841               0, 0, 0, 0, 0, 0, 0, 0, 0,
842               0, 0, 0, 0, 0, 0, 0, 0, 0,
843               0, 0, 0, 0, 0, 0, 0, 0, 0});
844   // clang-format on
845   // The filter matrix is:
846   // | 1 | 2 | 3 |
847   // | 4 | 5 | 6 |
848   // | 7 | 8 | 9 |
849   m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
850   // Zero bias for this test.
851   m.SetBias({0});
852   m.Invoke();
853 
854   // Since the dilation rate is 3 this will reduce the size of the output from
855   // 10x10 to 3x3 of all 5s. Specifically:
856   // | 5 | 5 | 5 |
857   // | 5 | 5 | 5 |
858   // | 5 | 5 | 5 |
859   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 5, 5, 5, 5}));
860 }
861 
862 class QuantizedConvolutionOpModel : public ConvolutionOpModel {
863  public:
864   using ConvolutionOpModel::ConvolutionOpModel;
865 
SetInput(std::initializer_list<float> data)866   void SetInput(std::initializer_list<float> data) {
867     QuantizeAndPopulate<uint8_t>(input_, data);
868   }
869 
SetFilter(std::initializer_list<float> data)870   void SetFilter(std::initializer_list<float> data) {
871     QuantizeAndPopulate<uint8_t>(filter_, data);
872   }
873 
SetBias(std::initializer_list<float> data)874   void SetBias(std::initializer_list<float> data) {
875     QuantizeAndPopulate<int32_t>(bias_, data);
876   }
877 
GetOutput()878   std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
GetDequantizedOutput()879   std::vector<float> GetDequantizedOutput() {
880     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
881                                GetScale(output_), GetZeroPoint(output_));
882   }
883 };
884 
TEST(ConvolutionOpTest,SimpleTestQuantizedWithDilation)885 TEST(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
886   const int depth = 1;
887   const int image_width = 9;
888   const int image_height = 9;
889   const int image_batch_count = 1;
890   const int filter_size = 3;
891   const int filter_count = 1;
892   const int stride_width = 1;
893   const int stride_height = 1;
894   const int dilation_width_factor = 3;
895   const int dilation_height_factor = 3;
896   const Padding padding = Padding_VALID;
897   ConvolutionOpModel m({TensorType_UINT8,
898                         {image_batch_count, image_height, image_width, depth},
899                         0,
900                         127.5},
901                        {TensorType_UINT8,
902                         {depth, filter_size, filter_size, filter_count},
903                         0,
904                         127.5},
905                        {TensorType_UINT8, {}, 0, 255}, stride_width,
906                        stride_height, padding, ActivationFunctionType_NONE,
907                        dilation_width_factor, dilation_height_factor);
908 
909   // The image matrix is:
910   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
911   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
912   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
913   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
914   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
915   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
916   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
917   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
918   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
919   // clang-format off
920   m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
921               0, 0, 0, 0, 0, 0, 0, 0, 0,
922               0, 0, 0, 0, 0, 0, 0, 0, 0,
923               0, 0, 0, 1, 1, 1, 0, 0, 0,
924               0, 0, 0, 1, 1, 1, 0, 0, 0,
925               0, 0, 0, 1, 1, 1, 0, 0, 0,
926               0, 0, 0, 0, 0, 0, 0, 0, 0,
927               0, 0, 0, 0, 0, 0, 0, 0, 0,
928               0, 0, 0, 0, 0, 0, 0, 0, 0});
929   // clang-format on
930   // The filter matrix is:
931   // | 1 | 2 | 3 |
932   // | 4 | 5 | 6 |
933   // | 7 | 8 | 9 |
934   m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
935   // Zero bias for this test.
936   m.SetBias({0});
937   m.Invoke();
938 
939   // Since the dilation rate is 3 this will reduce the size of the output from
940   // 10x10 to 3x3 of all 5s. Specifically:
941   // | 5 | 5 | 5 |
942   // | 5 | 5 | 5 |
943   // | 5 | 5 | 5 |
944   EXPECT_THAT(m.GetQuantizedOutput(),
945               ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
946 }
947 
948 class PerChannelQuantizedConvolutionWithConstantFilterOpModel
949     : public SingleOpModelWithNNAPI {
950  public:
PerChannelQuantizedConvolutionWithConstantFilterOpModel(const TensorData & input,const TensorData & filter,std::initializer_list<int8_t> filter_data,std::initializer_list<int32_t> bias_data,const TensorData & output,int stride_width=2,int stride_height=2,enum Padding padding=Padding_VALID,enum ActivationFunctionType activation=ActivationFunctionType_NONE,int dilation_width_factor=1,int dilation_height_factor=1)951   PerChannelQuantizedConvolutionWithConstantFilterOpModel(
952       const TensorData& input, const TensorData& filter,
953       std::initializer_list<int8_t> filter_data,
954       std::initializer_list<int32_t> bias_data, const TensorData& output,
955       int stride_width = 2, int stride_height = 2,
956       enum Padding padding = Padding_VALID,
957       enum ActivationFunctionType activation = ActivationFunctionType_NONE,
958       int dilation_width_factor = 1, int dilation_height_factor = 1)
959       : input_type_(input.type), filter_type_(filter.type) {
960     CHECK(filter.per_channel_quantization);
961     input_ = AddInput(input);
962     filter_ = AddConstInput(filter, filter_data);
963 
964     const int bias_size = GetShape(filter_)[0];
965     const int num_channels = filter.per_channel_quantization_scales.size();
966     const std::vector<int64_t> bias_offsets(num_channels, 0);
967     std::vector<float> bias_scales(num_channels);
968     for (int i = 0; i < num_channels; i++) {
969       bias_scales[i] = input.scale * filter.per_channel_quantization_scales[i];
970     }
971     const TensorData bias{TensorType_INT32,
972                           {bias_size},
973                           /*min=*/0,
974                           /*max=*/0,
975                           /*scale=*/0,
976                           /*zero_point=*/0,
977                           /*per_channel_quantization=*/true,
978                           /*per_channel_quantization_scales=*/bias_scales,
979                           /*per_channel_quantization_offsets=*/bias_offsets,
980                           /*channel_index==*/0};
981     bias_ = AddConstInput(bias, bias_data);
982 
983     output_ = AddOutput(output);
984 
985     SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
986                  CreateConv2DOptions(
987                      builder_, padding, stride_width, stride_height, activation,
988                      dilation_width_factor, dilation_height_factor)
989                      .Union());
990 
991     BuildInterpreterWithNNAPI(
992         {GetShape(input_), GetShape(filter_), GetShape(bias_)});
993   }
994 
SetInput(std::initializer_list<float> data)995   void SetInput(std::initializer_list<float> data) {
996     QuantizeAndPopulate<int8_t>(input_, data);
997   }
998 
GetOutput()999   std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
1000 
1001  protected:
1002   int input_;
1003   int filter_;
1004   int bias_;
1005   int output_;
1006 
1007   const TensorType input_type_;
1008   const TensorType filter_type_;
1009 };
1010 
TEST(ConvolutionOpTest,SimplePerChannelTest)1011 TEST(ConvolutionOpTest, SimplePerChannelTest) {
1012   PerChannelQuantizedConvolutionWithConstantFilterOpModel m(
1013       {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
1014       {TensorType_INT8,
1015        // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1016        {2, 2, 2, 2},
1017        /*min=*/0,
1018        /*max=*/0,
1019        /*scale=*/0,
1020        /*zero_point=*/0,
1021        /*per_channel_quantization=*/true,
1022        /*per_channel_quantization_scales=*/{1, 2},
1023        /*per_channel_quantization_offsets=*/{0, 0},
1024        /*channel_index=*/0},
1025       /*filter_data=*/
1026       {
1027           // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1028           1, 2,  // out channel = 0, y = 0, x = 0
1029           3, 4,  // out channel = 0, y = 0, x = 1
1030           3, 4,  // out channel = 0, y = 1, x = 0
1031           5, 6,  // out channel = 0, y = 1, x = 1
1032           4, 4,  // out channel = 1, y = 0, x = 0
1033           3, 3,  // out channel = 1, y = 0, x = 1
1034           2, 2,  // out channel = 1, y = 1, x = 0
1035           1, 1,  // out channel = 1, y = 1, x = 1
1036       },
1037       /*bias_data=*/{6, -2}, {TensorType_INT8, {}, -63.5, 64, 0.5, -1},
1038       /*stride_width=*/1, /*stride_height=*/1);
1039   m.SetInput({
1040       // [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
1041       3, 2,    // batch = 0, y = 0, x = 0
1042       1, -1,   // batch = 0, y = 0, x = 1
1043       -2, -3,  // batch = 0, y = 0, x = 2
1044       4, 3,    // batch = 0, y = 1, x = 0
1045       2, -2,   // batch = 0, y = 1, x = 1
1046       -3, -4,  // batch = 0, y = 1, x = 2
1047   });
1048 
1049   // Invoke and verify output.
1050   // output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
1051   m.Invoke();
1052   EXPECT_THAT(m.GetOutput(),
1053               testing::Pointwise(QuantizedNear(), {61, 127, -115, -93}));
1054 }
1055 
1056 class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
1057  public:
DepthwiseConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output)1058   DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
1059                               const TensorData& output)
1060       : input_type_(input.type) {
1061     input_ = AddInput(input);
1062     filter_ = AddInput(filter);
1063 
1064     int bias_size = GetShape(filter_)[3];
1065     if (input.type == TensorType_FLOAT32) {
1066       bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
1067     } else {
1068       // This is a quantized version. The scale of 'bias' depends on the scales
1069       // of input and filter. Supposedly this is correctly set during quantized
1070       // training.
1071       auto bias_scale = GetScale(input_) * GetScale(filter_);
1072       TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
1073       bias_ = AddInput(bias);
1074     }
1075 
1076     output_ = AddOutput(output);
1077 
1078     int input_depth = GetShape(input_)[3];
1079     int output_depth = GetShape(filter_)[3];
1080     int depth_mul = output_depth / input_depth;
1081 
1082     SetBuiltinOp(
1083         BuiltinOperator_DEPTHWISE_CONV_2D,
1084         BuiltinOptions_DepthwiseConv2DOptions,
1085         CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
1086                                      ActivationFunctionType_NONE)
1087             .Union());
1088 
1089     BuildInterpreterWithNNAPI(
1090         {GetShape(input_), GetShape(filter_), GetShape(bias_)});
1091   }
1092 
SetInput(std::initializer_list<float> data)1093   void SetInput(std::initializer_list<float> data) {
1094     SetData(input_, input_type_, data);
1095   }
1096 
SetFilter(std::initializer_list<float> data)1097   void SetFilter(std::initializer_list<float> data) {
1098     SetData(filter_, input_type_, data);
1099   }
1100 
SetBias(std::initializer_list<float> data)1101   void SetBias(std::initializer_list<float> data) {
1102     const auto bias_type =
1103         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
1104     SetData(bias_, bias_type, data);
1105   }
1106 
GetOutput()1107   std::vector<float> GetOutput() {
1108     if (input_type_ == TensorType_FLOAT32) {
1109       return ExtractVector<float>(output_);
1110     } else {
1111       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1112                                  GetScale(output_), GetZeroPoint(output_));
1113     }
1114   }
1115 
1116  protected:
1117   int input_;
1118   int filter_;
1119   int bias_;
1120   int output_;
1121 
1122   const TensorType input_type_;
1123 };
1124 
TEST(NNAPIDelegate,DepthwiseConv2DWithNoActivation)1125 TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
1126   DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
1127                                 {TensorType_FLOAT32, {1, 2, 2, 4}},
1128                                 {TensorType_FLOAT32, {}});
1129 
1130   m.SetInput({
1131       1, 2, 7, 8,    // column 1
1132       3, 4, 9, 10,   // column 2
1133       5, 6, 11, 12,  // column 3
1134   });
1135   m.SetFilter({
1136       1, 2, 3, 4,        //
1137       -9, 10, -11, 12,   //
1138       5, 6, 7, 8,        //
1139       13, -14, 15, -16,  //
1140   });
1141   m.SetBias({1, 2, 3, 4});
1142 
1143   m.Invoke();
1144 
1145   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
1146                                  71, -34, 99, -20,  //
1147                                  91, -26, 127, -4,  //
1148                              }));
1149 }
1150 
TEST(QuantizedDepthwiseConv2DTest,FilterMultiplierGreaterThan1)1151 TEST(QuantizedDepthwiseConv2DTest, FilterMultiplierGreaterThan1) {
1152   DepthwiseConvolutionOpModel quant_op(
1153       {TensorType_UINT8, {1, 3, 2, 2}, -128.5, 128},
1154       {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
1155       {TensorType_UINT8, {}, -127, 128});
1156   DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
1157                                        {TensorType_FLOAT32, {1, 2, 2, 4}},
1158                                        {TensorType_FLOAT32, {}});
1159 
1160   std::initializer_list<float> input = {
1161       1, 2, 7,  8,   // column 1
1162       3, 4, 9,  10,  // column 2
1163       5, 6, 11, 12,  // column 3
1164   };
1165   std::initializer_list<float> filter = {
1166       1,  2,   3,   4,    //
1167       -9, 10,  -11, 12,   //
1168       5,  6,   7,   8,    //
1169       13, -14, 15,  -16,  //
1170   };
1171   std::initializer_list<float> bias = {1, 2, 3, 4};
1172 
1173   quant_op.SetInput(input);
1174   quant_op.SetFilter(filter);
1175   quant_op.SetBias(bias);
1176   quant_op.Invoke();
1177 
1178   float_op.SetInput(input);
1179   float_op.SetFilter(filter);
1180   float_op.SetBias(bias);
1181   float_op.Invoke();
1182 
1183   EXPECT_THAT(quant_op.GetOutput(),
1184               ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
1185 }
1186 
1187 class FullyConnectedOpModel : public SingleOpModelWithNNAPI {
1188  public:
FullyConnectedOpModel(const TensorData & input,const TensorData & weights,const TensorData & output,enum ActivationFunctionType activation=ActivationFunctionType_NONE)1189   FullyConnectedOpModel(
1190       const TensorData& input, const TensorData& weights,
1191       const TensorData& output,
1192       enum ActivationFunctionType activation = ActivationFunctionType_NONE)
1193       : input_type_(input.type), weights_type_(weights.type) {
1194     input_ = AddInput(input);
1195     weights_ = AddInput(weights);
1196 
1197     const int units = weights.shape[0];
1198     if (input.type == TensorType_FLOAT32) {
1199       bias_ = AddInput({TensorType_FLOAT32, {units}});
1200     } else {
1201       // This is a quantized version. The scale of 'bias' depends on the scales
1202       // of input and filter. Supposedly this is correctly set during quantized
1203       // training.
1204       auto bias_scale = GetScale(input_) * GetScale(weights_);
1205       TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale};
1206       bias_ = AddInput(bias);
1207     }
1208 
1209     output_ = AddOutput(output);
1210 
1211     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
1212                  BuiltinOptions_FullyConnectedOptions,
1213                  CreateFullyConnectedOptions(builder_, activation).Union());
1214     BuildInterpreterWithNNAPI(
1215         {GetShape(input_), GetShape(weights_), GetShape(bias_)});
1216   }
1217 
SetInput(std::initializer_list<float> data)1218   void SetInput(std::initializer_list<float> data) {
1219     SetData(input_, input_type_, data);
1220   }
1221 
SetWeights(std::initializer_list<float> data)1222   void SetWeights(std::initializer_list<float> data) {
1223     SetData(weights_, weights_type_, data);
1224   }
1225 
SetBias(std::initializer_list<float> data)1226   void SetBias(std::initializer_list<float> data) {
1227     const auto bias_type =
1228         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
1229     SetData(bias_, bias_type, data);
1230   }
1231 
GetOutput()1232   std::vector<float> GetOutput() {
1233     if (input_type_ == TensorType_FLOAT32) {
1234       return ExtractVector<float>(output_);
1235     } else {
1236       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1237                                  GetScale(output_), GetZeroPoint(output_));
1238     }
1239   }
1240 
1241  protected:
1242   int input_;
1243   int weights_;
1244   int bias_;
1245   int output_;
1246 
1247   const TensorType input_type_;
1248   const TensorType weights_type_;
1249 };
1250 
TEST(FullyConnectedOpTest,SimpleTest)1251 TEST(FullyConnectedOpTest, SimpleTest) {
1252   FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
1253                           /*weights=*/{TensorType_FLOAT32, {3, 10}},
1254                           /*output=*/{TensorType_FLOAT32});
1255   m.SetWeights({
1256       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1257       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1258       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1259   });
1260   m.SetBias({1, 2, 3});
1261 
1262   m.SetInput({
1263       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1264       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1265   });
1266 
1267   m.Invoke();
1268 
1269   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
1270 }
1271 
TEST(FullyConnectedOpTest,FloatInputQuantizedWeights)1272 TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) {
1273   FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
1274                           /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64},
1275                           /*output=*/{TensorType_FLOAT32});
1276   m.SetWeights({
1277       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1278       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1279       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1280   });
1281   m.SetBias({1, 2, 3});
1282 
1283   m.SetInput({
1284       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1285       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1286   });
1287 
1288   m.Invoke();
1289 
1290   EXPECT_THAT(m.GetOutput(),
1291               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3)));
1292 }
1293 
TEST(FullyConnectedOpTest,QuantizedOutputMultiplierGreaterThan1)1294 TEST(FullyConnectedOpTest, QuantizedOutputMultiplierGreaterThan1) {
1295   // real_multiplier = 2.
1296   FullyConnectedOpModel m(
1297       /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
1298       /*weights=*/{TensorType_UINT8, {3, 10}, -127, 128},
1299       /*output=*/{TensorType_UINT8, {}, -63.5, 64});
1300 
1301   m.SetWeights({
1302       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1303       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1304       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1305   });
1306   m.SetBias({1, 2, 3});
1307 
1308   m.SetInput({
1309       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1310       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1311   });
1312 
1313   m.Invoke();
1314 
1315   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1316                                  24, 25, 26,  // first batch
1317                                  58, 59, 60,  // second batch
1318                              })));
1319 }
1320 
1321 class SoftmaxOpModel : public SingleOpModelWithNNAPI {
1322  public:
SoftmaxOpModel(const TensorData & input,float beta)1323   SoftmaxOpModel(const TensorData& input, float beta) {
1324     input_ = AddInput(input);
1325     output_ = AddOutput(input);
1326     SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
1327                  CreateSoftmaxOptions(builder_, beta).Union());
1328     BuildInterpreterWithNNAPI({GetShape(input_)});
1329   }
1330 
SetInput(std::initializer_list<float> data)1331   void SetInput(std::initializer_list<float> data) {
1332     PopulateTensor(input_, data);
1333   }
1334 
SetInput(int offset,float * begin,float * end)1335   void SetInput(int offset, float* begin, float* end) {
1336     PopulateTensor(input_, offset, begin, end);
1337   }
1338 
GetOutput()1339   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1340 
1341  private:
1342   int input_;
1343   int output_;
1344 };
1345 
TEST(SoftmaxOpTest,SimpleTest)1346 TEST(SoftmaxOpTest, SimpleTest) {
1347   SoftmaxOpModel m({TensorType_FLOAT32, {2, 5}}, /*beta=*/1.0);
1348   m.SetInput({
1349       1.0, 2.0, 3.0, 4.0, 5.0,       // b = 0
1350       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 1
1351   });
1352 
1353   m.Invoke();
1354 
1355   EXPECT_THAT(
1356       m.GetOutput(),
1357       NnapiArrayFloatNear({0.011656231, 0.031684921, 0.086128544, 0.234121657,
1358                            0.636408647, 0.636408647, 0.234121657, 0.086128544,
1359                            0.031684921, 0.011656231}));
1360 }
1361 
TEST(SoftmaxOpTest,Beta2)1362 TEST(SoftmaxOpTest, Beta2) {
1363   SoftmaxOpModel m({TensorType_FLOAT32, {1, 5}}, /*beta=*/2.0);
1364   m.SetInput({
1365       1.0, 2.0, 3.0, 4.0, 5.0,  // b = 0
1366   });
1367 
1368   m.Invoke();
1369 
1370   EXPECT_THAT(m.GetOutput(),
1371               NnapiArrayFloatNear({0.000290076, 0.002143387, 0.015837606,
1372                                    0.117024957, 0.864703974}));
1373 }
1374 
1375 TEST(SoftmaxOpTest, 3dInput) {
1376   SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 5}}, /*beta=*/1.0);
1377   m.SetInput({
1378       1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
1379       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
1380       5.0,  1.0,  2.0,  3.0,  4.0,   // b = 1
1381       -5.0, -1.0, -2.0, -3.0, -4.0,  // b = 1
1382   });
1383 
1384   m.Invoke();
1385 
1386   EXPECT_THAT(
1387       m.GetOutput(),
1388       NnapiArrayFloatNear(
1389           {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
1390            0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231,
1391            0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657,
1392            0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}));
1393 }
1394 
1395 TEST(SoftmaxOpTest, 4dInput) {
1396   SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 1, 5}}, /*beta=*/1.0);
1397   m.SetInput({
1398       1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
1399       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
1400       5.0,  1.0,  2.0,  3.0,  4.0,   // b = 1
1401       -5.0, -1.0, -2.0, -3.0, -4.0,  // b = 1
1402   });
1403 
1404   m.Invoke();
1405 
1406   EXPECT_THAT(
1407       m.GetOutput(),
1408       NnapiArrayFloatNear(
1409           {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
1410            0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231,
1411            0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657,
1412            0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}));
1413 }
1414 
1415 class ReshapeOpModel : public SingleOpModelWithNNAPI {
1416  public:
ReshapeOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> new_shape)1417   ReshapeOpModel(std::initializer_list<int> input_shape,
1418                  std::initializer_list<int> new_shape) {
1419     input_ = AddInput(TensorType_FLOAT32);
1420     new_shape_ = AddConstInput<int>(TensorType_INT32, new_shape,
1421                                     {static_cast<int>(new_shape.size())});
1422     output_ = AddOutput(TensorType_FLOAT32);
1423     SetBuiltinOp(
1424         BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
1425         CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
1426             .Union());
1427     BuildInterpreterWithNNAPI(
1428         {input_shape, {static_cast<int>(new_shape.size())}});
1429   }
1430 
SetInput(std::initializer_list<float> data)1431   void SetInput(std::initializer_list<float> data) {
1432     PopulateTensor<float>(input_, data);
1433   }
GetOutput()1434   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1435   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1436 
1437  private:
1438   int input_;
1439   int new_shape_;
1440   int output_;
1441 };
1442 
TEST(NNAPIDelegate,ReshapeSimpleTest)1443 TEST(NNAPIDelegate, ReshapeSimpleTest) {
1444   ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
1445   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
1446   m.Invoke();
1447   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 2, 3, 4, 5, 6, 7, 8}));
1448   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
1449 }
1450 
1451 class SqueezeOpModel : public SingleOpModelWithNNAPI {
1452  public:
SqueezeOpModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis)1453   SqueezeOpModel(const TensorData& input, const TensorData& output,
1454                  std::initializer_list<int> axis) {
1455     input_ = AddInput(input);
1456     output_ = AddOutput(output);
1457     SetBuiltinOp(
1458         BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
1459         CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
1460             .Union());
1461     BuildInterpreterWithNNAPI({GetShape(input_)});
1462   }
1463 
SetInput(std::initializer_list<float> data)1464   void SetInput(std::initializer_list<float> data) {
1465     PopulateTensor<float>(input_, data);
1466   }
GetOutput()1467   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1468   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1469 
1470  private:
1471   int input_;
1472   int new_shape_;
1473   int output_;
1474 };
1475 
1476 // TODO(b/215935381): Enable after resolving issues with flakiness.
TEST(NNAPIDelegate,DISABLED_SqueezeSimpleTest)1477 TEST(NNAPIDelegate, DISABLED_SqueezeSimpleTest) {
1478   std::initializer_list<float> data = {
1479       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
1480       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
1481   SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
1482                    {});
1483   m.SetInput(data);
1484   m.Invoke();
1485   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
1486   EXPECT_THAT(
1487       m.GetOutput(),
1488       NnapiArrayFloatNear({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
1489                            9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
1490                            17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
1491 }
1492 
TEST(NNAPIDelegate,SqueezeWithAxisTest)1493 TEST(NNAPIDelegate, SqueezeWithAxisTest) {
1494   std::initializer_list<float> data = {
1495       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
1496       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
1497   SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
1498                    {2});
1499   m.SetInput(data);
1500   m.Invoke();
1501   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
1502   EXPECT_THAT(
1503       m.GetOutput(),
1504       NnapiArrayFloatNear({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
1505                            9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
1506                            17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
1507 }
1508 
1509 class L2NormOpModel : public SingleOpModelWithNNAPI {
1510  public:
L2NormOpModel(const TensorData & input,const TensorData & output,ActivationFunctionType activation_type)1511   L2NormOpModel(const TensorData& input, const TensorData& output,
1512                 ActivationFunctionType activation_type) {
1513     input_ = AddInput(input);
1514     output_ = AddOutput(output);
1515     SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
1516                  CreateL2NormOptions(builder_, activation_type).Union());
1517     BuildInterpreterWithNNAPI({GetShape(input_)});
1518   }
1519 
SetInput(std::initializer_list<float> data)1520   void SetInput(std::initializer_list<float> data) {
1521     PopulateTensor<float>(input_, data);
1522   }
GetOutput()1523   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1524   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1525 
1526  private:
1527   int input_;
1528   int new_shape_;
1529   int output_;
1530 };
1531 
TEST(NNAPIDelegate,L2NormSimpleTest)1532 TEST(NNAPIDelegate, L2NormSimpleTest) {
1533   std::initializer_list<float> data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1};
1534   L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}},
1535                   {TensorType_FLOAT32, {1, 1, 1, 6}},
1536                   ActivationFunctionType_NONE);
1537   m.SetInput(data);
1538   m.Invoke();
1539   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6}));
1540   EXPECT_THAT(m.GetOutput(),
1541               NnapiArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
1542 }
1543 
1544 class TransposeSimpleModel : public SingleOpModelWithNNAPI {
1545  public:
TransposeSimpleModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape,std::initializer_list<int> perm)1546   TransposeSimpleModel(std::initializer_list<int> input_shape,
1547                        std::initializer_list<int> perm_shape,
1548                        std::initializer_list<int> perm) {
1549     input_ = AddInput(TensorType_FLOAT32);
1550     perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
1551     output_ = AddOutput(TensorType_FLOAT32);
1552     SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
1553                  CreateTransposeOptions(builder_).Union());
1554     BuildInterpreterWithNNAPI({input_shape, perm_shape});
1555   }
1556 
SetInput(std::initializer_list<float> data)1557   void SetInput(std::initializer_list<float> data) {
1558     PopulateTensor<float>(input_, data);
1559   }
1560 
GetOutput()1561   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1562   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1563 
1564  private:
1565   int input_;
1566   int perm_;
1567   int output_;
1568 };
1569 
TEST(NNAPIDelegate,TransposeSimpleTest)1570 TEST(NNAPIDelegate, TransposeSimpleTest) {
1571   TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
1572   m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
1573               12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
1574   m.Invoke();
1575   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
1576   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear(
1577                                  {0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
1578                                   2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
1579 }
1580 
1581 class ElementwiseOpBaseModel : public SingleOpModelWithNNAPI {
1582  public:
input() const1583   int input() const { return input_; }
output() const1584   int output() const { return output_; }
1585 
1586  protected:
1587   int input_;
1588   int output_;
1589 };
1590 
1591 class ElementwiseOpFloatModel : public ElementwiseOpBaseModel {
1592  public:
ElementwiseOpFloatModel(BuiltinOperator op,std::initializer_list<int> input_shape)1593   ElementwiseOpFloatModel(BuiltinOperator op,
1594                           std::initializer_list<int> input_shape) {
1595     input_ = AddInput(TensorType_FLOAT32);
1596     output_ = AddOutput(TensorType_FLOAT32);
1597     SetBuiltinOp(op, BuiltinOptions_NONE, 0);
1598     BuildInterpreterWithNNAPI({input_shape});
1599   }
1600 };
1601 
TEST(Elementwise,Abs)1602 TEST(Elementwise, Abs) {
1603   ElementwiseOpFloatModel m(BuiltinOperator_ABS, {1, 2, 4, 1});
1604   m.PopulateTensor<float>(m.input(), {
1605                                          0.f, -6.2f, 2.f, 4.f,  //
1606                                          3.f, -2.f, 10.f, 1.f,  //
1607                                      });
1608   m.Invoke();
1609   EXPECT_THAT(m.ExtractVector<float>(m.output()), NnapiArrayFloatNear({
1610                                                       0.f, 6.2f, 2.f, 4.f,  //
1611                                                       3.f, 2.f, 10.f, 1.f,  //
1612                                                   }));
1613   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 4, 1}));
1614 }
1615 
TEST(Elementwise,Exp)1616 TEST(Elementwise, Exp) {
1617   ElementwiseOpFloatModel m(BuiltinOperator_EXP, {3, 1, 2});
1618   m.PopulateTensor<float>(m.input(), {1.0, 0.0, -1.0, 1.0, 1.0, -1.0});
1619   m.Invoke();
1620   EXPECT_THAT(
1621       m.ExtractVector<float>(m.output()),
1622       NnapiArrayFloatNear({2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}));
1623   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({3, 1, 2}));
1624 }
1625 
TEST(Elementwise,Log)1626 TEST(Elementwise, Log) {
1627   ElementwiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
1628   m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
1629   m.Invoke();
1630   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1631               NnapiArrayFloatNear({0, 1.14473, 0, 0}));
1632   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1633 }
1634 
TEST(Elementwise,Rsqrt)1635 TEST(Elementwise, Rsqrt) {
1636   ElementwiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
1637   m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});
1638   m.Invoke();
1639   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1640               NnapiArrayFloatNear({1, 0.7071, 0.5, 0.33333}));
1641   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1642 }
1643 
TEST(Elementwise,Sin)1644 TEST(Elementwise, Sin) {
1645   ElementwiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
1646   m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
1647   m.Invoke();
1648   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1649               NnapiArrayFloatNear({0, 0, 0, 0.84147}));
1650   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1651 }
1652 
TEST(Elementwise,Sqrt)1653 TEST(Elementwise, Sqrt) {
1654   ElementwiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
1655   m.PopulateTensor<float>(m.input(), {0, 1, 2, 4});
1656   m.Invoke();
1657   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1658               NnapiArrayFloatNear({0, 1, 1.41421, 2}));
1659   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1660 }
1661 
1662 class FloatSubOpModel : public SingleOpModelWithNNAPI {
1663  public:
FloatSubOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)1664   FloatSubOpModel(const TensorData& input1, const TensorData& input2,
1665                   const TensorData& output,
1666                   ActivationFunctionType activation_type) {
1667     input1_ = AddInput(input1);
1668     input2_ = AddInput(input2);
1669     output_ = AddOutput(output);
1670     SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions,
1671                  CreateMulOptions(builder_, activation_type).Union());
1672     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
1673   }
1674 
input1()1675   int input1() { return input1_; }
input2()1676   int input2() { return input2_; }
1677 
GetOutput()1678   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1679 
1680  protected:
1681   int input1_;
1682   int input2_;
1683   int output_;
1684 };
1685 
TEST(NNAPIDelegate,SubWithNoActivation)1686 TEST(NNAPIDelegate, SubWithNoActivation) {
1687   FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
1688                     {TensorType_FLOAT32, {1, 2, 2, 1}},
1689                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
1690   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
1691   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
1692   m.Invoke();
1693   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-2.1, 0.0, 0.4, 0.3}));
1694 }
1695 
1696 class FloatDivOpModel : public SingleOpModelWithNNAPI {
1697  public:
FloatDivOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)1698   FloatDivOpModel(const TensorData& input1, const TensorData& input2,
1699                   const TensorData& output,
1700                   ActivationFunctionType activation_type) {
1701     input1_ = AddInput(input1);
1702     input2_ = AddInput(input2);
1703     output_ = AddOutput(output);
1704     SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions,
1705                  CreateMulOptions(builder_, activation_type).Union());
1706     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
1707   }
1708 
input1()1709   int input1() { return input1_; }
input2()1710   int input2() { return input2_; }
1711 
GetOutput()1712   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1713 
1714  protected:
1715   int input1_;
1716   int input2_;
1717   int output_;
1718 };
1719 
TEST(NNAPIDelegate,DivWithNoActivation)1720 TEST(NNAPIDelegate, DivWithNoActivation) {
1721   FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
1722                     {TensorType_FLOAT32, {1, 2, 2, 1}},
1723                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
1724   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.8, 0.8});
1725   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.4, 0.2});
1726   m.Invoke();
1727   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-20, 1, 2, 4}));
1728 }
1729 
1730 class BaseConcatenationOpModel : public SingleOpModelWithNNAPI {
1731  public:
BaseConcatenationOpModel()1732   BaseConcatenationOpModel() {}
BaseConcatenationOpModel(const TensorData & input_template,int axis,int num_inputs)1733   BaseConcatenationOpModel(const TensorData& input_template, int axis,
1734                            int num_inputs) {
1735     std::vector<std::vector<int>> all_input_shapes;
1736     for (int i = 0; i < num_inputs; ++i) {
1737       all_input_shapes.push_back(input_template.shape);
1738       AddInput(input_template);
1739     }
1740     output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
1741                          input_template.max});
1742     SetBuiltinOp(
1743         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
1744         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
1745             .Union());
1746     BuildInterpreterWithNNAPI(all_input_shapes);
1747   }
1748 
1749  protected:
1750   int output_;
1751 };
1752 
1753 class ConcatenationOpModel : public BaseConcatenationOpModel {
1754  public:
1755   using BaseConcatenationOpModel::BaseConcatenationOpModel;
SetInput(int index,std::initializer_list<float> data)1756   void SetInput(int index, std::initializer_list<float> data) {
1757     PopulateTensor(index, data);
1758   }
GetOutput()1759   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1760 };
1761 
TEST(NNAPIDelegate,ConcatenationThreeDimensionalOneInput)1762 TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) {
1763   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
1764                           /*num_inputs=*/1);
1765   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1766   m0.Invoke();
1767   EXPECT_THAT(m0.GetOutput(), NnapiArrayFloatNear({1, 3, 4, 7}));
1768 }
1769 
TEST(NNAPIDelegate,ConcatenationFourInputs)1770 TEST(NNAPIDelegate, ConcatenationFourInputs) {
1771   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
1772                           /*num_inputs=*/4);
1773   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1774   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1775   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1776   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1777   m0.Invoke();
1778   EXPECT_THAT(m0.GetOutput(),
1779               NnapiArrayFloatNear({
1780                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1781                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1782               }));
1783 }
1784 
1785 class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
1786  public:
1787   using BaseConcatenationOpModel::BaseConcatenationOpModel;
QuantizedConcatenationOpModel(const std::vector<TensorData> & input_template,int axis,int num_inputs,const TensorData & output_template)1788   QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
1789                                 int axis, int num_inputs,
1790                                 const TensorData& output_template) {
1791     std::vector<std::vector<int>> all_input_shapes;
1792     CHECK_EQ(input_template.size(), num_inputs);
1793     for (int i = 0; i < num_inputs; ++i) {
1794       all_input_shapes.push_back(input_template[i].shape);
1795       AddInput(input_template[i]);
1796     }
1797     output_ = AddOutput({output_template.type, /*shape=*/{},
1798                          output_template.min, output_template.max});
1799     SetBuiltinOp(
1800         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
1801         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
1802             .Union());
1803     BuildInterpreterWithNNAPI(all_input_shapes);
1804   }
SetInput(int index,std::initializer_list<float> data)1805   void SetInput(int index, std::initializer_list<float> data) {
1806     QuantizeAndPopulate<uint8_t>(index, data);
1807   }
GetOutput()1808   std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
GetDequantizedOutput()1809   std::vector<float> GetDequantizedOutput() {
1810     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1811                                GetScale(output_), GetZeroPoint(output_));
1812   }
1813 };
1814 
TEST(NNAPIDelegate,ConcatenationFourInputsQuantized)1815 TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) {
1816   QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
1817                                    /*axis=*/2,
1818                                    /*num_inputs=*/4);
1819 
1820   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1821   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1822   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1823   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1824   m0.Invoke();
1825   EXPECT_THAT(m0.GetDequantizedOutput(),
1826               ElementsAreArray(ArrayFloatNear({
1827                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1828                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1829               })));
1830   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1831                                   137, 157, 138, 158, 139, 159, 140, 160,  //
1832                                   167, 197, 168, 198, 169, 199, 170, 200,  //
1833                               }));
1834 }
1835 
TEST(NNAPIDelegate,ConcatenationFourInputsQuantizedMixedRange)1836 TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) {
1837   QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
1838                                     {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
1839                                     {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
1840                                     {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
1841                                    /*axis=*/2, /*num_inputs=*/4,
1842                                    {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});
1843 
1844   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1845   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1846   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1847   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1848   m0.Invoke();
1849   EXPECT_THAT(m0.GetDequantizedOutput(),
1850               ElementsAreArray(ArrayFloatNear({
1851                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1852                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1853               })));
1854   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1855                                   137, 157, 138, 158, 139, 159, 140, 160,  //
1856                                   167, 197, 168, 198, 169, 199, 170, 200,  //
1857                               }));
1858 }
1859 
1860 class DequantizeOpModel : public SingleOpModelWithNNAPI {
1861  public:
DequantizeOpModel(TensorType inputType,std::initializer_list<int> shape,float min,float max)1862   DequantizeOpModel(TensorType inputType, std::initializer_list<int> shape,
1863                     float min, float max) {
1864     input_ = AddInput({inputType, shape, min, max});
1865     output_ = AddOutput({TensorType_FLOAT32, shape});
1866     SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
1867                  CreateDequantizeOptions(builder_).Union());
1868 
1869     BuildInterpreterWithNNAPI({GetShape(input_)});
1870   }
1871 
1872   template <typename T>
SetInput(std::initializer_list<T> data)1873   void SetInput(std::initializer_list<T> data) {
1874     PopulateTensor(input_, data);
1875   }
1876 
GetOutput()1877   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1878 
1879  private:
1880   int input_;
1881   int output_;
1882 };
1883 
TEST(NNAPIDelegate,DequantizeFourDimensionalUint8)1884 TEST(NNAPIDelegate, DequantizeFourDimensionalUint8) {
1885   DequantizeOpModel m(TensorType_UINT8, {2, 5}, -63.5, 64);
1886 
1887   m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
1888   m.Invoke();
1889   EXPECT_THAT(m.GetOutput(),
1890               ElementsAreArray(ArrayFloatNear(
1891                   {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64})));
1892 }
1893 
TEST(NNAPIDelegate,DequantizeFourDimensionalInt8Symm)1894 TEST(NNAPIDelegate, DequantizeFourDimensionalInt8Symm) {
1895   // [-64, 63.5] -> scale=0.5, zero_point=0 for INT8
1896   DequantizeOpModel m(TensorType_INT8, {2, 5}, -64, 63.5);
1897 
1898   m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
1899   m.Invoke();
1900   EXPECT_THAT(m.GetOutput(),
1901               ElementsAreArray(ArrayFloatNear(
1902                   {-64, -63.5, -63, -62.5, -62, 61.5, 62, 62.5, 63, 63.5})));
1903 }
1904 
1905 class FloorOpModel : public SingleOpModelWithNNAPI {
1906  public:
FloorOpModel(std::initializer_list<int> input_shape,TensorType input_type)1907   FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
1908     input_ = AddInput(TensorType_FLOAT32);
1909     output_ = AddOutput(TensorType_FLOAT32);
1910     SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
1911     BuildInterpreterWithNNAPI({
1912         input_shape,
1913     });
1914   }
1915 
input()1916   int input() { return input_; }
1917 
GetOutput()1918   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1919   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1920 
1921  private:
1922   int input_;
1923   int output_;
1924 };
1925 
TEST(NNAPIDelegate,FloorSingleDim)1926 TEST(NNAPIDelegate, FloorSingleDim) {
1927   FloorOpModel model({2}, TensorType_FLOAT32);
1928   model.PopulateTensor<float>(model.input(), {8.5, 0.0});
1929   model.Invoke();
1930   EXPECT_THAT(model.GetOutput(), NnapiArrayFloatNear({8, 0}));
1931   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
1932 }
1933 
TEST(NNAPIDelegate,FloorMultiDims)1934 TEST(NNAPIDelegate, FloorMultiDims) {
1935   FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
1936   model.PopulateTensor<float>(model.input(), {
1937                                                  0.0001,
1938                                                  8.0001,
1939                                                  0.9999,
1940                                                  9.9999,
1941                                                  0.5,
1942                                                  -0.0001,
1943                                                  -8.0001,
1944                                                  -0.9999,
1945                                                  -9.9999,
1946                                                  -0.5,
1947                                              });
1948   model.Invoke();
1949   EXPECT_THAT(model.GetOutput(),
1950               NnapiArrayFloatNear({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
1951   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
1952 }
1953 
1954 class LocalResponseNormOpModel : public SingleOpModelWithNNAPI {
1955  public:
LocalResponseNormOpModel(std::initializer_list<int> input_shape,int radius,float bias,float alpha,float beta)1956   LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
1957                            float bias, float alpha, float beta) {
1958     input_ = AddInput(TensorType_FLOAT32);
1959     output_ = AddOutput(TensorType_FLOAT32);
1960     SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
1961                  BuiltinOptions_LocalResponseNormalizationOptions,
1962                  CreateLocalResponseNormalizationOptions(builder_, radius, bias,
1963                                                          alpha, beta)
1964                      .Union());
1965     BuildInterpreterWithNNAPI({input_shape});
1966   }
1967 
SetInput(std::initializer_list<float> data)1968   void SetInput(std::initializer_list<float> data) {
1969     PopulateTensor(input_, data);
1970   }
1971 
GetOutput()1972   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1973 
1974  private:
1975   int input_;
1976   int output_;
1977 };
1978 
TEST(NNAPIDelegate,LocalResponseNormSameAsL2Norm)1979 TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) {
1980   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
1981                              /*alpha=*/1.0, /*beta=*/0.5);
1982   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1983   m.Invoke();
1984   // The result is every input divided by 2.
1985   EXPECT_THAT(m.GetOutput(),
1986               NnapiArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
1987 }
1988 
TEST(NNAPIDelegate,LocalResponseNormWithAlpha)1989 TEST(NNAPIDelegate, LocalResponseNormWithAlpha) {
1990   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
1991                              /*alpha=*/4.0, /*beta=*/0.5);
1992   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
1993   m.Invoke();
1994   // The result is every input divided by 3.
1995   EXPECT_THAT(m.GetOutput(),
1996               NnapiArrayFloatNear({-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}));
1997 }
1998 
TEST(NNAPIDelegate,LocalResponseNormWithBias)1999 TEST(NNAPIDelegate, LocalResponseNormWithBias) {
2000   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
2001                              /*alpha=*/4.0, /*beta=*/0.5);
2002   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
2003   m.Invoke();
2004   // The result is every input divided by 5.
2005   EXPECT_THAT(m.GetOutput(),
2006               NnapiArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}));
2007 }
2008 
TEST(NNAPIDelegate,LocalResponseNormSmallRadius)2009 TEST(NNAPIDelegate, LocalResponseNormSmallRadius) {
2010   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
2011                              /*alpha=*/4.0, /*beta=*/0.5);
2012   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
2013   m.Invoke();
2014   EXPECT_THAT(m.GetOutput(),
2015               NnapiArrayFloatNear({-0.264926, 0.125109, 0.140112, 0.267261,
2016                                    -0.161788, 0.0244266}));
2017 }
2018 
2019 class LSHProjectionOpModel : public SingleOpModelWithNNAPI {
2020  public:
LSHProjectionOpModel(LSHProjectionType type,std::initializer_list<int> hash_shape,std::initializer_list<int> input_shape,std::initializer_list<int> weight_shape)2021   LSHProjectionOpModel(LSHProjectionType type,
2022                        std::initializer_list<int> hash_shape,
2023                        std::initializer_list<int> input_shape,
2024                        std::initializer_list<int> weight_shape) {
2025     hash_ = AddInput(TensorType_FLOAT32);
2026     input_ = AddInput(TensorType_INT32);
2027     if (weight_shape.size() > 0) {
2028       weight_ = AddInput(TensorType_FLOAT32);
2029     }
2030     output_ = AddOutput(TensorType_INT32);
2031 
2032     SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
2033                  BuiltinOptions_LSHProjectionOptions,
2034                  CreateLSHProjectionOptions(builder_, type).Union());
2035     if (weight_shape.size() > 0) {
2036       BuildInterpreterWithNNAPI({hash_shape, input_shape, weight_shape});
2037     } else {
2038       BuildInterpreterWithNNAPI({hash_shape, input_shape});
2039     }
2040 
2041     output_size_ = 1;
2042     for (int i : hash_shape) {
2043       output_size_ *= i;
2044       if (type == LSHProjectionType_SPARSE) {
2045         break;
2046       }
2047     }
2048   }
SetInput(std::initializer_list<int> data)2049   void SetInput(std::initializer_list<int> data) {
2050     PopulateTensor(input_, data);
2051   }
2052 
SetHash(std::initializer_list<float> data)2053   void SetHash(std::initializer_list<float> data) {
2054     PopulateTensor(hash_, data);
2055   }
2056 
SetWeight(std::initializer_list<float> f)2057   void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }
2058 
GetOutput()2059   std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
2060 
2061  private:
2062   int input_;
2063   int hash_;
2064   int weight_;
2065   int output_;
2066 
2067   int output_size_;
2068 };
2069 
TEST(NNAPIDelegate,LSHProjectionDense1DInputs)2070 TEST(NNAPIDelegate, LSHProjectionDense1DInputs) {
2071   LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});
2072 
2073   m.SetInput({12345, 54321, 67890, 9876, -12345678});
2074   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
2075   m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});
2076 
2077   m.Invoke();
2078 
2079 #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
2080     __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
2081   // Hash returns differently on machines with different endianness
2082   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 1, 1, 1, 0));
2083 #else
2084   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
2085 #endif
2086 }
2087 
TEST(NNAPIDelegate,LSHProjectionSparse1DInputs)2088 TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) {
2089   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});
2090 
2091   m.SetInput({12345, 54321, 67890, 9876, -12345678});
2092   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
2093 
2094   m.Invoke();
2095 
2096 #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
2097     __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
2098   // Hash returns differently on machines with different endianness
2099   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
2100 #else
2101   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
2102 #endif
2103 }
2104 
TEST(NNAPIDelegate,LSHProjectionSparse3DInputs)2105 TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) {
2106   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});
2107 
2108   m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
2109               9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
2110   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
2111   m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});
2112 
2113   m.Invoke();
2114 
2115 #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
2116     __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
2117   // Hash returns differently on machines with different endianness
2118   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
2119 #else
2120   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
2121 #endif
2122 }
2123 
2124 class BaseActivationsOpModel : public SingleOpModelWithNNAPI {
2125  public:
2126   // Most activations don't take any options, so this constructor works for
2127   // them.
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input)2128   BaseActivationsOpModel(BuiltinOperator type, const TensorData& input) {
2129     input_ = AddInput(input);
2130     if (input.type == TensorType_UINT8) {
2131       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
2132     } else {
2133       output_ = AddOutput({input.type, {}});
2134     }
2135     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
2136     BuildInterpreterWithNNAPI({GetShape(input_)});
2137   }
2138 
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input,const TensorData & output)2139   BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
2140                          const TensorData& output) {
2141     input_ = AddInput(input);
2142     output_ = AddOutput(output);
2143     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
2144     BuildInterpreterWithNNAPI({GetShape(input_)});
2145   }
2146 
2147  protected:
2148   int input_;
2149   int output_;
2150 };
2151 
2152 class FloatActivationsOpModel : public BaseActivationsOpModel {
2153  public:
2154   using BaseActivationsOpModel::BaseActivationsOpModel;
2155 
SetInput(std::initializer_list<float> data)2156   void SetInput(std::initializer_list<float> data) {
2157     PopulateTensor(input_, data);
2158   }
GetOutput()2159   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2160 };
2161 
2162 const float kQuantizedTolerance = 2 * (1. / 256);
2163 
2164 class QuantizedActivationsOpModel : public BaseActivationsOpModel {
2165  public:
2166   using BaseActivationsOpModel::BaseActivationsOpModel;
2167 
2168   template <typename T>
SetInput(std::initializer_list<float> data)2169   void SetInput(std::initializer_list<float> data) {
2170     QuantizeAndPopulate<T>(input_, data);
2171   }
2172   template <typename T>
2173 
GetOutput()2174   std::vector<T> GetOutput() {
2175     return ExtractVector<T>(output_);
2176   }
2177   template <typename T>
GetDequantizedOutput()2178   std::vector<float> GetDequantizedOutput() {
2179     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
2180                          GetZeroPoint(output_));
2181   }
2182 };
2183 
TEST(NNAPIDelegate,Relu)2184 TEST(NNAPIDelegate, Relu) {
2185   FloatActivationsOpModel m(BuiltinOperator_RELU,
2186                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2187   m.SetInput({
2188       0, -6, 2, 4,   //
2189       3, -2, 10, 1,  //
2190   });
2191   m.Invoke();
2192   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2193                                  0, 0, 2, 4,   //
2194                                  3, 0, 10, 1,  //
2195                              }));
2196 }
2197 
TEST(NNAPIDelegate,Relu1)2198 TEST(NNAPIDelegate, Relu1) {
2199   FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1,
2200                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2201   m.SetInput({
2202       0.0, -0.6, 0.2, -0.4,  //
2203       0.3, -2.0, 1.1, -0.1,  //
2204   });
2205   m.Invoke();
2206   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2207                                  0.0, -0.6, 0.2, -0.4,  //
2208                                  0.3, -1.0, 1.0, -0.1,  //
2209                              }));
2210 }
2211 
TEST(NNAPIDelegate,Relu6)2212 TEST(NNAPIDelegate, Relu6) {
2213   FloatActivationsOpModel m(BuiltinOperator_RELU6,
2214                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2215   m.SetInput({
2216       0, -6, 2, 4,   //
2217       3, -2, 10, 1,  //
2218   });
2219   m.Invoke();
2220   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2221                                  0, 0, 2, 4,  //
2222                                  3, 0, 6, 1,  //
2223                              }));
2224 }
2225 
TEST(NNAPIDelegate,LogisticFloat)2226 TEST(NNAPIDelegate, LogisticFloat) {
2227   FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
2228                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2229   m.SetInput({
2230       0, -6, 2, 4,   //
2231       3, -2, 10, 1,  //
2232   });
2233   m.Invoke();
2234   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2235                                  0.5, 0.002473, 0.880797, 0.982014,       //
2236                                  0.952574, 0.119203, 0.999955, 0.731059,  //
2237                              }));
2238 }
2239 
TEST(NNAPIDelegate,LogisticQuantized)2240 TEST(NNAPIDelegate, LogisticQuantized) {
2241   QuantizedActivationsOpModel m(
2242       BuiltinOperator_LOGISTIC,
2243       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
2244   m.SetInput<uint8_t>({
2245       0, -6, 2, 4,   //
2246       3, -2, 10, 1,  //
2247   });
2248   m.Invoke();
2249   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2250               ElementsAreArray(ArrayFloatNear(
2251                   {
2252                       0.5, 0.002473, 0.880797, 0.982014,       //
2253                       0.952574, 0.119203, 0.999955, 0.731059,  //
2254                   },
2255                   kQuantizedTolerance)));
2256   EXPECT_THAT(m.GetOutput<uint8_t>(),
2257               testing::Pointwise(QuantizedNear(),
2258                                  {128, 1, 227, 251, 244, 32, 255, 188}));
2259 }
2260 
2261 class ResizeBilinearOpModel : public SingleOpModelWithNNAPI {
2262  public:
ResizeBilinearOpModel(const TensorData & input,std::initializer_list<int> size_data)2263   ResizeBilinearOpModel(const TensorData& input,
2264                         std::initializer_list<int> size_data) {
2265     bool const_size = size_data.size() != 0;
2266     input_ = AddInput(input);
2267     if (const_size) {
2268       size_ = AddConstInput(TensorType_INT32, size_data, {2});
2269     } else {
2270       size_ = AddInput({TensorType_INT32, {2}});
2271     }
2272     output_ = AddOutput(input.type);
2273     SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
2274                  BuiltinOptions_ResizeBilinearOptions,
2275                  CreateResizeBilinearOptions(builder_).Union());
2276     if (const_size) {
2277       BuildInterpreterWithNNAPI({GetShape(input_)});
2278     } else {
2279       BuildInterpreterWithNNAPI({GetShape(input_), GetShape(size_)});
2280     }
2281   }
2282 
2283   template <typename T>
SetInput(std::initializer_list<T> data)2284   void SetInput(std::initializer_list<T> data) {
2285     PopulateTensor(input_, data);
2286   }
SetSize(std::initializer_list<int> data)2287   void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
2288 
2289   template <typename T>
GetOutput()2290   std::vector<T> GetOutput() {
2291     return ExtractVector<T>(output_);
2292   }
2293 
2294  private:
2295   int input_;
2296   int size_;
2297   int output_;
2298 };
2299 
TEST(ResizeBilinear,Horizontal)2300 TEST(ResizeBilinear, Horizontal) {
2301   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
2302   m.SetInput<float>({3, 6});
2303   m.SetSize({1, 3});
2304   m.Invoke();
2305   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({3, 5, 6}));
2306 }
2307 
TEST(ResizeBilinear,HorizontalConstant)2308 TEST(ResizeBilinear, HorizontalConstant) {
2309   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
2310   const_m.SetInput<float>({3, 6});
2311   const_m.Invoke();
2312   EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({3, 5, 6}));
2313 }
2314 
TEST(ResizeBilinear,Vertical)2315 TEST(ResizeBilinear, Vertical) {
2316   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
2317   m.SetInput<float>({3, 9});
2318   m.SetSize({3, 1});
2319   m.Invoke();
2320   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({3, 7, 9}));
2321 }
2322 
TEST(ResizeBilinear,VerticalConstant)2323 TEST(ResizeBilinear, VerticalConstant) {
2324   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
2325   const_m.SetInput<float>({3, 9});
2326   const_m.Invoke();
2327   EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({3, 7, 9}));
2328 }
2329 
TEST(ResizeBilinear,TwoDimensional)2330 TEST(ResizeBilinear, TwoDimensional) {
2331   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
2332   m.SetInput<float>({
2333       3, 6,  //
2334       9, 12  //
2335   });
2336   m.SetSize({3, 3});
2337   m.Invoke();
2338   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({
2339                                         3, 5, 6,    //
2340                                         7, 9, 10,   //
2341                                         9, 11, 12,  //
2342                                     }));
2343 }
2344 
TEST(ResizeBilinear,TwoDimensionalConstant)2345 TEST(ResizeBilinear, TwoDimensionalConstant) {
2346   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
2347   const_m.SetInput<float>({
2348       3, 6,  //
2349       9, 12  //
2350   });
2351   const_m.Invoke();
2352   EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({
2353                                               3, 5, 6,    //
2354                                               7, 9, 10,   //
2355                                               9, 11, 12,  //
2356                                           }));
2357 }
2358 
2359 template <typename T>
2360 class PadOpModel : public SingleOpModelWithNNAPI {
2361  public:
SetInput(std::initializer_list<T> data)2362   void SetInput(std::initializer_list<T> data) {
2363     PopulateTensor<T>(input_, data);
2364   }
2365 
2366   template <typename QuantizedInputOutput>
SetQuantizedInput(std::initializer_list<float> data)2367   void SetQuantizedInput(std::initializer_list<float> data) {
2368     QuantizeAndPopulate<QuantizedInputOutput>(input_, data);
2369   }
2370 
2371   template <typename QuantizedInputOutput>
SetQuantizedPadValue(float data)2372   void SetQuantizedPadValue(float data) {
2373     QuantizeAndPopulate<QuantizedInputOutput>(constant_values_, {data});
2374   }
2375 
SetPaddings(std::initializer_list<int> paddings)2376   void SetPaddings(std::initializer_list<int> paddings) {
2377     PopulateTensor<int>(paddings_, paddings);
2378   }
2379 
GetOutput()2380   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()2381   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
2382 
2383   template <typename QuantizedInputOutput>
GetDequantizedOutput()2384   std::vector<float> GetDequantizedOutput() {
2385     return Dequantize<QuantizedInputOutput>(
2386         ExtractVector<QuantizedInputOutput>(output_), GetScale(output_),
2387         GetZeroPoint(output_));
2388   }
2389 
2390  protected:
2391   int input_;
2392   int output_;
2393   int paddings_;
2394   int constant_values_;
2395 };
2396 
2397 class PadOpConstModel : public PadOpModel<float> {
2398  public:
PadOpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,const TensorData & output)2399   PadOpConstModel(const TensorData& input,
2400                   std::initializer_list<int> paddings_shape,
2401                   std::initializer_list<int> paddings,
2402                   const TensorData& output) {
2403     input_ = AddInput(input);
2404     paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
2405     output_ = AddOutput(output);
2406 
2407     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
2408                  CreatePadOptions(builder_).Union());
2409     BuildInterpreterWithNNAPI({input.shape});
2410   }
2411 };
2412 
TEST(NNAPIDelegate,PadAdvancedConstTest)2413 TEST(NNAPIDelegate, PadAdvancedConstTest) {
2414   PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
2415                     {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32});
2416   m.SetInput({1, 2, 3, 4, 5, 6});
2417   m.Invoke();
2418   EXPECT_THAT(m.GetOutput(),
2419               NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
2420                                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
2421   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
2422 }
2423 
2424 class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI {
2425  public:
SetInput(std::initializer_list<float> data)2426   void SetInput(std::initializer_list<float> data) {
2427     PopulateTensor<float>(input_, data);
2428   }
2429 
SetBlockShape(std::initializer_list<int> data)2430   void SetBlockShape(std::initializer_list<int> data) {
2431     PopulateTensor<int>(block_shape_, data);
2432   }
2433 
SetPaddings(std::initializer_list<int> data)2434   void SetPaddings(std::initializer_list<int> data) {
2435     PopulateTensor<int>(paddings_, data);
2436   }
2437 
GetOutput()2438   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()2439   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
2440 
2441  protected:
2442   int input_;
2443   int block_shape_;
2444   int paddings_;
2445   int output_;
2446 };
2447 
2448 class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
2449  public:
SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> block_shape,std::initializer_list<int> paddings)2450   SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
2451                              std::initializer_list<int> block_shape,
2452                              std::initializer_list<int> paddings) {
2453     input_ = AddInput(TensorType_FLOAT32);
2454     block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
2455     paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
2456     output_ = AddOutput(TensorType_FLOAT32);
2457 
2458     SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
2459                  BuiltinOptions_SpaceToBatchNDOptions,
2460                  CreateSpaceToBatchNDOptions(builder_).Union());
2461     BuildInterpreterWithNNAPI({input_shape});
2462   }
2463 };
2464 
TEST(NNAPIDelegate,SpaceToBatchNDSimpleConstTest)2465 TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) {
2466   SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
2467   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
2468   m.Invoke();
2469   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
2470   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 9, 11, 2, 4, 10, 12, 5,
2471                                                   7, 13, 15, 6, 8, 14, 16}));
2472 }
2473 
TEST(NNAPIDelegate,SpaceToBatchNDMultipleInputBatchesConstTest)2474 TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) {
2475   SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
2476   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
2477   m.Invoke();
2478   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
2479   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 9, 11, 2, 4, 10, 12, 5,
2480                                                   7, 13, 15, 6, 8, 14, 16}));
2481 }
2482 
TEST(NNAPIDelegate,SpaceToBatchNDSimplePaddingConstTest)2483 TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) {
2484   SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
2485   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
2486   m.Invoke();
2487   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
2488   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2489                                  0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7,
2490                                  0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
2491                              }));
2492 }
2493 
TEST(NNAPIDelegate,SpaceToBatchNDComplexPaddingConstTest)2494 TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) {
2495   SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
2496   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
2497   m.Invoke();
2498   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
2499   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2500                                  0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0,
2501                                  0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0,
2502                                  0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
2503                              }));
2504 }
2505 
2506 template <typename input_type = float,
2507           TensorType tensor_input_type = TensorType_FLOAT32>
2508 class StridedSliceOpModel : public SingleOpModelWithNNAPI {
2509  public:
StridedSliceOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> begin_shape,std::initializer_list<int> begin_data,std::initializer_list<int> end_shape,std::initializer_list<int> end_data,std::initializer_list<int> strides_shape,std::initializer_list<int> strides_data,int begin_mask,int end_mask,int ellipsis_mask,int new_axis_mask,int shrink_axis_mask)2510   StridedSliceOpModel(std::initializer_list<int> input_shape,
2511                       std::initializer_list<int> begin_shape,
2512                       std::initializer_list<int> begin_data,
2513                       std::initializer_list<int> end_shape,
2514                       std::initializer_list<int> end_data,
2515                       std::initializer_list<int> strides_shape,
2516                       std::initializer_list<int> strides_data, int begin_mask,
2517                       int end_mask, int ellipsis_mask, int new_axis_mask,
2518                       int shrink_axis_mask) {
2519     input_ = AddInput(tensor_input_type);
2520     begin_ = AddConstInput(TensorType_INT32, begin_data, begin_shape);
2521     end_ = AddConstInput(TensorType_INT32, end_data, end_shape);
2522     strides_ = AddConstInput(TensorType_INT32, strides_data, strides_shape);
2523     output_ = AddOutput(tensor_input_type);
2524     SetBuiltinOp(
2525         BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
2526         CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
2527                                   new_axis_mask, shrink_axis_mask)
2528             .Union());
2529     BuildInterpreterWithNNAPI(
2530         {input_shape, begin_shape, end_shape, strides_shape});
2531   }
2532 
SetInput(std::initializer_list<input_type> data)2533   void SetInput(std::initializer_list<input_type> data) {
2534     PopulateTensor<input_type>(input_, data);
2535   }
2536 
GetOutput()2537   std::vector<input_type> GetOutput() {
2538     return ExtractVector<input_type>(output_);
2539   }
GetOutputShape()2540   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
2541 
2542  private:
2543   int input_;
2544   int begin_;
2545   int end_;
2546   int strides_;
2547   int output_;
2548 };
2549 
TEST(StridedSliceOpTest,In1D)2550 TEST(StridedSliceOpTest, In1D) {
2551   StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 0, 0, 0, 0, 0);
2552   m.SetInput({1, 2, 3, 4});
2553   m.Invoke();
2554   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
2555   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({2, 3}));
2556 }
2557 
TEST(StridedSliceOpTest,In1D_BeginMask)2558 TEST(StridedSliceOpTest, In1D_BeginMask) {
2559   StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 1, 0, 0, 0, 0);
2560   m.SetInput({1, 2, 3, 4});
2561   m.Invoke();
2562   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
2563   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 2, 3}));
2564 }
2565 
TEST(StridedSliceOpTest,In2D_Stride2)2566 TEST(StridedSliceOpTest, In2D_Stride2) {
2567   StridedSliceOpModel<> m({2, 3}, {2}, {0, 0}, {2}, {2, 3}, {2}, {2, 2}, 0, 0,
2568                           0, 0, 0);
2569   m.SetInput({1, 2, 3, 4, 5, 6});
2570   m.Invoke();
2571   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
2572   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3}));
2573 }
2574 
TEST(StridedSliceOpTest,In2D_EndMask)2575 TEST(StridedSliceOpTest, In2D_EndMask) {
2576   StridedSliceOpModel<> m({2, 3}, {2}, {1, 0}, {2}, {2, 2}, {2}, {1, 1}, 0, 2,
2577                           0, 0, 0);
2578   m.SetInput({1, 2, 3, 4, 5, 6});
2579   m.Invoke();
2580   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
2581   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({4, 5, 6}));
2582 }
2583 
TEST(StridedSliceOpTest,In3D_IdentityShrinkAxis4)2584 TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
2585   StridedSliceOpModel<> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 1}, {3},
2586                           {1, 1, 1}, 0, 0, 0, 0, 4);
2587   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
2588   m.Invoke();
2589   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
2590   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 5, 7, 9, 11}));
2591 }
2592 
2593 static float rnn_input[] = {
2594     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
2595     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
2596     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
2597     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
2598     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
2599     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
2600     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
2601     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
2602     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
2603     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
2604     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
2605     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
2606     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
2607     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
2608     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
2609     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
2610     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
2611     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
2612     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
2613     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
2614     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
2615     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
2616     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
2617     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
2618     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
2619     0.93455386,   -0.6324693,   -0.083922029};
2620 
2621 static float rnn_golden_output[] = {
2622     0.496726,   0,          0.965996,  0,         0.0584254, 0,
2623     0,          0.12315,    0,         0,         0.612266,  0.456601,
2624     0,          0.52286,    1.16099,   0.0291232,
2625 
2626     0,          0,          0.524901,  0,         0,         0,
2627     0,          1.02116,    0,         1.35762,   0,         0.356909,
2628     0.436415,   0.0355727,  0,         0,
2629 
2630     0,          0,          0,         0.262335,  0,         0,
2631     0,          1.33992,    0,         2.9739,    0,         0,
2632     1.31914,    2.66147,    0,         0,
2633 
2634     0.942568,   0,          0,         0,         0.025507,  0,
2635     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
2636     0.8158,     1.21805,    0.586239,  0.25427,
2637 
2638     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
2639     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
2640     0,          1.22031,    1.30117,   0.495867,
2641 
2642     0.222187,   0,          0.72725,   0,         0.767003,  0,
2643     0,          0.147835,   0,         0,         0,         0.608758,
2644     0.469394,   0.00720298, 0.927537,  0,
2645 
2646     0.856974,   0.424257,   0,         0,         0.937329,  0,
2647     0,          0,          0.476425,  0,         0.566017,  0.418462,
2648     0.141911,   0.996214,   1.13063,   0,
2649 
2650     0.967899,   0,          0,         0,         0.0831304, 0,
2651     0,          1.00378,    0,         0,         0,         1.44818,
2652     1.01768,    0.943891,   0.502745,  0,
2653 
2654     0.940135,   0,          0,         0,         0,         0,
2655     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
2656     1.30225,    1.59644,    0.70222,   0,
2657 
2658     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
2659     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
2660     0.0454298,  0.300267,   0.562784,  0.395095,
2661 
2662     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
2663     0,          0,          0,         0.735363,  0.0759267, 1.91017,
2664     0.941888,   0,          0,         0,
2665 
2666     0,          0,          1.5909,    0,         0,         0,
2667     0,          0.5755,     0,         0.184687,  0,         1.56296,
2668     0.625285,   0,          0,         0,
2669 
2670     0,          0,          0.0857888, 0,         0,         0,
2671     0,          0.488383,   0.252786,  0,         0,         0,
2672     1.02817,    1.85665,    0,         0,
2673 
2674     0.00981836, 0,          1.06371,   0,         0,         0,
2675     0,          0,          0,         0.290445,  0.316406,  0,
2676     0.304161,   1.25079,    0.0707152, 0,
2677 
2678     0.986264,   0.309201,   0,         0,         0,         0,
2679     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
2680     0.524981,   1.92076,    2.07013,   0.333244,
2681 
2682     0.415153,   0.210318,   0,         0,         0,         0,
2683     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
2684     0.628881,   3.58099,    1.49974,   0};
2685 
2686 static std::initializer_list<float> rnn_weights = {
2687     0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
2688     0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
2689     0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
2690     -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
2691     -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
2692     -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
2693     -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
2694     0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
2695     0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
2696     0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
2697     -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
2698     0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
2699     -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
2700     -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
2701     0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
2702     0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
2703     0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
2704     -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
2705     0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
2706     0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
2707     -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
2708     0.277308,    0.415818};
2709 
2710 static std::initializer_list<float> rnn_recurrent_weights = {
2711     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2712     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2713     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2714     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2715     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2716     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2717     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2718     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2719     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2720     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2721     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2722     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2723     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2724     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2725     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2726     0.1};
2727 
2728 static std::initializer_list<float> rnn_bias = {
2729     0.065691948, -0.69055247, 0.1107955,  -0.97084129, -0.23957068, -0.23566568,
2730     -0.389184,   0.47481549,  -0.4791103, 0.29931796,  0.10463274,  0.83918178,
2731     0.37197268,  0.61957061,  0.3956964,  -0.37609905};
2732 
2733 class RNNOpModel : public SingleOpModelWithNNAPI {
2734  public:
RNNOpModel(int batches,int units,int size,const TensorType weights=TensorType_FLOAT32,const TensorType recurrent_weights=TensorType_FLOAT32)2735   RNNOpModel(int batches, int units, int size,
2736              const TensorType weights = TensorType_FLOAT32,
2737              const TensorType recurrent_weights = TensorType_FLOAT32)
2738       : batches_(batches), units_(units), input_size_(size) {
2739     input_ = AddInput(TensorType_FLOAT32);
2740     weights_ = AddInput(weights);
2741     recurrent_weights_ = AddInput(recurrent_weights);
2742     bias_ = AddInput(TensorType_FLOAT32);
2743     hidden_state_ = AddVariableInput(TensorType_FLOAT32);
2744     output_ = AddOutput(TensorType_FLOAT32);
2745     SetBuiltinOp(
2746         BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
2747         CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
2748     BuildInterpreterWithNNAPI({
2749         {batches_, input_size_},  // input tensor
2750         {units_, input_size_},    // weights tensor
2751         {units_, units_},         // recurrent weights tensor
2752         {units_},                 // bias tensor
2753         {batches_, units_}        // hidden state tensor
2754     });
2755   }
2756 
SetBias(std::initializer_list<float> f)2757   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
2758 
SetWeights(std::initializer_list<float> f)2759   void SetWeights(std::initializer_list<float> f) {
2760     PopulateTensor(weights_, f);
2761   }
2762 
SetRecurrentWeights(std::initializer_list<float> f)2763   void SetRecurrentWeights(std::initializer_list<float> f) {
2764     PopulateTensor(recurrent_weights_, f);
2765   }
2766 
SetInput(std::initializer_list<float> data)2767   void SetInput(std::initializer_list<float> data) {
2768     PopulateTensor(input_, data);
2769   }
2770 
SetInput(int offset,float * begin,float * end)2771   void SetInput(int offset, float* begin, float* end) {
2772     PopulateTensor(input_, offset, begin, end);
2773   }
2774 
GetOutput()2775   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2776 
input_size()2777   int input_size() { return input_size_; }
num_units()2778   int num_units() { return units_; }
num_batches()2779   int num_batches() { return batches_; }
2780 
2781  protected:
2782   int input_;
2783   int weights_;
2784   int recurrent_weights_;
2785   int bias_;
2786   int hidden_state_;
2787   int output_;
2788 
2789   int batches_;
2790   int units_;
2791   int input_size_;
2792 };
2793 
TEST(NNAPIDelegate,RnnBlackBoxTest)2794 TEST(NNAPIDelegate, RnnBlackBoxTest) {
2795   RNNOpModel rnn(2, 16, 8);
2796   rnn.SetWeights(rnn_weights);
2797   rnn.SetBias(rnn_bias);
2798   rnn.SetRecurrentWeights(rnn_recurrent_weights);
2799 
2800   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
2801                                   (rnn.input_size() * rnn.num_batches());
2802 
2803   for (int i = 0; i < input_sequence_size; i++) {
2804     float* batch_start = rnn_input + i * rnn.input_size();
2805     float* batch_end = batch_start + rnn.input_size();
2806     rnn.SetInput(0, batch_start, batch_end);
2807     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
2808 
2809     rnn.Invoke();
2810 
2811     float* golden_start = rnn_golden_output + i * rnn.num_units();
2812     float* golden_end = golden_start + rnn.num_units();
2813     std::vector<float> expected;
2814     expected.insert(expected.end(), golden_start, golden_end);
2815     expected.insert(expected.end(), golden_start, golden_end);
2816 
2817     EXPECT_THAT(rnn.GetOutput(), NnapiArrayFloatNear(expected));
2818   }
2819 }
2820 
2821 static float svdf_input[] = {
2822     0.12609188,  -0.46347019, -0.89598465,
2823     0.35867718,  0.36897406,  0.73463392,
2824 
2825     0.14278367,  -1.64410412, -0.75222826,
2826     -0.57290924, 0.12729003,  0.7567004,
2827 
2828     0.49837467,  0.19278903,  0.26584083,
2829     0.17660543,  0.52949083,  -0.77931279,
2830 
2831     -0.11186574, 0.13164264,  -0.05349274,
2832     -0.72674477, -0.5683046,  0.55900657,
2833 
2834     -0.68892461, 0.37783599,  0.18263303,
2835     -0.63690937, 0.44483393,  -0.71817774,
2836 
2837     -0.81299269, -0.86831826, 1.43940818,
2838     -0.95760226, 1.82078898,  0.71135032,
2839 
2840     -1.45006323, -0.82251364, -1.69082689,
2841     -1.65087092, -1.89238167, 1.54172635,
2842 
2843     0.03966608,  -0.24936394, -0.77526885,
2844     2.06740379,  -1.51439476, 1.43768692,
2845 
2846     0.11771342,  -0.23761693, -0.65898693,
2847     0.31088525,  -1.55601168, -0.87661445,
2848 
2849     -0.89477462, 1.67204106,  -0.53235275,
2850     -0.6230064,  0.29819036,  1.06939757,
2851 };
2852 
2853 static float svdf_golden_output_rank_1[] = {
2854     0.014899,    -0.0517661,  -0.143725,   -0.00271883,
2855     -0.03004015, 0.09565311,  0.1587342,   0.00784263,
2856 
2857     0.068281,    -0.162217,   -0.152268,   0.00323521,
2858     0.01582633,  0.03858774,  -0.03001583, -0.02671271,
2859 
2860     -0.0317821,  -0.0333089,  0.0609602,   0.0333759,
2861     -0.01432795, 0.05524484,  0.1101355,   -0.02382665,
2862 
2863     -0.00623099, -0.077701,   -0.391193,   -0.0136691,
2864     -0.02333033, 0.02293761,  0.12338032,  0.04326871,
2865 
2866     0.201551,    -0.164607,   -0.179462,   -0.0592739,
2867     0.01064911,  -0.17503069, 0.07821996,  -0.00224009,
2868 
2869     0.0886511,   -0.0875401,  -0.269283,   0.0281379,
2870     -0.02282338, 0.09741908,  0.32973239,  0.12281385,
2871 
2872     -0.201174,   -0.586145,   -0.628624,   -0.0330412,
2873     0.24780814,  -0.39304617, -0.22473189, 0.02589256,
2874 
2875     -0.0839096,  -0.299329,   0.108746,    0.109808,
2876     0.10084175,  -0.06416984, 0.28936723,  0.0026358,
2877 
2878     0.419114,    -0.237824,   -0.422627,   0.175115,
2879     -0.2314795,  -0.18584411, -0.4228974,  -0.12928449,
2880 
2881     0.36726,     -0.522303,   -0.456502,   -0.175475,
2882     0.17012937,  -0.34447709, 0.38505614,  -0.28158101,
2883 };
2884 
2885 static float svdf_golden_output_rank_2[] = {
2886     -0.09623547, -0.10193135, 0.11083051,  -0.0347917,
2887     0.1141196,   0.12965347,  -0.12652366, 0.01007236,
2888 
2889     -0.16396809, -0.21247184, 0.11259045,  -0.04156673,
2890     0.10132131,  -0.06143532, -0.00924693, 0.10084561,
2891 
2892     0.01257364,  0.0506071,   -0.19287863, -0.07162561,
2893     -0.02033747, 0.22673416,  0.15487903,  0.02525555,
2894 
2895     -0.1411963,  -0.37054959, 0.01774767,  0.05867489,
2896     0.09607603,  -0.0141301,  -0.08995658, 0.12867066,
2897 
2898     -0.27142537, -0.16955489, 0.18521598,  -0.12528358,
2899     0.00331409,  0.11167502,  0.02218599,  -0.07309391,
2900 
2901     0.09593632,  -0.28361851, -0.0773851,  0.17199151,
2902     -0.00075242, 0.33691186,  -0.1536046,  0.16572715,
2903 
2904     -0.27916506, -0.27626723, 0.42615682,  0.3225764,
2905     -0.37472126, -0.55655634, -0.05013514, 0.289112,
2906 
2907     -0.24418658, 0.07540751,  -0.1940318,  -0.08911639,
2908     0.00732617,  0.46737891,  0.26449674,  0.24888524,
2909 
2910     -0.17225097, -0.54660404, -0.38795233, 0.08389944,
2911     0.07736043,  -0.28260678, 0.15666828,  1.14949894,
2912 
2913     -0.57454878, -0.64704704, 0.73235172,  -0.34616736,
2914     0.21120001,  -0.22927976, 0.02455296,  -0.35906726,
2915 };
2916 
2917 class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
2918  public:
BaseSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType weights_feature_type=TensorType_FLOAT32,TensorType weights_time_type=TensorType_FLOAT32)2919   BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
2920                   int rank,
2921                   TensorType weights_feature_type = TensorType_FLOAT32,
2922                   TensorType weights_time_type = TensorType_FLOAT32)
2923       : batches_(batches),
2924         units_(units),
2925         input_size_(input_size),
2926         memory_size_(memory_size),
2927         rank_(rank) {
2928     input_ = AddInput(TensorType_FLOAT32);
2929     weights_feature_ = AddInput(weights_feature_type);
2930     weights_time_ = AddInput(weights_time_type);
2931     // TODO(b/121383394) : figure out why optional bias causes TFLite segfault
2932     // when using NNAPI delegate.
2933     bias_ = AddInput(TensorType_FLOAT32);
2934     const int num_filters = units * rank;
2935     activation_state_ = AddVariableInput(
2936         TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}});
2937     output_ = AddOutput(TensorType_FLOAT32);
2938     SetBuiltinOp(
2939         BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
2940         CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
2941     BuildInterpreterWithNNAPI({
2942         {batches_, input_size_},              // input tensor
2943         {units_ * rank, input_size_},         // weights_feature tensor
2944         {units_ * rank, memory_size_},        // weights_time tensor
2945         {units_},                             // bias tensor
2946         {batches, memory_size * num_filters}  // activation_state tensor
2947     });
2948     // TODO(b/121383394) : remove once the optional bias bug is fixed.
2949     PopulateTensor(bias_, std::vector<float>(units_));
2950   }
2951 
2952   // Populates the weights_feature tensor.
SetWeightsFeature(std::initializer_list<float> f)2953   void SetWeightsFeature(std::initializer_list<float> f) {
2954     PopulateTensor(weights_feature_, f);
2955   }
2956 
2957   // Populates the weights_time tensor.
SetWeightsTime(std::initializer_list<float> f)2958   void SetWeightsTime(std::initializer_list<float> f) {
2959     PopulateTensor(weights_time_, f);
2960   }
2961 
2962   // Populates the input tensor.
SetInput(int offset,float * begin,float * end)2963   void SetInput(int offset, float* begin, float* end) {
2964     PopulateTensor(input_, offset, begin, end);
2965   }
2966 
2967   // Extracts the output tensor from the SVDF op.
GetOutput()2968   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2969 
input_size()2970   int input_size() { return input_size_; }
num_units()2971   int num_units() { return units_; }
num_batches()2972   int num_batches() { return batches_; }
2973 
2974  protected:
2975   int input_;
2976   int weights_feature_;
2977   int weights_time_;
2978   int bias_;
2979   int activation_state_;
2980   int output_;
2981 
2982   int batches_;
2983   int units_;
2984   int input_size_;
2985   int memory_size_;
2986   int rank_;
2987 };
2988 
2989 class SVDFOpModel : public BaseSVDFOpModel {
2990  public:
2991   using BaseSVDFOpModel::BaseSVDFOpModel;
2992 };
2993 
2994 class SVDFOpTest : public ::testing::Test {
2995  protected:
VerifyGoldens(float golden_input[],float golden_output[],int golden_size,BaseSVDFOpModel * svdf,float tolerance=1e-5)2996   void VerifyGoldens(float golden_input[], float golden_output[],
2997                      int golden_size, BaseSVDFOpModel* svdf,
2998                      float tolerance = 1e-5) {
2999     const int svdf_num_batches = svdf->num_batches();
3000     const int svdf_input_size = svdf->input_size();
3001     const int svdf_num_units = svdf->num_units();
3002     const int input_sequence_size =
3003         golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
3004     // Going over each input batch, setting the input tensor, invoking the SVDF
3005     // op and checking the output with the expected golden values.
3006     for (int i = 0; i < input_sequence_size; i++) {
3007       float* batch_start =
3008           golden_input + i * svdf_input_size * svdf_num_batches;
3009       float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
3010       svdf->SetInput(0, batch_start, batch_end);
3011 
3012       svdf->Invoke();
3013 
3014       const float* golden_start =
3015           golden_output + i * svdf_num_units * svdf_num_batches;
3016       const float* golden_end =
3017           golden_start + svdf_num_units * svdf_num_batches;
3018       std::vector<float> expected;
3019       expected.insert(expected.end(), golden_start, golden_end);
3020 
3021       EXPECT_THAT(svdf->GetOutput(),
3022                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
3023     }
3024   }
3025 };
3026 
TEST_F(SVDFOpTest,BlackBoxTestRank1)3027 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
3028   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
3029                    /*memory_size=*/10, /*rank=*/1);
3030   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
3031                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
3032                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
3033 
3034   svdf.SetWeightsTime(
3035       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
3036        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
3037 
3038        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
3039        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
3040 
3041        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
3042        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
3043 
3044        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
3045        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
3046 
3047   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
3048                 &svdf);
3049 }
3050 
TEST_F(SVDFOpTest,BlackBoxTestRank2)3051 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
3052   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
3053                    /*memory_size=*/10, /*rank=*/2);
3054   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
3055                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
3056                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
3057                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
3058                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
3059                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
3060 
3061   svdf.SetWeightsTime(
3062       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
3063        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
3064 
3065        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
3066        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
3067 
3068        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
3069        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
3070 
3071        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
3072        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
3073 
3074        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
3075        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
3076 
3077        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
3078        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
3079 
3080        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
3081        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
3082 
3083        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
3084        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
3085 
3086   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
3087                 &svdf);
3088 }
3089 
3090 class LSTMOpModel : public SingleOpModelWithNNAPI {
3091  public:
LSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<int>> & input_shapes,const TensorType weight_type)3092   LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
3093               bool use_peephole, bool use_projection_weights,
3094               bool use_projection_bias, float cell_clip, float proj_clip,
3095               const std::vector<std::vector<int>>& input_shapes,
3096               const TensorType weight_type)
3097       : n_batch_(n_batch),
3098         n_input_(n_input),
3099         n_cell_(n_cell),
3100         n_output_(n_output),
3101         weight_type_(weight_type) {
3102     input_ = AddInput(TensorType_FLOAT32);
3103 
3104     if (use_cifg) {
3105       input_to_input_weights_ = AddNullInput();
3106     } else {
3107       input_to_input_weights_ = AddInput(weight_type);
3108     }
3109 
3110     input_to_forget_weights_ = AddInput(weight_type);
3111     input_to_cell_weights_ = AddInput(weight_type);
3112     input_to_output_weights_ = AddInput(weight_type);
3113 
3114     if (use_cifg) {
3115       recurrent_to_input_weights_ = AddNullInput();
3116     } else {
3117       recurrent_to_input_weights_ = AddInput(weight_type);
3118     }
3119 
3120     recurrent_to_forget_weights_ = AddInput(weight_type);
3121     recurrent_to_cell_weights_ = AddInput(weight_type);
3122     recurrent_to_output_weights_ = AddInput(weight_type);
3123 
3124     if (use_peephole) {
3125       if (use_cifg) {
3126         cell_to_input_weights_ = AddNullInput();
3127       } else {
3128         cell_to_input_weights_ = AddInput(weight_type);
3129       }
3130       cell_to_forget_weights_ = AddInput(weight_type);
3131       cell_to_output_weights_ = AddInput(weight_type);
3132     } else {
3133       cell_to_input_weights_ = AddNullInput();
3134       cell_to_forget_weights_ = AddNullInput();
3135       cell_to_output_weights_ = AddNullInput();
3136     }
3137 
3138     if (use_cifg) {
3139       input_gate_bias_ = AddNullInput();
3140     } else {
3141       input_gate_bias_ = AddInput(TensorType_FLOAT32);
3142     }
3143     forget_gate_bias_ = AddInput(TensorType_FLOAT32);
3144     cell_bias_ = AddInput(TensorType_FLOAT32);
3145     output_gate_bias_ = AddInput(TensorType_FLOAT32);
3146 
3147     if (use_projection_weights) {
3148       projection_weights_ = AddInput(weight_type);
3149       if (use_projection_bias) {
3150         projection_bias_ = AddInput(TensorType_FLOAT32);
3151       } else {
3152         projection_bias_ = AddNullInput();
3153       }
3154     } else {
3155       projection_weights_ = AddNullInput();
3156       projection_bias_ = AddNullInput();
3157     }
3158 
3159     // Adding the 2 input state tensors.
3160     input_activation_state_ = AddVariableInput(TensorType_FLOAT32);
3161     input_cell_state_ = AddVariableInput(TensorType_FLOAT32);
3162 
3163     const bool use_layer_norm = input_shapes.size() > 20;
3164     // Layer norm weights.
3165     if (use_layer_norm) {
3166       const int kInputLayerNormCoeffsIndex = 20;
3167       const int kForgetLayerNormCoeffsIndex = 21;
3168       const int kCellLayerNormCoeffsIndex = 22;
3169       const int kOutputLayerNormCoeffsIndex = 23;
3170 
3171       if (use_cifg) {
3172         input_layer_norm_coefficients_ = AddNullInput();
3173       } else {
3174         input_layer_norm_coefficients_ =
3175             AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes);
3176       }
3177       forget_layer_norm_coefficients_ =
3178           AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes);
3179       cell_layer_norm_coefficients_ =
3180           AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes);
3181       output_layer_norm_coefficients_ =
3182           AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes);
3183     }
3184 
3185     output_ = AddOutput(TensorType_FLOAT32);
3186 
3187     SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
3188                  CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
3189                                    cell_clip, proj_clip)
3190                      .Union());
3191     BuildInterpreterWithNNAPI(input_shapes);
3192   }
3193 
SetInputToInputWeights(const std::vector<float> & f)3194   void SetInputToInputWeights(const std::vector<float>& f) {
3195     SetData(input_to_input_weights_, weight_type_, f);
3196   }
3197 
SetInputToForgetWeights(const std::vector<float> & f)3198   void SetInputToForgetWeights(const std::vector<float>& f) {
3199     SetData(input_to_forget_weights_, weight_type_, f);
3200   }
3201 
SetInputToCellWeights(const std::vector<float> & f)3202   void SetInputToCellWeights(const std::vector<float>& f) {
3203     SetData(input_to_cell_weights_, weight_type_, f);
3204   }
3205 
SetInputToOutputWeights(const std::vector<float> & f)3206   void SetInputToOutputWeights(const std::vector<float>& f) {
3207     SetData(input_to_output_weights_, weight_type_, f);
3208   }
3209 
SetRecurrentToInputWeights(const std::vector<float> & f)3210   void SetRecurrentToInputWeights(const std::vector<float>& f) {
3211     SetData(recurrent_to_input_weights_, weight_type_, f);
3212   }
3213 
SetRecurrentToForgetWeights(const std::vector<float> & f)3214   void SetRecurrentToForgetWeights(const std::vector<float>& f) {
3215     SetData(recurrent_to_forget_weights_, weight_type_, f);
3216   }
3217 
SetRecurrentToCellWeights(const std::vector<float> & f)3218   void SetRecurrentToCellWeights(const std::vector<float>& f) {
3219     SetData(recurrent_to_cell_weights_, weight_type_, f);
3220   }
3221 
SetRecurrentToOutputWeights(const std::vector<float> & f)3222   void SetRecurrentToOutputWeights(const std::vector<float>& f) {
3223     SetData(recurrent_to_output_weights_, weight_type_, f);
3224   }
3225 
SetCellToInputWeights(const std::vector<float> & f)3226   void SetCellToInputWeights(const std::vector<float>& f) {
3227     SetData(cell_to_input_weights_, weight_type_, f);
3228   }
3229 
SetCellToForgetWeights(const std::vector<float> & f)3230   void SetCellToForgetWeights(const std::vector<float>& f) {
3231     SetData(cell_to_forget_weights_, weight_type_, f);
3232   }
3233 
SetCellToOutputWeights(const std::vector<float> & f)3234   void SetCellToOutputWeights(const std::vector<float>& f) {
3235     SetData(cell_to_output_weights_, weight_type_, f);
3236   }
3237 
SetInputGateBias(const std::vector<float> & f)3238   void SetInputGateBias(const std::vector<float>& f) {
3239     PopulateTensor(input_gate_bias_, f);
3240   }
3241 
SetForgetGateBias(const std::vector<float> & f)3242   void SetForgetGateBias(const std::vector<float>& f) {
3243     PopulateTensor(forget_gate_bias_, f);
3244   }
3245 
SetCellBias(const std::vector<float> & f)3246   void SetCellBias(const std::vector<float>& f) {
3247     PopulateTensor(cell_bias_, f);
3248   }
3249 
SetOutputGateBias(const std::vector<float> & f)3250   void SetOutputGateBias(const std::vector<float>& f) {
3251     PopulateTensor(output_gate_bias_, f);
3252   }
3253 
SetProjectionWeights(const std::vector<float> & f)3254   void SetProjectionWeights(const std::vector<float>& f) {
3255     SetData(projection_weights_, weight_type_, f);
3256   }
3257 
SetProjectionBias(const std::vector<float> & f)3258   void SetProjectionBias(const std::vector<float>& f) {
3259     PopulateTensor(projection_bias_, f);
3260   }
3261 
SetInputLayerNormCoefficients(const std::vector<float> & f)3262   void SetInputLayerNormCoefficients(const std::vector<float>& f) {
3263     PopulateTensor(input_layer_norm_coefficients_, f);
3264   }
3265 
SetForgetLayerNormCoefficients(const std::vector<float> & f)3266   void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
3267     PopulateTensor(forget_layer_norm_coefficients_, f);
3268   }
3269 
SetCellLayerNormCoefficients(const std::vector<float> & f)3270   void SetCellLayerNormCoefficients(const std::vector<float>& f) {
3271     PopulateTensor(cell_layer_norm_coefficients_, f);
3272   }
3273 
SetOutputLayerNormCoefficients(const std::vector<float> & f)3274   void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
3275     PopulateTensor(output_layer_norm_coefficients_, f);
3276   }
3277 
SetInput(int offset,const float * begin,const float * end)3278   void SetInput(int offset, const float* begin, const float* end) {
3279     PopulateTensor(input_, offset, const_cast<float*>(begin),
3280                    const_cast<float*>(end));
3281   }
3282 
GetOutput()3283   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
3284 
num_inputs()3285   int num_inputs() { return n_input_; }
num_outputs()3286   int num_outputs() { return n_output_; }
num_cells()3287   int num_cells() { return n_cell_; }
num_batches()3288   int num_batches() { return n_batch_; }
3289 
3290  protected:
3291   int input_;
3292   int input_to_input_weights_;
3293   int input_to_forget_weights_;
3294   int input_to_cell_weights_;
3295   int input_to_output_weights_;
3296 
3297   int recurrent_to_input_weights_;
3298   int recurrent_to_forget_weights_;
3299   int recurrent_to_cell_weights_;
3300   int recurrent_to_output_weights_;
3301 
3302   int cell_to_input_weights_;
3303   int cell_to_forget_weights_;
3304   int cell_to_output_weights_;
3305 
3306   int input_gate_bias_;
3307   int forget_gate_bias_;
3308   int cell_bias_;
3309   int output_gate_bias_;
3310 
3311   int projection_weights_;
3312   int projection_bias_;
3313   int input_activation_state_;
3314   int input_cell_state_;
3315 
3316   int input_layer_norm_coefficients_;
3317   int forget_layer_norm_coefficients_;
3318   int cell_layer_norm_coefficients_;
3319   int output_layer_norm_coefficients_;
3320 
3321   int output_;
3322   int output_state_;
3323   int cell_state_;
3324 
3325   int n_batch_;
3326   int n_input_;
3327   int n_cell_;
3328   int n_output_;
3329 
3330  private:
3331   const TensorType weight_type_;
3332 
AddLayerNormCoeffsTensor(int tensor_index,const std::vector<std::vector<int>> & input_shapes)3333   int AddLayerNormCoeffsTensor(
3334       int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
3335     if (input_shapes[tensor_index][0] != 0) {
3336       return AddInput(TensorType_FLOAT32);
3337     } else {
3338       return AddNullInput();
3339     }
3340   }
3341 };
3342 
3343 class BaseLstmTest : public ::testing::Test {
3344  protected:
3345   // Weights of the LSTM model. Some are optional.
3346   std::vector<float> input_to_input_weights_;
3347   std::vector<float> input_to_cell_weights_;
3348   std::vector<float> input_to_forget_weights_;
3349   std::vector<float> input_to_output_weights_;
3350   std::vector<float> input_gate_bias_;
3351   std::vector<float> cell_gate_bias_;
3352   std::vector<float> forget_gate_bias_;
3353   std::vector<float> output_gate_bias_;
3354   std::vector<float> recurrent_to_input_weights_;
3355   std::vector<float> recurrent_to_cell_weights_;
3356   std::vector<float> recurrent_to_forget_weights_;
3357   std::vector<float> recurrent_to_output_weights_;
3358   std::vector<float> cell_to_input_weights_;
3359   std::vector<float> cell_to_forget_weights_;
3360   std::vector<float> cell_to_output_weights_;
3361   std::vector<float> projection_weights_;
3362   std::vector<float> input_layer_norm_coefficients_;
3363   std::vector<float> forget_layer_norm_coefficients_;
3364   std::vector<float> cell_layer_norm_coefficients_;
3365   std::vector<float> output_layer_norm_coefficients_;
3366 
3367   // LSTM input is stored as num_batch x num_inputs vector.
3368   std::vector<std::vector<float>> lstm_input_;
3369   // LSTM output is stored as num_batch x num_outputs vector.
3370   std::vector<std::vector<float>> lstm_golden_output_;
3371 
3372   // Compares output up to tolerance to the result of the lstm given the input.
VerifyGoldens(const std::vector<std::vector<float>> & input,const std::vector<std::vector<float>> & output,LSTMOpModel * lstm,float tolerance=1e-5)3373   void VerifyGoldens(const std::vector<std::vector<float>>& input,
3374                      const std::vector<std::vector<float>>& output,
3375                      LSTMOpModel* lstm, float tolerance = 1e-5) {
3376     const int num_batches = input.size();
3377     EXPECT_GT(num_batches, 0);
3378     const int num_inputs = lstm->num_inputs();
3379     EXPECT_GT(num_inputs, 0);
3380     const int input_sequence_size = input[0].size() / num_inputs;
3381     EXPECT_GT(input_sequence_size, 0);
3382     for (int i = 0; i < input_sequence_size; ++i) {
3383       for (int b = 0; b < num_batches; ++b) {
3384         const float* batch_start = input[b].data() + i * num_inputs;
3385         const float* batch_end = batch_start + num_inputs;
3386 
3387         lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
3388       }
3389 
3390       lstm->Invoke();
3391 
3392       const int num_outputs = lstm->num_outputs();
3393       std::vector<float> expected;
3394       for (int b = 0; b < num_batches; ++b) {
3395         const float* golden_start_batch = output[b].data() + i * num_outputs;
3396         const float* golden_end_batch = golden_start_batch + num_outputs;
3397         expected.insert(expected.end(), golden_start_batch, golden_end_batch);
3398       }
3399       EXPECT_THAT(lstm->GetOutput(),
3400                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
3401     }
3402   }
3403 };
3404 
3405 class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()3406   void SetUp() override {
3407     input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
3408                                -0.34550029, 0.04266912,  -0.15680569,
3409                                -0.34856534, 0.43890524};
3410     input_to_cell_weights_ = {-0.50013041, 0.1370284,  0.11810488, 0.2013163,
3411                               -0.20583314, 0.44344562, 0.22077113, -0.29909778};
3412     input_to_forget_weights_ = {0.09701663,  0.20334584,  -0.50592935,
3413                                 -0.31343272, -0.40032279, 0.44781327,
3414                                 0.01387155,  -0.35593212};
3415     input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
3416                                 0.40525138,  0.44272184,  0.03897077,
3417                                 -0.1556896,  0.19487578};
3418     input_gate_bias_ = {0., 0., 0., 0.};
3419     cell_gate_bias_ = {0., 0., 0., 0.};
3420     forget_gate_bias_ = {1., 1., 1., 1.};
3421     output_gate_bias_ = {0., 0., 0., 0.};
3422 
3423     recurrent_to_input_weights_ = {
3424         -0.0063535,  -0.2042388,  0.31454784,  -0.35746509,
3425         0.28902304,  0.08183324,  -0.16555229, 0.02286911,
3426         -0.13566875, 0.03034258,  0.48091322,  -0.12528998,
3427         0.24077177,  -0.51332325, -0.33502164, 0.10629296};
3428 
3429     recurrent_to_cell_weights_ = {
3430         -0.3407414,  0.24443203,  -0.2078532,  0.26320225,
3431         0.05695659,  -0.00123841, -0.4744786,  -0.35869038,
3432         -0.06418842, -0.13502428, -0.501764,   0.22830659,
3433         -0.46367589, 0.26016325,  -0.03894562, -0.16368064};
3434 
3435     recurrent_to_forget_weights_ = {
3436         -0.48684245, -0.06655136, 0.42224967,  0.2112639,
3437         0.27654213,  0.20864892,  -0.07646349, 0.45877004,
3438         0.00141793,  -0.14609534, 0.36447752,  0.09196436,
3439         0.28053468,  0.01560611,  -0.20127171, -0.01140004};
3440 
3441     recurrent_to_output_weights_ = {
3442         0.43385774,  -0.17194885, 0.2718237,  0.09215671,
3443         0.24107647,  -0.39835793, 0.18212086, 0.01301402,
3444         0.48572797,  -0.50656658, 0.20047462, -0.20607421,
3445         -0.51818722, -0.15390486, 0.0468148,  0.39922136};
3446 
3447     lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
3448     lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
3449                             -0.03716109, 0.12507336, 0.41193449, -0.20860538,
3450                             -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
3451   }
3452 };
3453 
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)3454 TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
3455   const int n_batch = 1;
3456   const int n_input = 2;
3457   // n_cell and n_output have the same size when there is no projection.
3458   const int n_cell = 4;
3459   const int n_output = 4;
3460 
3461   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3462                    /*use_cifg=*/false, /*use_peephole=*/false,
3463                    /*use_projection_weights=*/false,
3464                    /*use_projection_bias=*/false,
3465                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3466                    {
3467                        {n_batch, n_input},  // input tensor
3468 
3469                        {n_cell, n_input},  // input_to_input_weight tensor
3470                        {n_cell, n_input},  // input_to_forget_weight tensor
3471                        {n_cell, n_input},  // input_to_cell_weight tensor
3472                        {n_cell, n_input},  // input_to_output_weight tensor
3473 
3474                        {n_cell, n_output},  // recurrent_to_input_weight_tensor
3475                        {n_cell, n_output},  // recurrent_to_forget_weight_tensor
3476                        {n_cell, n_output},  // recurrent_to_cell_weight_tensor
3477                        {n_cell, n_output},  // recurrent_to_output_weight_tensor
3478 
3479                        {0},  // cell_to_input_weight tensor
3480                        {0},  // cell_to_forget_weight tensor
3481                        {0},  // cell_to_output_weight tensor
3482 
3483                        {n_cell},  // input_gate_bias tensor
3484                        {n_cell},  // forget_gate_bias tensor
3485                        {n_cell},  // cell_bias tensor
3486                        {n_cell},  // output_gate_bias tensor
3487 
3488                        {0, 0},  // projection_weight tensor
3489                        {0},     // projection_bias tensor
3490 
3491                        {n_batch, n_output},  // activation_state tensor
3492                        {n_batch, n_cell},    // cell_state tensor
3493                    },
3494                    /*weight_type=*/TensorType_FLOAT32);
3495 
3496   lstm.SetInputToInputWeights(input_to_input_weights_);
3497   lstm.SetInputToCellWeights(input_to_cell_weights_);
3498   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3499   lstm.SetInputToOutputWeights(input_to_output_weights_);
3500 
3501   lstm.SetInputGateBias(input_gate_bias_);
3502   lstm.SetCellBias(cell_gate_bias_);
3503   lstm.SetForgetGateBias(forget_gate_bias_);
3504   lstm.SetOutputGateBias(output_gate_bias_);
3505 
3506   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
3507   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3508   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3509   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3510 
3511   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3512 }
3513 
3514 class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest
3515     : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
3516 
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,LstmBlackBoxTest)3517 TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
3518        LstmBlackBoxTest) {
3519   const int n_batch = 1;
3520   const int n_input = 2;
3521   // n_cell and n_output have the same size when there is no projection.
3522   const int n_cell = 4;
3523   const int n_output = 4;
3524 
3525   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3526                    /*use_cifg=*/false, /*use_peephole=*/false,
3527                    /*use_projection_weights=*/false,
3528                    /*use_projection_bias=*/false,
3529                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3530                    {
3531                        {n_batch, n_input},  // input tensor
3532 
3533                        {n_cell, n_input},  // input_to_input_weight tensor
3534                        {n_cell, n_input},  // input_to_forget_weight tensor
3535                        {n_cell, n_input},  // input_to_cell_weight tensor
3536                        {n_cell, n_input},  // input_to_output_weight tensor
3537 
3538                        {n_cell, n_output},  // recurrent_to_input_weight_tensor
3539                        {n_cell, n_output},  // recurrent_to_forget_weight_tensor
3540                        {n_cell, n_output},  // recurrent_to_cell_weight_tensor
3541                        {n_cell, n_output},  // recurrent_to_output_weight_tensor
3542 
3543                        {0},  // cell_to_input_weight tensor
3544                        {0},  // cell_to_forget_weight tensor
3545                        {0},  // cell_to_output_weight tensor
3546 
3547                        {n_cell},  // input_gate_bias tensor
3548                        {n_cell},  // forget_gate_bias tensor
3549                        {n_cell},  // cell_bias tensor
3550                        {n_cell},  // output_gate_bias tensor
3551 
3552                        {0, 0},  // projection_weight tensor
3553                        {0},     // projection_bias tensor
3554 
3555                        {n_batch, n_output},  // activation_state tensor
3556                        {n_batch, n_cell},    // cell_state tensor
3557 
3558                        {0},  // input_layer_norm_coefficient tensor
3559                        {0},  // forget_layer_norm_coefficient tensor
3560                        {0},  // cell_layer_norm_coefficient tensor
3561                        {0},  // output_layer_norm_coefficient tensor
3562                    },
3563                    /*weight_type=*/TensorType_FLOAT32);
3564 
3565   lstm.SetInputToInputWeights(input_to_input_weights_);
3566   lstm.SetInputToCellWeights(input_to_cell_weights_);
3567   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3568   lstm.SetInputToOutputWeights(input_to_output_weights_);
3569 
3570   lstm.SetInputGateBias(input_gate_bias_);
3571   lstm.SetCellBias(cell_gate_bias_);
3572   lstm.SetForgetGateBias(forget_gate_bias_);
3573   lstm.SetOutputGateBias(output_gate_bias_);
3574 
3575   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
3576   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3577   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3578   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3579 
3580   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3581 }
3582 
3583 class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()3584   void SetUp() override {
3585     input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
3586                               0.05100781,  0.04717243,  0.48944736,
3587                               -0.38535351, -0.17212132};
3588 
3589     input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
3590                                 -0.3633365,  -0.22755712, 0.28253698,
3591                                 0.24407166,  0.33826375};
3592 
3593     input_to_output_weights_ = {0.10725588,  -0.02335852, -0.55932593,
3594                                 -0.09426838, -0.44257352, 0.54939759,
3595                                 0.01533556,  0.42751634};
3596     cell_gate_bias_ = {0., 0., 0., 0.};
3597     forget_gate_bias_ = {1., 1., 1., 1.};
3598     output_gate_bias_ = {0., 0., 0., 0.};
3599 
3600     recurrent_to_cell_weights_ = {
3601         0.54066205,  -0.32668582, -0.43562764, -0.56094903,
3602         0.42957711,  0.01841056,  -0.32764608, -0.33027974,
3603         -0.10826075, 0.20675004,  0.19069612,  -0.03026325,
3604         -0.54532051, 0.33003211,  0.44901288,  0.21193194};
3605 
3606     recurrent_to_forget_weights_ = {
3607         -0.13832897, -0.0515101,  -0.2359007, -0.16661474,
3608         -0.14340827, 0.36986142,  0.23414481, 0.55899,
3609         0.10798943,  -0.41174671, 0.17751795, -0.34484994,
3610         -0.35874045, -0.11352962, 0.27268326, 0.54058349};
3611 
3612     recurrent_to_output_weights_ = {
3613         0.41613156, 0.42610586,  -0.16495961, -0.5663873,
3614         0.30579174, -0.05115908, -0.33941799, 0.23364776,
3615         0.11178309, 0.09481031,  -0.26424935, 0.46261835,
3616         0.50248802, 0.26114327,  -0.43736315, 0.33149987};
3617 
3618     cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
3619                                0.31544167};
3620     cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
3621                                -0.77109635};
3622 
3623     lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
3624     lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
3625                             -0.42312205, -0.01218222, 0.24201041, -0.08124574,
3626                             -0.358325, -0.04621704, 0.21641694, -0.06471302}};
3627   }
3628 };
3629 
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)3630 TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
3631   const int n_batch = 1;
3632   const int n_input = 2;
3633   // n_cell and n_output have the same size when there is no projection.
3634   const int n_cell = 4;
3635   const int n_output = 4;
3636 
3637   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3638                    /*use_cifg=*/true, /*use_peephole=*/true,
3639                    /*use_projection_weights=*/false,
3640                    /*use_projection_bias=*/false,
3641                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3642                    {
3643                        {n_batch, n_input},  // input tensor
3644 
3645                        {0, 0},             // input_to_input_weight tensor
3646                        {n_cell, n_input},  // input_to_forget_weight tensor
3647                        {n_cell, n_input},  // input_to_cell_weight tensor
3648                        {n_cell, n_input},  // input_to_output_weight tensor
3649 
3650                        {0, 0},              // recurrent_to_input_weight tensor
3651                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
3652                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
3653                        {n_cell, n_output},  // recurrent_to_output_weight tensor
3654 
3655                        {0},       // cell_to_input_weight tensor
3656                        {n_cell},  // cell_to_forget_weight tensor
3657                        {n_cell},  // cell_to_output_weight tensor
3658 
3659                        {0},       // input_gate_bias tensor
3660                        {n_cell},  // forget_gate_bias tensor
3661                        {n_cell},  // cell_bias tensor
3662                        {n_cell},  // output_gate_bias tensor
3663 
3664                        {0, 0},  // projection_weight tensor
3665                        {0},     // projection_bias tensor
3666 
3667                        {n_batch, n_output},  // activation_state tensor
3668                        {n_batch, n_cell},    // cell_state tensor
3669                    },
3670                    /*weight_type=*/TensorType_FLOAT32);
3671 
3672   lstm.SetInputToCellWeights(input_to_cell_weights_);
3673   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3674   lstm.SetInputToOutputWeights(input_to_output_weights_);
3675 
3676   lstm.SetCellBias(cell_gate_bias_);
3677   lstm.SetForgetGateBias(forget_gate_bias_);
3678   lstm.SetOutputGateBias(output_gate_bias_);
3679 
3680   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3681   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3682   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3683 
3684   lstm.SetCellToForgetWeights(cell_to_forget_weights_);
3685   lstm.SetCellToOutputWeights(cell_to_output_weights_);
3686 
3687   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3688 }
3689 
3690 class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
SetUp()3691   void SetUp() override {
3692     input_to_input_weights_ = {
3693         0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
3694         0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
3695         -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
3696         -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
3697         -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
3698         -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
3699         -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
3700         0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
3701         0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
3702         0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
3703         -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
3704         0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
3705         -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
3706         -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
3707         -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
3708         0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
3709         -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
3710         -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
3711         -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
3712         -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677};
3713 
3714     input_to_forget_weights_ = {
3715         -0.0018401089, -0.004852237, 0.03698424,    0.014181704,
3716         0.028273236,   -0.016726194, -0.05249759,   -0.10204261,
3717         0.00861066,    -0.040979505, -0.009899187,  0.01923892,
3718         -0.028177269,  -0.08535103,  -0.14585495,   0.10662567,
3719         -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
3720         0.0030784295,  0.076784775,  0.07463696,    0.094531395,
3721         0.0814421,     -0.12257899,  -0.033945758,  -0.031303465,
3722         0.045630626,   0.06843887,   -0.13492945,   -0.012480007,
3723         -0.0811829,    -0.07224499,  -0.09628791,   0.045100946,
3724         0.0012300825,  0.013964662,  0.099372394,   0.02543059,
3725         0.06958324,    0.034257296,  0.0482646,     0.06267997,
3726         0.052625068,   0.12784666,   0.07077897,    0.025725935,
3727         0.04165009,    0.07241905,   0.018668644,   -0.037377294,
3728         -0.06277783,   -0.08833636,  -0.040120605,  -0.011405586,
3729         -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
3730         0.05483423,    0.11449111,   0.11289652,    0.10939839,
3731         0.13396506,    -0.08402166,  -0.01901462,   -0.044678304,
3732         -0.07720565,   0.014350063,  -0.11757958,   -0.0652038,
3733         -0.08185733,   -0.076754324, -0.092614375,  0.10405491,
3734         0.052960336,   0.035755895,  0.035839386,   -0.012540553,
3735         0.036881298,   0.02913376,   0.03420159,    0.05448447,
3736         -0.054523353,  0.02582715,   0.02327355,    -0.011857179,
3737         -0.0011980024, -0.034641717, -0.026125094,  -0.17582615,
3738         -0.15923657,   -0.27486774,  -0.0006143371, 0.0001771948,
3739         -8.470171e-05, 0.02651807,   0.045790765,   0.06956496};
3740 
3741     input_to_cell_weights_ = {
3742         -0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
3743         -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
3744         -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
3745         -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
3746         -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
3747         0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
3748         -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
3749         0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
3750         -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
3751         -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
3752         -0.025174323,  0.0396852,     0.081777506,   0.06157468,
3753         0.10210095,    -0.009658194,  0.046511717,   0.03603906,
3754         0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
3755         0.053568836,   0.06408714,    0.12835667,    -0.008714329,
3756         -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
3757         -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
3758         -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
3759         -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
3760         -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
3761         -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
3762         0.05453865,    0.091149814,   0.06387331,    0.007518393,
3763         0.055960953,   0.069779344,   0.046411168,   0.10509911,
3764         0.07463894,    0.0075130584,  0.012850982,   0.04555431,
3765         0.056955688,   0.06555285,    0.050801456,   -0.009862683,
3766         0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042};
3767 
3768     input_to_output_weights_ = {
3769         -0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
3770         -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
3771         0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
3772         -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
3773         -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
3774         0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
3775         -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
3776         -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
3777         -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
3778         -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
3779         0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
3780         0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
3781         0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
3782         -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
3783         0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
3784         0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
3785         -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
3786         0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
3787         -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
3788         -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956};
3789 
3790     input_gate_bias_ = {0.02234832,   0.14757581,  0.18176508,  0.10380666,
3791                         0.053110216,  -0.06928846, -0.13942584, -0.11816189,
3792                         0.19483899,   0.03652339,  -0.10250295, 0.036714908,
3793                         -0.18426876,  0.036065217, 0.21810818,  0.02383196,
3794                         -0.043370757, 0.08690144,  -0.04444982, 0.00030581196};
3795 
3796     forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
3797                          0.11098921,  0.15378423,   0.09263801,  0.09790885,
3798                          0.09508917,  0.061199076,  0.07665568,  -0.015443159,
3799                          -0.03499149, 0.046190713,  0.08895977,  0.10899629,
3800                          0.40694186,  0.06030037,   0.012413437, -0.06108739};
3801 
3802     cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
3803                        -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
3804                        -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
3805                        -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
3806                        0.016178843,  0.1749513,    0.13975595,   0.92058027};
3807 
3808     output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469,   0.12648113,
3809                          0.027195795, 0.35373217,    -0.018957434, 0.008907322,
3810                          -0.0762701,  0.12018895,    0.04216877,   0.0022856654,
3811                          0.040952638, 0.3147856,     0.08225149,   -0.057416286,
3812                          -0.14995944, -0.008040261,  0.13208859,   0.029760877};
3813 
3814     recurrent_to_input_weights_ = {
3815         -0.001374326,   -0.078856036,   0.10672688,    0.029162422,
3816         -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
3817         -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
3818         -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
3819         0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
3820         0.08981,        -0.045407712,   0.08682226,    -0.06867011,
3821         -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
3822         0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
3823         -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
3824         0.009352075,    0.22920375,     0.0016303885,  0.11583097,
3825         -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
3826         0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
3827         -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
3828         0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
3829         -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
3830         -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
3831         -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
3832         -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
3833         -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
3834         0.01068115,     0.032956902,    0.022433773,   0.0026891115,
3835         0.08944216,     -0.0685835,     0.010513544,   0.07228705,
3836         0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
3837         0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
3838         0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
3839         -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
3840         -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
3841         0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
3842         -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
3843         -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
3844         -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
3845         -0.017142897,   0.03312627,     0.009205989,   0.024138335,
3846         -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
3847         -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
3848         0.0365468,      0.07590991,     0.08838724,    0.021681072,
3849         -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
3850         0.023646897,    -0.095322326,   0.02233014,    0.09756986,
3851         -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
3852         -0.09801813,    0.019894179,    0.08502348,    0.004032281,
3853         0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
3854         -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
3855         -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
3856         0.010889619,    0.0047078193,   0.038385306,   0.08540671,
3857         -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
3858         0.015963363,    0.00871737,     0.060130805,   0.028611384,
3859         0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
3860         0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
3861         0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
3862         0.019899689,    0.006106124,    -0.027092824,  0.0786356,
3863         0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
3864         -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
3865         -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
3866         -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
3867         -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
3868         -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
3869         0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
3870         0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
3871         -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
3872         0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
3873         0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
3874         0.058618143,    -0.08598433,    0.00972939,    0.023867095,
3875         -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
3876         -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
3877         0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
3878         -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
3879         -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
3880         0.06358255,     0.18531723,     0.07759293,    0.12006465,
3881         0.1305557,      0.058638252,    -0.03393652,   0.09622831,
3882         -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
3883         -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
3884         0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
3885         0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
3886         0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
3887         0.08184801,     -0.019164372,   0.06791302,    0.034257166,
3888         -0.10307039,    0.021943003,    0.046745934,   0.0790918,
3889         -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
3890         -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
3891         -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
3892         0.026351685,    0.012641483,    0.07466548,    0.044301085,
3893         -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
3894         -0.04106223,    -0.028126027,   0.028473156,   0.10467447};
3895 
3896     recurrent_to_cell_weights_ = {
3897         -0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
3898         0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
3899         0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
3900         -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
3901         0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
3902         0.08089997,     0.05143358,    0.038261272,   0.03339287,
3903         -0.027673481,   0.044746667,   0.028349208,   0.020090483,
3904         -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
3905         -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
3906         -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
3907         0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
3908         -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
3909         -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
3910         0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
3911         0.010868644,    -0.031489216,  0.09525667,    0.013939797,
3912         0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
3913         -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
3914         0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
3915         0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
3916         -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
3917         0.02786344,     -0.014179351,  0.005264273,   0.14376344,
3918         0.015983658,    0.03406988,    -0.06939408,   0.040699873,
3919         0.02111075,     0.09669095,    0.041345075,   -0.08316494,
3920         -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
3921         0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
3922         -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
3923         0.06760663,     -0.027437469,  0.07216407,    0.06977076,
3924         -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
3925         0.043184172,    -0.037189785,  0.10420091,    0.00882477,
3926         -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
3927         0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
3928         0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
3929         -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
3930         0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
3931         -0.008264958,   0.042035464,   0.05891794,    0.029673764,
3932         0.0063542654,   0.044788733,   0.054816857,   0.062257513,
3933         -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
3934         -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
3935         -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
3936         -0.007376126,   0.003533447,   0.006570588,   0.056037236,
3937         0.12436656,     0.051817212,   0.028532185,   -0.08686856,
3938         0.11868599,     0.07663395,    -0.07323171,   0.03463402,
3939         -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
3940         0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
3941         0.023029093,    0.086124025,   0.006445803,   -0.03496501,
3942         0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
3943         -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
3944         0.09465633,     0.008115513,   -0.02171956,   0.08304309,
3945         0.071401566,    0.019622514,   0.032163795,   -0.004167056,
3946         0.02295182,     0.030739572,   0.056506045,   0.004612461,
3947         0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
3948         -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
3949         0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
3950         -0.0329582,     0.07922767,    0.029322514,   0.026405897,
3951         0.04207835,     -0.07073373,   0.063781224,   0.0859677,
3952         -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
3953         -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
3954         -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
3955         -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
3956         0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
3957         0.15978073,     0.10185836,    0.10298046,    -0.015476589,
3958         -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
3959         -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
3960         -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
3961         -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
3962         -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
3963         0.012962922,    -0.031234352,  0.07029052,    0.016418684,
3964         0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
3965         -0.054761406,   0.029065743,   0.052404847,   0.020238016,
3966         0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
3967         0.06262858,     0.009184685,   0.020785125,   -0.043904778,
3968         -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
3969         -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
3970         0.09232601,     -0.035886683,  0.06000002,    0.05229691,
3971         -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
3972         -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
3973         0.031502828,    0.036232427,   -0.031581745,  0.023051167,
3974         -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
3975         -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
3976         -0.008799762,   0.056595087,   0.0022273948,  0.055752404};
3977 
3978     recurrent_to_forget_weights_ = {
3979         -0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
3980         0.14811787,    0.10826372,    0.09471067,     0.03987225,
3981         -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
3982         0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
3983         0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
3984         -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
3985         -0.06193199,   0.055729095,   0.03736828,     0.020123724,
3986         0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
3987         -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
3988         -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
3989         0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
3990         -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
3991         -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
3992         -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
3993         0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
3994         0.013454138,   0.028934088,   0.01685226,     -0.086110644,
3995         -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
3996         0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
3997         0.03761666,    0.008096139,   -0.014454086,   0.014361001,
3998         -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
3999         -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
4000         0.060212336,   0.055259194,   0.06974018,     0.049454916,
4001         -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
4002         0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
4003         -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
4004         0.0042065294,  0.03881498,    0.019844765,    0.041858196,
4005         -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
4006         0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
4007         0.012321099,   0.082840554,   -0.029899208,   0.044217527,
4008         0.059855383,   0.07711018,    -0.045319796,   0.0948846,
4009         -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
4010         -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
4011         -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
4012         0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
4013         0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
4014         0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
4015         0.052958444,   0.07558703,    0.04817258,     0.044462286,
4016         -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
4017         0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
4018         0.024734668,   0.024614193,   -0.042046934,   0.09597743,
4019         -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
4020         -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
4021         -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
4022         0.04383914,    -0.046476185,  0.028658995,    0.060410924,
4023         0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
4024         0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
4025         0.015898481,   0.021362653,   -0.030262267,   0.016587038,
4026         -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
4027         -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
4028         0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
4029         -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
4030         -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
4031         -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
4032         -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
4033         0.15443139,    0.07684145,    0.036571592,    -0.035900835,
4034         -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
4035         -0.03858649,   0.01849943,    0.13872518,     0.01503974,
4036         0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
4037         -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
4038         0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
4039         0.05866852,    0.023947537,   -0.09445152,    0.035450947,
4040         0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
4041         0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
4042         0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
4043         0.051808182,   0.05875331,    -0.04536488,    0.001626336,
4044         -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
4045         0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
4046         -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
4047         -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
4048         0.11475477,    -0.023854522,  0.10071741,     0.0686208,
4049         -0.014250481,  0.034261297,   0.047418304,    0.08562733,
4050         -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
4051         0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
4052         0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
4053         0.014410365,   0.020995233,   0.17040324,     0.11511526,
4054         0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
4055         -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
4056         -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
4057         0.007076659,   0.10964551,    0.0409152,      0.008275321,
4058         -0.07283536,   0.07937492,    0.04192024,     -0.1075027};
4059 
4060     recurrent_to_output_weights_ = {
4061         0.025825322,   -0.05813119,   0.09495884,     -0.045984812,
4062         -0.01255415,   -0.0026479573, -0.08196161,    -0.054914974,
4063         -0.0046604523, -0.029587349,  -0.044576716,   -0.07480124,
4064         -0.082868785,  0.023254942,   0.027502948,    -0.0039728214,
4065         -0.08683098,   -0.08116779,   -0.014675607,   -0.037924774,
4066         -0.023314456,  -0.007401714,  -0.09255757,    0.029460307,
4067         -0.08829125,   -0.005139627,  -0.08989442,    -0.0555066,
4068         0.13596267,    -0.025062224,  -0.048351806,   -0.03850004,
4069         0.07266485,    -0.022414139,  0.05940088,     0.075114764,
4070         0.09597592,    -0.010211725,  -0.0049794707,  -0.011523867,
4071         -0.025980417,  0.072999895,   0.11091378,     -0.081685916,
4072         0.014416728,   0.043229222,   0.034178585,    -0.07530371,
4073         0.035837382,   -0.085607,     -0.007721233,   -0.03287832,
4074         -0.043848954,  -0.06404588,   -0.06632928,    -0.073643476,
4075         0.008214239,   -0.045984086,  0.039764922,    0.03474462,
4076         0.060612556,   -0.080590084,  0.049127717,    0.04151091,
4077         -0.030063879,  0.008801774,   -0.023021035,   -0.019558564,
4078         0.05158114,    -0.010947698,  -0.011825728,   0.0075720972,
4079         0.0699727,     -0.0039981045, 0.069350146,    0.08799282,
4080         0.016156472,   0.035502106,   0.11695009,     0.006217345,
4081         0.13392477,    -0.037875112,  0.025745004,    0.08940699,
4082         -0.00924166,   0.0046702605,  -0.036598757,   -0.08811812,
4083         0.10522024,    -0.032441203,  0.008176899,    -0.04454919,
4084         0.07058152,    0.0067963637,  0.039206743,    0.03259838,
4085         0.03725492,    -0.09515802,   0.013326398,    -0.052055415,
4086         -0.025676316,  0.03198509,    -0.015951829,   -0.058556724,
4087         0.036879618,   0.043357447,   0.028362012,    -0.05908629,
4088         0.0059240665,  -0.04995891,   -0.019187413,   0.0276265,
4089         -0.01628143,   0.0025863599,  0.08800015,     0.035250366,
4090         -0.022165963,  -0.07328642,   -0.009415526,   -0.07455109,
4091         0.11690406,    0.0363299,     0.07411125,     0.042103454,
4092         -0.009660886,  0.019076364,   0.018299393,    -0.046004917,
4093         0.08891175,    0.0431396,     -0.026327137,   -0.051502608,
4094         0.08979574,    -0.051670972,  0.04940282,     -0.07491107,
4095         -0.021240504,  0.022596184,   -0.034280192,   0.060163025,
4096         -0.058211457,  -0.051837247,  -0.01349775,    -0.04639988,
4097         -0.035936575,  -0.011681591,  0.064818054,    0.0073146066,
4098         -0.021745546,  -0.043124277,  -0.06471268,    -0.07053354,
4099         -0.029321948,  -0.05330136,   0.016933719,    -0.053782392,
4100         0.13747959,    -0.1361751,    -0.11569455,    0.0033329215,
4101         0.05693899,    -0.053219706,  0.063698,       0.07977434,
4102         -0.07924483,   0.06936997,    0.0034815092,   -0.007305279,
4103         -0.037325785,  -0.07251102,   -0.033633437,   -0.08677009,
4104         0.091591336,   -0.14165086,   0.021752775,    0.019683983,
4105         0.0011612234,  -0.058154266,  0.049996935,    0.0288841,
4106         -0.0024567875, -0.14345716,   0.010955264,    -0.10234828,
4107         0.1183656,     -0.0010731248, -0.023590032,   -0.072285876,
4108         -0.0724771,    -0.026382286,  -0.0014920527,  0.042667855,
4109         0.0018776858,  0.02986552,    0.009814309,    0.0733756,
4110         0.12289186,    0.018043943,   -0.0458958,     0.049412545,
4111         0.033632483,   0.05495232,    0.036686596,    -0.013781798,
4112         -0.010036754,  0.02576849,    -0.08307328,    0.010112348,
4113         0.042521734,   -0.05869831,   -0.071689695,   0.03876447,
4114         -0.13275425,   -0.0352966,    -0.023077697,   0.10285965,
4115         0.084736146,   0.15568255,    -0.00040734606, 0.027835453,
4116         -0.10292561,   -0.032401145,  0.10053256,     -0.026142767,
4117         -0.08271222,   -0.0030240538, -0.016368777,   0.1070414,
4118         0.042672627,   0.013456989,   -0.0437609,     -0.022309763,
4119         0.11576483,    0.04108048,    0.061026827,    -0.0190714,
4120         -0.0869359,    0.037901703,   0.0610107,      0.07202949,
4121         0.01675338,    0.086139716,   -0.08795751,    -0.014898893,
4122         -0.023771819,  -0.01965048,   0.007955471,    -0.043740474,
4123         0.03346837,    -0.10549954,   0.090567775,    0.042013682,
4124         -0.03176985,   0.12569028,    -0.02421228,    -0.029526481,
4125         0.023851605,   0.031539805,   0.05292009,     -0.02344001,
4126         -0.07811758,   -0.08834428,   0.10094801,     0.16594367,
4127         -0.06861939,   -0.021256343,  -0.041093912,   -0.06669611,
4128         0.035498552,   0.021757556,   -0.09302526,    -0.015403468,
4129         -0.06614931,   -0.051798206,  -0.013874718,   0.03630673,
4130         0.010412845,   -0.08077351,   0.046185967,    0.0035662893,
4131         0.03541868,    -0.094149634,  -0.034814864,   0.003128424,
4132         -0.020674974,  -0.03944324,   -0.008110165,   -0.11113267,
4133         0.08484226,    0.043586485,   0.040582247,    0.0968012,
4134         -0.065249965,  -0.028036479,  0.0050708856,   0.0017462453,
4135         0.0326779,     0.041296225,   0.09164146,     -0.047743853,
4136         -0.015952192,  -0.034451712,  0.084197424,    -0.05347844,
4137         -0.11768019,   0.085926116,   -0.08251791,    -0.045081906,
4138         0.0948852,     0.068401024,   0.024856757,    0.06978981,
4139         -0.057309967,  -0.012775832,  -0.0032452994,  0.01977615,
4140         -0.041040014,  -0.024264973,  0.063464895,    0.05431621,
4141     };
4142 
4143     cell_to_input_weights_ = {
4144         0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
4145         -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
4146         -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
4147         0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175};
4148 
4149     cell_to_forget_weights_ = {
4150         -0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
4151         -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
4152         -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
4153         0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355};
4154 
4155     cell_to_output_weights_ = {
4156         0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
4157         -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
4158         -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
4159         0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733};
4160 
4161     projection_weights_ = {
4162         -0.009802181, 0.09401916,   0.0717386,     -0.13895074,
4163         0.09641832,   0.060420845,  0.08539281,    0.054285463,
4164         0.061395317,  0.034448683,  -0.042991187,  0.019801661,
4165         -0.16840284,  -0.015726732, -0.23041931,   -0.024478018,
4166         -0.10959692,  -0.013875541, 0.18600968,    -0.061274476,
4167         0.0138165,    -0.08160894,  -0.07661644,   0.032372914,
4168         0.16169067,   0.22465782,   -0.03993472,   -0.004017731,
4169         0.08633481,   -0.28869787,  0.08682067,    0.17240396,
4170         0.014975425,  0.056431185,  0.031037588,   0.16702051,
4171         0.0077946745, 0.15140012,   0.29405436,    0.120285,
4172         -0.188994,    -0.027265169, 0.043389652,   -0.022061434,
4173         0.014777949,  -0.20203483,  0.094781205,   0.19100232,
4174         0.13987629,   -0.036132768, -0.06426278,   -0.05108664,
4175         0.13221376,   0.009441198,  -0.16715929,   0.15859416,
4176         -0.040437475, 0.050779544,  -0.022187516,  0.012166504,
4177         0.027685808,  -0.07675938,  -0.0055694645, -0.09444123,
4178         0.0046453946, 0.050794356,  0.10770313,    -0.20790008,
4179         -0.07149004,  -0.11425117,  0.008225835,   -0.035802525,
4180         0.14374903,   0.15262283,   0.048710253,   0.1847461,
4181         -0.007487823, 0.11000021,   -0.09542012,   0.22619456,
4182         -0.029149994, 0.08527916,   0.009043713,   0.0042746216,
4183         0.016261552,  0.022461696,  0.12689082,    -0.043589946,
4184         -0.12035478,  -0.08361797,  -0.050666027,  -0.1248618,
4185         -0.1275799,   -0.071875185, 0.07377272,    0.09944291,
4186         -0.18897448,  -0.1593054,   -0.06526116,   -0.040107165,
4187         -0.004618631, -0.067624845, -0.007576253,  0.10727444,
4188         0.041546922,  -0.20424393,  0.06907816,    0.050412357,
4189         0.00724631,   0.039827548,  0.12449835,    0.10747581,
4190         0.13708383,   0.09134148,   -0.12617786,   -0.06428341,
4191         0.09956831,   0.1208086,    -0.14676677,   -0.0727722,
4192         0.1126304,    0.010139365,  0.015571211,   -0.038128063,
4193         0.022913318,  -0.042050496, 0.16842307,    -0.060597885,
4194         0.10531834,   -0.06411776,  -0.07451711,   -0.03410368,
4195         -0.13393489,  0.06534304,   0.003620307,   0.04490757,
4196         0.05970546,   0.05197996,   0.02839995,    0.10434969,
4197         -0.013699693, -0.028353551, -0.07260381,   0.047201227,
4198         -0.024575593, -0.036445823, 0.07155557,    0.009672501,
4199         -0.02328883,  0.009533515,  -0.03606021,   -0.07421458,
4200         -0.028082801, -0.2678904,   -0.13221288,   0.18419984,
4201         -0.13012612,  -0.014588381, -0.035059117,  -0.04824723,
4202         0.07830115,   -0.056184657, 0.03277091,    0.025466874,
4203         0.14494097,   -0.12522776,  -0.098633975,  -0.10766018,
4204         -0.08317623,  0.08594209,   0.07749552,    0.039474737,
4205         0.1776665,    -0.07409566,  -0.0477268,    0.29323658,
4206         0.10801441,   0.1154011,    0.013952499,   0.10739139,
4207         0.10708251,   -0.051456142, 0.0074137426,  -0.10430189,
4208         0.10034707,   0.045594677,  0.0635285,     -0.0715442,
4209         -0.089667566, -0.10811871,  0.00026344223, 0.08298446,
4210         -0.009525053, 0.006585689,  -0.24567553,   -0.09450807,
4211         0.09648481,   0.026996298,  -0.06419476,   -0.04752702,
4212         -0.11063944,  -0.23441927,  -0.17608605,   -0.052156363,
4213         0.067035615,  0.19271925,   -0.0032889997, -0.043264326,
4214         0.09663576,   -0.057112187, -0.10100678,   0.0628376,
4215         0.04447668,   0.017961001,  -0.10094388,   -0.10190601,
4216         0.18335468,   0.10494553,   -0.052095775,  -0.0026118709,
4217         0.10539724,   -0.04383912,  -0.042349473,  0.08438151,
4218         -0.1947263,   0.02251204,   0.11216432,    -0.10307853,
4219         0.17351969,   -0.039091777, 0.08066188,    -0.00561982,
4220         0.12633002,   0.11335965,   -0.0088127935, -0.019777594,
4221         0.06864014,   -0.059751723, 0.016233567,   -0.06894641,
4222         -0.28651384,  -0.004228674, 0.019708522,   -0.16305895,
4223         -0.07468996,  -0.0855457,   0.099339016,   -0.07580735,
4224         -0.13775392,  0.08434318,   0.08330512,    -0.12131499,
4225         0.031935584,  0.09180414,   -0.08876437,   -0.08049874,
4226         0.008753825,  0.03498998,   0.030215185,   0.03907079,
4227         0.089751154,  0.029194152,  -0.03337423,   -0.019092513,
4228         0.04331237,   0.04299654,   -0.036394123,  -0.12915532,
4229         0.09793732,   0.07512415,   -0.11319543,   -0.032502122,
4230         0.15661901,   0.07671967,   -0.005491124,  -0.19379048,
4231         -0.218606,    0.21448623,   0.017840758,   0.1416943,
4232         -0.07051762,  0.19488361,   0.02664691,    -0.18104725,
4233         -0.09334311,  0.15026465,   -0.15493552,   -0.057762887,
4234         -0.11604192,  -0.262013,    -0.01391798,   0.012185008,
4235         0.11156489,   -0.07483202,  0.06693364,    -0.26151478,
4236         0.046425626,  0.036540434,  -0.16435726,   0.17338543,
4237         -0.21401681,  -0.11385144,  -0.08283257,   -0.069031075,
4238         0.030635102,  0.010969227,  0.11109743,    0.010919218,
4239         0.027526086,  0.13519906,   0.01891392,    -0.046839405,
4240         -0.040167913, 0.017953383,  -0.09700955,   0.0061885654,
4241         -0.07000971,  0.026893595,  -0.038844477,  0.14543656};
4242 
4243     lstm_input_ = {
4244         {// Batch0: 4 (input_sequence_size) * 5 (n_input)
4245          0.787926, 0.151646, 0.071352, 0.118426, 0.458058,   // step 0
4246          0.596268, 0.998386, 0.568695, 0.864524, 0.571277,   // step 1
4247          0.073204, 0.296072, 0.743333, 0.069199, 0.045348,   // step 2
4248          0.867394, 0.291279, 0.013714, 0.482521, 0.626339},  // step 3
4249 
4250         {// Batch1: 4 (input_sequence_size) * 5 (n_input)
4251          0.295743, 0.544053, 0.690064, 0.858138, 0.497181,  // step 0
4252          0.642421, 0.524260, 0.134799, 0.003639, 0.162482,  // step 1
4253          0.640394, 0.930399, 0.050782, 0.432485, 0.988078,  // step 2
4254          0.082922, 0.563329, 0.865614, 0.333232, 0.259916}  // step 3
4255     };
4256 
4257     lstm_golden_output_ = {
4258         {// Batch0: 4 (input_sequence_size) * 16 (n_output)
4259          -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
4260          -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
4261          -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
4262          0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
4263          -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
4264          -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
4265          0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
4266          0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
4267          0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
4268          0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
4269          -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
4270          -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
4271          0.0286833,   0.00824207,   0.0264887,   0.0305169},
4272         {// Batch1: 4 (input_sequence_size) * 16 (n_output)
4273          -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
4274          -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
4275          0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
4276          0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
4277          -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
4278          -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
4279          0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
4280          0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
4281          0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
4282          0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
4283          -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
4284          -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
4285          0.0412031,    0.0118723,   0.0239643,   0.0394009}};
4286   }
4287 };
4288 
TEST_F(NoCifgPeepholeProjectionClippingLstmTest,LstmBlackBoxTest)4289 TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
4290   const int n_batch = 2;
4291   const int n_input = 5;
4292   const int n_cell = 20;
4293   const int n_output = 16;
4294 
4295   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
4296                    /*use_cifg=*/false, /*use_peephole=*/true,
4297                    /*use_projection_weights=*/true,
4298                    /*use_projection_bias=*/false,
4299                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
4300                    {
4301                        {n_batch, n_input},  // input tensor
4302 
4303                        {n_cell, n_input},  // input_to_input_weight tensor
4304                        {n_cell, n_input},  // input_to_forget_weight tensor
4305                        {n_cell, n_input},  // input_to_cell_weight tensor
4306                        {n_cell, n_input},  // input_to_output_weight tensor
4307 
4308                        {n_cell, n_output},  // recurrent_to_input_weight tensor
4309                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
4310                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
4311                        {n_cell, n_output},  // recurrent_to_output_weight tensor
4312 
4313                        {n_cell},  // cell_to_input_weight tensor
4314                        {n_cell},  // cell_to_forget_weight tensor
4315                        {n_cell},  // cell_to_output_weight tensor
4316 
4317                        {n_cell},  // input_gate_bias tensor
4318                        {n_cell},  // forget_gate_bias tensor
4319                        {n_cell},  // cell_bias tensor
4320                        {n_cell},  // output_gate_bias tensor
4321 
4322                        {n_output, n_cell},  // projection_weight tensor
4323                        {0},                 // projection_bias tensor
4324 
4325                        {n_batch, n_output},  // activation_state tensor
4326                        {n_batch, n_cell},    // cell_state tensor
4327                    },
4328                    /*weight_type=*/TensorType_FLOAT32);
4329 
4330   lstm.SetInputToInputWeights(input_to_input_weights_);
4331   lstm.SetInputToCellWeights(input_to_cell_weights_);
4332   lstm.SetInputToForgetWeights(input_to_forget_weights_);
4333   lstm.SetInputToOutputWeights(input_to_output_weights_);
4334 
4335   lstm.SetInputGateBias(input_gate_bias_);
4336   lstm.SetCellBias(cell_gate_bias_);
4337   lstm.SetForgetGateBias(forget_gate_bias_);
4338   lstm.SetOutputGateBias(output_gate_bias_);
4339 
4340   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
4341   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
4342   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
4343   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
4344 
4345   lstm.SetCellToInputWeights(cell_to_input_weights_);
4346   lstm.SetCellToForgetWeights(cell_to_forget_weights_);
4347   lstm.SetCellToOutputWeights(cell_to_output_weights_);
4348 
4349   lstm.SetProjectionWeights(projection_weights_);
4350 
4351   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
4352 }
4353 
4354 class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
4355     : public BaseLstmTest {
SetUp()4356   void SetUp() override {
4357     input_to_input_weights_ = {0.5,  0.6,  0.7,  -0.8, -0.9, 0.1,  0.2,
4358                                0.3,  -0.4, 0.5,  -0.8, 0.7,  -0.6, 0.5,
4359                                -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
4360 
4361     input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
4362                                 -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
4363                                 -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
4364 
4365     input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
4366                               -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
4367                               -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
4368 
4369     input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
4370                                 -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
4371                                 -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
4372 
4373     input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
4374 
4375     forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
4376 
4377     cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
4378 
4379     output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
4380 
4381     recurrent_to_input_weights_ = {-0.2, -0.3, 0.4,  0.1,  -0.5, 0.9,
4382                                    -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
4383 
4384     recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
4385                                   -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
4386 
4387     recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
4388                                     0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
4389 
4390     recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
4391                                     -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
4392 
4393     cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
4394 
4395     cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
4396 
4397     cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
4398 
4399     input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
4400     forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
4401     cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
4402     output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
4403 
4404     projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
4405                            0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
4406 
4407     lstm_input_ = {
4408         {// Batch0: 3 (input_sequence_size) * 5 (n_input)
4409          0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
4410          0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
4411          0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
4412 
4413         {// Batch1: 3 (input_sequence_size) * 5 (n_input)
4414          0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
4415          0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
4416          0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
4417     };
4418   }
4419 };
4420 
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,LayerNormLstmBlackBoxTest)4421 TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
4422        LayerNormLstmBlackBoxTest) {
4423   const int n_batch = 2;
4424   const int n_input = 5;
4425   const int n_cell = 4;
4426   const int n_output = 3;
4427   const float ceil_clip = 0.0;
4428   const float proj_clip = 0.0;
4429 
4430   LSTMOpModel layer_norm_lstm(
4431       n_batch, n_input, n_cell, n_output,
4432       /*use_cifg=*/false, /*use_peephole=*/true,
4433       /*use_projection_weights=*/true,
4434       /*use_projection_bias=*/false, ceil_clip, proj_clip,
4435       {
4436           {n_batch, n_input},  // input tensor
4437 
4438           {n_cell, n_input},  // input_to_input_weight tensor
4439           {n_cell, n_input},  // input_to_forget_weight tensor
4440           {n_cell, n_input},  // input_to_cell_weight tensor
4441           {n_cell, n_input},  // input_to_output_weight tensor
4442 
4443           {n_cell, n_output},  // recurrent_to_input_weight tensor
4444           {n_cell, n_output},  // recurrent_to_forget_weight tensor
4445           {n_cell, n_output},  // recurrent_to_cell_weight tensor
4446           {n_cell, n_output},  // recurrent_to_output_weight tensor
4447 
4448           {n_cell},  // cell_to_input_weight tensor
4449           {n_cell},  // cell_to_forget_weight tensor
4450           {n_cell},  // cell_to_output_weight tensor
4451 
4452           {n_cell},  // input_gate_bias tensor
4453           {n_cell},  // forget_gate_bias tensor
4454           {n_cell},  // cell_bias tensor
4455           {n_cell},  // output_gate_bias tensor
4456 
4457           {n_output, n_cell},  // projection_weight tensor
4458           {0},                 // projection_bias tensor
4459 
4460           {n_batch, n_output},  // activation_state tensor
4461           {n_batch, n_cell},    // cell_state tensor
4462 
4463           {n_cell},  // input_layer_norm_coefficient tensor
4464           {n_cell},  // forget_layer_norm_coefficient tensor
4465           {n_cell},  // cell_layer_norm_coefficient tensor
4466           {n_cell},  // output_layer_norm_coefficient tensor
4467       },
4468       /*weight_type=*/TensorType_FLOAT32);
4469 
4470   layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
4471   layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
4472   layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
4473   layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
4474 
4475   layer_norm_lstm.SetInputGateBias(input_gate_bias_);
4476   layer_norm_lstm.SetCellBias(cell_gate_bias_);
4477   layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
4478   layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
4479 
4480   layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
4481   layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
4482   layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
4483   layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
4484 
4485   layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
4486   layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
4487   layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
4488 
4489   layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
4490   layer_norm_lstm.SetForgetLayerNormCoefficients(
4491       forget_layer_norm_coefficients_);
4492   layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
4493   layer_norm_lstm.SetOutputLayerNormCoefficients(
4494       output_layer_norm_coefficients_);
4495 
4496   layer_norm_lstm.SetProjectionWeights(projection_weights_);
4497 
4498   // Verify the final output.
4499   const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
4500       {
4501           // Batch0: 3 (input_sequence_size) * 3 (n_output)
4502           0.0244077, 0.128027, -0.00170918,  // seq 0
4503           0.0137642, 0.140751, 0.0395835,    // seq 1
4504           -0.00459231, 0.155278, 0.0837377,  // seq 2
4505       },
4506       {
4507           // Batch1: 3 (input_sequence_size) * 3 (n_output)
4508           -0.00692428, 0.0848741, 0.063445,  // seq 0
4509           -0.00403912, 0.139963, 0.072681,   // seq 1
4510           0.00752706, 0.161903, 0.0561371,   // seq 2
4511       }};
4512 
4513   VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm);
4514 }
4515 
4516 class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
SetUp()4517   void SetUp() override {
4518     input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
4519                                 -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
4520                                 -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
4521     input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
4522                               -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
4523                               -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
4524     input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
4525                                 -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
4526                                 -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
4527 
4528     forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
4529     cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
4530     output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
4531 
4532     recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
4533                                   -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
4534     recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
4535                                     0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
4536     recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
4537                                     -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
4538 
4539     cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
4540     cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
4541 
4542     forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
4543     cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
4544     output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
4545     projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
4546                            0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
4547 
4548     lstm_input_ = {
4549         {// Batch0: 3 (input_sequence_size) * 5 (n_input)
4550          0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
4551          0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
4552          0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
4553 
4554         {// Batch1: 3 (input_sequence_size) * 5 (n_input)
4555          0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
4556          0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
4557          0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
4558     };
4559   }
4560 };
4561 
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,LayerNormLstmBlackBoxTest)4562 TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
4563        LayerNormLstmBlackBoxTest) {
4564   const int n_batch = 2;
4565   const int n_input = 5;
4566   const int n_cell = 4;
4567   const int n_output = 3;
4568   const float ceil_clip = 0.0;
4569   const float proj_clip = 0.0;
4570 
4571   LSTMOpModel layer_norm_lstm(
4572       n_batch, n_input, n_cell, n_output,
4573       /*use_cifg=*/true, /*use_peephole=*/true,
4574       /*use_projection_weights=*/true,
4575       /*use_projection_bias=*/false, ceil_clip, proj_clip,
4576       {
4577           {n_batch, n_input},  // input tensor
4578 
4579           {0, 0},             // input_to_input_weight tensor
4580           {n_cell, n_input},  // input_to_forget_weight tensor
4581           {n_cell, n_input},  // input_to_cell_weight tensor
4582           {n_cell, n_input},  // input_to_output_weight tensor
4583 
4584           {0, 0},              // recurrent_to_input_weight tensor
4585           {n_cell, n_output},  // recurrent_to_forget_weight tensor
4586           {n_cell, n_output},  // recurrent_to_cell_weight tensor
4587           {n_cell, n_output},  // recurrent_to_output_weight tensor
4588 
4589           {0},       // cell_to_input_weight tensor
4590           {n_cell},  // cell_to_forget_weight tensor
4591           {n_cell},  // cell_to_output_weight tensor
4592 
4593           {0},       // input_gate_bias tensor
4594           {n_cell},  // forget_gate_bias tensor
4595           {n_cell},  // cell_bias tensor
4596           {n_cell},  // output_gate_bias tensor
4597 
4598           {n_output, n_cell},  // projection_weight tensor
4599           {0},                 // projection_bias tensor
4600 
4601           {n_batch, n_output},  // activation_state tensor
4602           {n_batch, n_cell},    // cell_state tensor
4603 
4604           {0},       // input_layer_norm_coefficient tensor
4605           {n_cell},  // forget_layer_norm_coefficient tensor
4606           {n_cell},  // cell_layer_norm_coefficient tensor
4607           {n_cell},  // output_layer_norm_coefficient tensor
4608       },
4609       /*weight_type=*/TensorType_FLOAT32);
4610 
4611   layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
4612   layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
4613   layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
4614 
4615   layer_norm_lstm.SetCellBias(cell_gate_bias_);
4616   layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
4617   layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
4618 
4619   layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
4620   layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
4621   layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
4622 
4623   layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
4624   layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
4625 
4626   layer_norm_lstm.SetForgetLayerNormCoefficients(
4627       forget_layer_norm_coefficients_);
4628   layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
4629   layer_norm_lstm.SetOutputLayerNormCoefficients(
4630       output_layer_norm_coefficients_);
4631 
4632   layer_norm_lstm.SetProjectionWeights(projection_weights_);
4633 
4634   // Verify the final output.
4635   const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
4636       {
4637           // Batch0: 3 (input_sequence_size) * 3 (n_output)
4638           0.02129706, 0.140816242, 0.0112733059,     // seq 0
4639           0.0132302344, 0.152308047, 0.0346313119,   // seq 1
4640           -0.0123688057, 0.165790111, 0.0893077999,  // seq 2
4641       },
4642       {
4643           // Batch1: 3 (input_sequence_size) * 3 (n_output)
4644           -0.0226350538, 0.0916948169, 0.0769175813,  // seq 0
4645           -0.0269966982, 0.149707705, 0.094149217,    // seq 1
4646           -0.0103429332, 0.173016444, 0.0720508844,   // seq 2
4647       }};
4648 
4649   VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm);
4650 }
4651 
4652 class BaseReduceOpModel : public SingleOpModelWithNNAPI {
4653  public:
SetAxis(const std::vector<int> & data)4654   void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }
4655 
4656   template <class T>
SetInput(const std::vector<T> & data)4657   void SetInput(const std::vector<T>& data) {
4658     PopulateTensor(input_, data);
4659   }
4660 
4661   template <class T>
GetOutput()4662   std::vector<T> GetOutput() {
4663     return ExtractVector<T>(output_);
4664   }
4665 
GetDequantizedOutput()4666   std::vector<float> GetDequantizedOutput() {
4667     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
4668                                GetScale(output_), GetZeroPoint(output_));
4669   }
4670 
GetOutputShape()4671   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
4672 
Input()4673   int Input() { return input_; }
4674 
4675  protected:
4676   int input_;
4677   int axis_;
4678   int output_;
4679 };
4680 
4681 // Model for the tests case where axis is a dynamic tensor.
4682 class MeanOpDynamicModel : public BaseReduceOpModel {
4683  public:
MeanOpDynamicModel(const TensorData & input,const TensorData & output,const TensorData & axis,bool keep_dims)4684   MeanOpDynamicModel(const TensorData& input, const TensorData& output,
4685                      const TensorData& axis, bool keep_dims) {
4686     input_ = AddInput(input);
4687     axis_ = AddInput(axis);
4688     output_ = AddOutput(output);
4689     SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
4690                  CreateReducerOptions(builder_, keep_dims).Union());
4691     BuildInterpreterWithNNAPI({GetShape(input_)});
4692   }
4693 };
4694 
TEST(DynamicFloatMeanOpTest,NotKeepDims)4695 TEST(DynamicFloatMeanOpTest, NotKeepDims) {
4696   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
4697                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
4698                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
4699   MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
4700                        {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
4701                        false);
4702   std::vector<int> axis = {1, 0, -3, -3};
4703   m.SetAxis(axis);
4704   m.SetInput(data);
4705   m.Invoke();
4706   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
4707   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
4708 }
4709 
4710 // Model for the tests case where axis is a const tensor.
4711 class MeanOpConstModel : public BaseReduceOpModel {
4712  public:
MeanOpConstModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis_shape,std::initializer_list<int> axis,bool keep_dims)4713   MeanOpConstModel(const TensorData& input, const TensorData& output,
4714                    std::initializer_list<int> axis_shape,
4715                    std::initializer_list<int> axis, bool keep_dims) {
4716     input_ = AddInput(input);
4717     axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
4718     output_ = AddOutput(output);
4719     SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
4720                  CreateReducerOptions(builder_, keep_dims).Union());
4721     BuildInterpreterWithNNAPI({GetShape(input_)});
4722   }
4723 };
4724 
4725 // Tests for reduce_mean
TEST(NNAPIDelegate,MeanFloatNotKeepDims)4726 TEST(NNAPIDelegate, MeanFloatNotKeepDims) {
4727   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
4728                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
4729                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
4730   MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
4731                      {4}, {1, 0, -3, -3}, false);
4732   m.SetInput(data);
4733   m.Invoke();
4734   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
4735   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({12, 13}));
4736 }
4737 
TEST(NNAPIDelegate,MeanFloatKeepDims)4738 TEST(NNAPIDelegate, MeanFloatKeepDims) {
4739   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
4740                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
4741                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
4742   MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
4743                      {2}, {0, 2}, true);
4744   m.SetInput(data);
4745   m.Invoke();
4746   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
4747   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({10.5, 12.5, 14.5}));
4748 }
4749 
4750 class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI {
4751  public:
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,std::initializer_list<int> weight_shape,TensorType weight_type=TensorType_FLOAT32)4752   BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
4753                              std::initializer_list<int> weight_shape,
4754                              TensorType weight_type = TensorType_FLOAT32) {
4755     input_ = AddInput(TensorType_INT32);
4756     weight_ = AddInput(weight_type);
4757     output_ = AddOutput(TensorType_FLOAT32);
4758     SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
4759     BuildInterpreterWithNNAPI({index_shape, weight_shape});
4760   }
4761 
SetInput(std::initializer_list<int> data)4762   void SetInput(std::initializer_list<int> data) {
4763     PopulateTensor(input_, data);
4764   }
4765 
GetOutput()4766   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
4767 
4768  protected:
4769   int input_;
4770   int weight_;
4771   int output_;
4772 };
4773 
4774 class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
4775  public:
4776   using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
4777 
Set3DWeightMatrix(const std::function<float (int,int,int)> & function)4778   void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
4779     TfLiteTensor* tensor = interpreter_->tensor(weight_);
4780     int rows = tensor->dims->data[0];
4781     int columns = tensor->dims->data[1];
4782     int features = tensor->dims->data[2];
4783     for (int i = 0; i < rows; i++) {
4784       for (int j = 0; j < columns; j++) {
4785         for (int k = 0; k < features; k++) {
4786           tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
4787         }
4788       }
4789     }
4790   }
4791 };
4792 
TEST(NNAPIDelegate,EmbeddingLookupSimpleTest)4793 TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) {
4794   EmbeddingLookupOpModel m({3}, {3, 2, 4});
4795   m.SetInput({1, 0, 2});
4796   m.Set3DWeightMatrix(
4797       [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
4798 
4799   m.Invoke();
4800 
4801   EXPECT_THAT(m.GetOutput(),
4802               NnapiArrayFloatNear({
4803                   1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
4804                   0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
4805                   2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
4806               }));
4807 }
4808 
4809 class HashtableLookupOpModel : public SingleOpModelWithNNAPI {
4810  public:
HashtableLookupOpModel(std::initializer_list<int> lookup_shape,std::initializer_list<int> key_shape,std::initializer_list<int> value_shape,TensorType type)4811   HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
4812                          std::initializer_list<int> key_shape,
4813                          std::initializer_list<int> value_shape,
4814                          TensorType type) {
4815     lookup_ = AddInput(TensorType_INT32);
4816     key_ = AddInput(TensorType_INT32);
4817     value_ = AddInput(type);
4818     output_ = AddOutput(type);
4819     hit_ = AddOutput(TensorType_UINT8);
4820     SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
4821     BuildInterpreterWithNNAPI({lookup_shape, key_shape, value_shape});
4822   }
4823 
SetLookup(std::initializer_list<int> data)4824   void SetLookup(std::initializer_list<int> data) {
4825     PopulateTensor<int>(lookup_, data);
4826   }
4827 
SetHashtableKey(std::initializer_list<int> data)4828   void SetHashtableKey(std::initializer_list<int> data) {
4829     PopulateTensor<int>(key_, data);
4830   }
4831 
SetHashtableValue(const std::vector<string> & content)4832   void SetHashtableValue(const std::vector<string>& content) {
4833     PopulateStringTensor(value_, content);
4834   }
4835 
SetHashtableValue(const std::function<float (int)> & function)4836   void SetHashtableValue(const std::function<float(int)>& function) {
4837     TfLiteTensor* tensor = interpreter_->tensor(value_);
4838     int rows = tensor->dims->data[0];
4839     for (int i = 0; i < rows; i++) {
4840       tensor->data.f[i] = function(i);
4841     }
4842   }
4843 
SetHashtableValue(const std::function<float (int,int)> & function)4844   void SetHashtableValue(const std::function<float(int, int)>& function) {
4845     TfLiteTensor* tensor = interpreter_->tensor(value_);
4846     int rows = tensor->dims->data[0];
4847     int features = tensor->dims->data[1];
4848     for (int i = 0; i < rows; i++) {
4849       for (int j = 0; j < features; j++) {
4850         tensor->data.f[i * features + j] = function(i, j);
4851       }
4852     }
4853   }
4854 
GetStringOutput()4855   std::vector<string> GetStringOutput() {
4856     TfLiteTensor* output = interpreter_->tensor(output_);
4857     int num = GetStringCount(output);
4858     std::vector<string> result(num);
4859     for (int i = 0; i < num; i++) {
4860       auto ref = GetString(output, i);
4861       result[i] = string(ref.str, ref.len);
4862     }
4863     return result;
4864   }
4865 
GetOutput()4866   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetHit()4867   std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }
4868 
4869  private:
4870   int lookup_;
4871   int key_;
4872   int value_;
4873   int output_;
4874   int hit_;
4875 };
4876 
TEST(NNAPIDelegate,HashtableLookupTest2DInput)4877 TEST(NNAPIDelegate, HashtableLookupTest2DInput) {
4878   HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);
4879 
4880   m.SetLookup({1234, -292, -11, 0});
4881   m.SetHashtableKey({-11, 0, 1234});
4882   m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
4883 
4884   m.Invoke();
4885 
4886   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
4887                                  2.0, 2.1,  // 2-nd item
4888                                  0, 0,      // Not found
4889                                  0.0, 0.1,  // 0-th item
4890                                  1.0, 1.1,  // 1-st item
4891                              }));
4892   EXPECT_THAT(m.GetHit(), ElementsAreArray({
4893                               1,
4894                               0,
4895                               1,
4896                               1,
4897                           }));
4898 }
4899 
TEST(NNAPIDelegate,HashtableLookupTest1DInput)4900 TEST(NNAPIDelegate, HashtableLookupTest1DInput) {
4901   HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);
4902 
4903   m.SetLookup({1234, -292, -11, 0});
4904   m.SetHashtableKey({-11, 0, 1234});
4905   m.SetHashtableValue([](int i) { return i * i / 10.0f; });
4906 
4907   m.Invoke();
4908 
4909   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
4910                                  0.4,  // 2-nd item
4911                                  0,    // Not found
4912                                  0.0,  // 0-th item
4913                                  0.1,  // 1-st item
4914                              }));
4915   EXPECT_THAT(m.GetHit(), ElementsAreArray({
4916                               1,
4917                               0,
4918                               1,
4919                               1,
4920                           }));
4921 }
4922 
4923 // A base class of PRelu op model. It provides the constructor for
4924 // FloatPReluOpModel and QuantizedPReluOpModel.
4925 class PReluOpModel : public SingleOpModelWithNNAPI {
4926  public:
PReluOpModel(const TensorData & input,const TensorData & alpha)4927   PReluOpModel(const TensorData& input, const TensorData& alpha)
4928       : input_type_(input.type) {
4929     input_ = AddInput(input);
4930     alpha_ = AddInput(alpha);
4931     output_ = AddOutput({input.type, input.shape, input.min, input.max});
4932     SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
4933     BuildInterpreterWithNNAPI({GetShape(input_), GetShape(alpha_)});
4934   }
4935 
SetInput(std::initializer_list<float> data)4936   void SetInput(std::initializer_list<float> data) {
4937     SetData(input_, input_type_, data);
4938   }
4939 
SetAlpha(std::initializer_list<float> data)4940   void SetAlpha(std::initializer_list<float> data) {
4941     SetData(alpha_, input_type_, data);
4942   }
4943 
GetOutput()4944   std::vector<float> GetOutput() {
4945     std::vector<float> output;
4946     GetData(output_, input_type_, &output);
4947     return output;
4948   }
4949 
4950  protected:
4951   int input_;
4952   int alpha_;
4953   int output_;
4954 
4955   const TensorType input_type_;
4956 };
4957 
TEST(NNAPIDelegate,PReluFloat)4958 TEST(NNAPIDelegate, PReluFloat) {
4959   PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
4960                  {TensorType_FLOAT32, {1, 1, 3}});
4961 
4962   m.SetInput({
4963       0.0f, 0.0f, 0.0f,     // Row 1, Column 1
4964       1.0f, 1.0f, 1.0f,     // Row 1, Column 2
4965       -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
4966       -2.0f, -2.0f, -2.0f,  // Row 1, Column 2
4967   });
4968   m.SetAlpha({0.0f, 1.0f, 2.0f});
4969   m.Invoke();
4970   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
4971                                  0.0f, 0.0f, 0.0f,    // Row 1, Column 1
4972                                  1.0f, 1.0f, 1.0f,    // Row 1, Column 2
4973                                  0.0f, -1.0f, -2.0f,  // Row 2, Column 1
4974                                  0.0f, -2.0f, -4.0f,  // Row 1, Column 2
4975                              }));
4976 }
4977 
TEST(NNAPIDelegate,PReluQuantized)4978 TEST(NNAPIDelegate, PReluQuantized) {
4979   const float kMin = -1;
4980   const float kMax = 127.f / 128.f;
4981   PReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
4982                  {TensorType_UINT8, {1, 1, 3}, kMin, kMax});
4983   m.SetInput({
4984       0.0f, 0.0f, 0.0f,        // Row 1, Column 1
4985       0.5f, 0.5f, 0.5f,        // Row 1, Column 2
4986       -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
4987       -0.25f, -0.25f, -0.25f,  // Row 1, Column 2
4988   });
4989   m.SetAlpha({0.0f, 0.5f, -0.5f});
4990   m.Invoke();
4991   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
4992                                  {
4993                                      0.0f, 0.0f, 0.0f,       // Row 1, Column 1
4994                                      0.5f, 0.5f, 0.5f,       // Row 1, Column 2
4995                                      0.0f, -0.5f, 0.5f,      // Row 2, Column 1
4996                                      0.0f, -0.125f, 0.125f,  // Row 1, Column 2
4997                                  },
4998                                  kQuantizedTolerance)));
4999 }
5000 
5001 // Tests case where paddings is a const tensor. Type T is the dtype.
5002 template <typename T1>
5003 class PadV2OpConstModel : public PadOpModel<T1> {
5004  public:
PadV2OpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,T1 constant_values,const TensorData & output)5005   PadV2OpConstModel(const TensorData& input,
5006                     std::initializer_list<int> paddings_shape,
5007                     std::initializer_list<int> paddings, T1 constant_values,
5008                     const TensorData& output) {
5009     this->input_ = this->AddInput(input);
5010     this->paddings_ =
5011         this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
5012     this->constant_values_ =
5013         this->AddConstInput(GetTensorType<T1>(), {constant_values}, {1});
5014 
5015     this->output_ = this->AddOutput(output);
5016 
5017     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5018                        CreatePadV2Options(this->builder_).Union());
5019     this->BuildInterpreterWithNNAPI({input.shape});
5020   }
5021 
PadV2OpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,const TensorData & constant_values,const TensorData & output)5022   PadV2OpConstModel(const TensorData& input,
5023                     std::initializer_list<int> paddings_shape,
5024                     std::initializer_list<int> paddings,
5025                     const TensorData& constant_values,
5026                     const TensorData& output) {
5027     this->input_ = this->AddInput(input);
5028     this->paddings_ =
5029         this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
5030     this->constant_values_ = this->AddInput(constant_values);
5031 
5032     this->output_ = this->AddOutput(output);
5033 
5034     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5035                        CreatePadV2Options(this->builder_).Union());
5036     this->BuildInterpreterWithNNAPI({input.shape});
5037   }
5038 };
5039 
5040 // Test case where paddings is a non-const tensor.
5041 template <typename RegularInputOutput>
5042 class PadV2OpDynamicModel : public PadOpModel<RegularInputOutput> {
5043  public:
PadV2OpDynamicModel(const TensorData & input,std::initializer_list<int> paddings_shape,RegularInputOutput constant_values,const TensorData & output)5044   PadV2OpDynamicModel(const TensorData& input,
5045                       std::initializer_list<int> paddings_shape,
5046                       RegularInputOutput constant_values,
5047                       const TensorData& output) {
5048     this->input_ = this->AddInput(input);
5049     this->paddings_ = this->AddInput(TensorType_INT32);
5050     this->constant_values_ = this->AddConstInput(
5051         GetTensorType<RegularInputOutput>(), {constant_values}, {1});
5052     this->output_ = this->AddOutput(output);
5053 
5054     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5055                        CreatePadV2Options(this->builder_).Union());
5056     this->BuildInterpreterWithNNAPI({input.shape, paddings_shape});
5057   }
PadV2OpDynamicModel(const TensorData & input,std::initializer_list<int> paddings_shape,const TensorData & constant_values,const TensorData & output)5058   PadV2OpDynamicModel(const TensorData& input,
5059                       std::initializer_list<int> paddings_shape,
5060                       const TensorData& constant_values,
5061                       const TensorData& output) {
5062     this->input_ = this->AddInput(input);
5063     this->paddings_ = this->AddInput(TensorType_INT32);
5064     this->constant_values_ = this->AddInput(constant_values);
5065     this->output_ = this->AddOutput(output);
5066 
5067     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5068                        CreatePadV2Options(this->builder_).Union());
5069     this->BuildInterpreterWithNNAPI({input.shape, paddings_shape});
5070   }
5071 };
5072 
TEST(PadV2OpTest,SimpleConstTest)5073 TEST(PadV2OpTest, SimpleConstTest) {
5074   // Padding is represented as four 2-D lists representing above padding and
5075   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5076   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
5077                              {0, 0, 1, 1, 1, 1, 0, 0}, 0.0,
5078                              {TensorType_FLOAT32});
5079   m.SetInput({1, 2, 3, 4});
5080   m.Invoke();
5081   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0, 0, 0, 0, 0, 1, 2, 0, 0, 3,
5082                                                   4, 0, 0, 0, 0, 0}));
5083   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5084 }
5085 
TEST(PadV2OpTest,SimpleConstFloat32ValuedTestUint8)5086 TEST(PadV2OpTest, SimpleConstFloat32ValuedTestUint8) {
5087   // Padding is represented as four 2-D lists representing above padding and
5088   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5089   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
5090                              {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_FLOAT32});
5091   m.SetInput({1, 2, 3, 4});
5092   m.Invoke();
5093   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 1, 2, 5, 5, 3,
5094                                                   4, 5, 5, 5, 5, 5}));
5095   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5096 }
5097 
TEST(PadV2OpTest,Simple4DConstFloat32ValuedTest)5098 TEST(PadV2OpTest, Simple4DConstFloat32ValuedTest) {
5099   // Padding is represented as four 2-D lists representing above padding and
5100   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5101   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2},
5102                              {0, 1, 0, 0, 0, 0, 0, 1}, 5, {TensorType_FLOAT32});
5103   m.SetInput({3, 3});
5104   m.Invoke();
5105   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({3, 5, 3, 5, 5, 5, 5, 5}));
5106   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2}));
5107 }
5108 
TEST(PadV2OpTest,SimpleDynamicTest)5109 TEST(PadV2OpTest, SimpleDynamicTest) {
5110   PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 0.0,
5111                                {TensorType_FLOAT32});
5112   m.SetInput({1, 2, 3, 4});
5113   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5114   m.Invoke();
5115   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0, 0, 0, 0, 0, 1, 2, 0, 0, 3,
5116                                                   4, 0, 0, 0, 0, 0}));
5117   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5118 }
5119 
TEST(PadV2OpTest,SimpleDynamicValuedTest)5120 TEST(PadV2OpTest, SimpleDynamicValuedTest) {
5121   PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 5,
5122                                {TensorType_FLOAT32});
5123   m.SetInput({1, 2, 3, 4});
5124   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5125   m.Invoke();
5126   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 1, 2, 5, 5, 3,
5127                                                   4, 5, 5, 5, 5, 5}));
5128   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5129 }
5130 
TEST(PadV2OpTest,AdvancedConstTest)5131 TEST(PadV2OpTest, AdvancedConstTest) {
5132   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
5133                              {0, 0, 0, 2, 1, 3, 0, 0}, 0, {TensorType_FLOAT32});
5134   m.SetInput({1, 2, 3, 4, 5, 6});
5135   m.Invoke();
5136   EXPECT_THAT(m.GetOutput(),
5137               NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
5138                                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
5139   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5140 }
5141 
TEST(PadV2OpTest,AdvancedDynamicTest)5142 TEST(PadV2OpTest, AdvancedDynamicTest) {
5143   PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, 0,
5144                                {TensorType_FLOAT32});
5145   m.SetInput({1, 2, 3, 4, 5, 6});
5146   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
5147   m.Invoke();
5148   EXPECT_THAT(m.GetOutput(),
5149               NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
5150                                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
5151   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5152 }
5153 
DequantizedArrayNear(const std::vector<float> & values,const float min,const float max)5154 std::vector<testing::Matcher<float>> DequantizedArrayNear(
5155     const std::vector<float>& values, const float min, const float max) {
5156   const float quantization_tolerance = (max - min) / 255.0;
5157   return ArrayFloatNear(values, quantization_tolerance);
5158 }
5159 
5160 template <typename integer_type, TensorType tensor_dtype>
SimpleConstTestV2()5161 void SimpleConstTestV2() {
5162   // Padding is represented as four 2-D lists representing above padding and
5163   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5164   PadV2OpConstModel<integer_type> m(
5165       {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
5166       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5167   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5168   m.template SetQuantizedPadValue<integer_type>(0);
5169   m.Invoke();
5170   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5171               ElementsAreArray(DequantizedArrayNear(
5172                   {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
5173                   -1.0, 1.0)));
5174   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5175 }
5176 
TEST(QuantizedPadV2OpTest,UInt8SimpleConstTest)5177 TEST(QuantizedPadV2OpTest, UInt8SimpleConstTest) {
5178   SimpleConstTestV2<uint8_t, TensorType_UINT8>();
5179 }
TEST(QuantizedPadV2OpTest,Int8SimpleConstTest)5180 TEST(QuantizedPadV2OpTest, Int8SimpleConstTest) {
5181   SimpleConstTestV2<int8_t, TensorType_INT8>();
5182 }
5183 
5184 template <typename integer_type, TensorType tensor_dtype>
SimpleDynamicTestV2()5185 void SimpleDynamicTestV2() {
5186   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0},
5187                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5188                                       {tensor_dtype, {}, -1.0, 1.0});
5189   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5190   m.template SetQuantizedPadValue<integer_type>(0);
5191   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5192   m.Invoke();
5193   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5194               ElementsAreArray(DequantizedArrayNear(
5195                   {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
5196                   -1.0, 1.0)));
5197   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5198 }
5199 
TEST(QuantizedPadV2OpTest,UInt8SimpleDynamicTest)5200 TEST(QuantizedPadV2OpTest, UInt8SimpleDynamicTest) {
5201   SimpleDynamicTestV2<uint8_t, TensorType_UINT8>();
5202 }
TEST(QuantizedPadV2OpTest,Int8SimpleDynamicTest)5203 TEST(QuantizedPadV2OpTest, Int8SimpleDynamicTest) {
5204   SimpleDynamicTestV2<int8_t, TensorType_INT8>();
5205 }
5206 
5207 template <typename integer_type, TensorType tensor_dtype>
AdvancedConstTestV2()5208 void AdvancedConstTestV2() {
5209   PadV2OpConstModel<integer_type> m(
5210       {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
5211       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5212   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5213   m.template SetQuantizedPadValue<integer_type>(0);
5214   m.Invoke();
5215   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5216               ElementsAreArray(DequantizedArrayNear(
5217                   {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
5218                    0, 0,    0,   0,   0, 0, 0, 0, 0,   0,   0,    0, 0, 0},
5219                   -1.0, 1.0)));
5220   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5221 }
5222 
TEST(QuantizedPadV2OpTest,UInt8AdvancedConstTest)5223 TEST(QuantizedPadV2OpTest, UInt8AdvancedConstTest) {
5224   AdvancedConstTestV2<uint8_t, TensorType_UINT8>();
5225 }
TEST(QuantizedPadV2OpTest,Int8AdvancedConstTest)5226 TEST(QuantizedPadV2OpTest, Int8AdvancedConstTest) {
5227   AdvancedConstTestV2<int8_t, TensorType_INT8>();
5228 }
5229 
5230 template <typename integer_type, TensorType tensor_dtype>
AdvancedDynamicTestV2()5231 void AdvancedDynamicTestV2() {
5232   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0},
5233                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5234                                       {tensor_dtype, {}, -1.0, 1.0});
5235   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5236   m.template SetQuantizedPadValue<integer_type>(0);
5237   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
5238   m.Invoke();
5239   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5240               ElementsAreArray(DequantizedArrayNear(
5241                   {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
5242                    0, 0,    0,   0,   0, 0, 0, 0, 0,   0,   0,    0, 0, 0},
5243                   -1.0, 1.0)));
5244   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5245 }
5246 
TEST(QuantizedPadV2OpTest,UInt8AdvancedDynamicTest)5247 TEST(QuantizedPadV2OpTest, UInt8AdvancedDynamicTest) {
5248   AdvancedDynamicTestV2<uint8_t, TensorType_UINT8>();
5249 }
TEST(QuantizedPadV2OpTest,Int8AdvancedDynamicTest)5250 TEST(QuantizedPadV2OpTest, Int8AdvancedDynamicTest) {
5251   AdvancedDynamicTestV2<int8_t, TensorType_INT8>();
5252 }
5253 
5254 template <typename integer_type, TensorType tensor_dtype>
SimpleConstValuedTest()5255 void SimpleConstValuedTest() {
5256   // Padding is represented as four 2-D lists representing above padding and
5257   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5258   PadV2OpConstModel<integer_type> m(
5259       {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
5260       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5261   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5262   m.template SetQuantizedPadValue<integer_type>(-0.5);
5263   m.Invoke();
5264   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5265               ElementsAreArray(DequantizedArrayNear(
5266                   {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
5267                    0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
5268                   -1.0, 1.0)));
5269   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5270 }
5271 
TEST(QuantizedPadV2OpTest,UInt8SimpleConstValuedTest)5272 TEST(QuantizedPadV2OpTest, UInt8SimpleConstValuedTest) {
5273   SimpleConstValuedTest<uint8_t, TensorType_UINT8>();
5274 }
TEST(QuantizedPadV2OpTest,Int8SimpleConstValuedTest)5275 TEST(QuantizedPadV2OpTest, Int8SimpleConstValuedTest) {
5276   SimpleConstValuedTest<int8_t, TensorType_INT8>();
5277 }
5278 
5279 template <typename integer_type, TensorType tensor_dtype>
SimpleDynamicValuedTest()5280 void SimpleDynamicValuedTest() {
5281   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0},
5282                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5283                                       {tensor_dtype, {}, -1.0, 1.0});
5284   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5285   m.template SetQuantizedPadValue<integer_type>(-0.5);
5286   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5287   m.Invoke();
5288   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5289               ElementsAreArray(DequantizedArrayNear(
5290                   {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
5291                    0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
5292                   -1.0, 1.0)));
5293   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5294 }
5295 
TEST(QuantizedPadV2OpTest,UInt8SimpleDynamicValuedTest)5296 TEST(QuantizedPadV2OpTest, UInt8SimpleDynamicValuedTest) {
5297   SimpleDynamicValuedTest<uint8_t, TensorType_UINT8>();
5298 }
TEST(QuantizedPadV2OpTest,Int8SimpleDynamicValuedTest)5299 TEST(QuantizedPadV2OpTest, Int8SimpleDynamicValuedTest) {
5300   SimpleDynamicValuedTest<int8_t, TensorType_INT8>();
5301 }
5302 
5303 template <typename integer_type, TensorType tensor_dtype>
AdvancedConstValuedTest()5304 void AdvancedConstValuedTest() {
5305   PadV2OpConstModel<integer_type> m(
5306       {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
5307       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5308   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5309   m.template SetQuantizedPadValue<integer_type>(-0.5);
5310   m.Invoke();
5311   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5312               ElementsAreArray(DequantizedArrayNear(
5313                   {-0.5, -0.8, 0.2,  0.9,  -0.5, -0.5, -0.5, -0.5, 0.7,  0.1,
5314                    -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
5315                    -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
5316                   -1.0, 1.0)));
5317   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5318 }
5319 
TEST(QuantizedPadV2OpTest,UInt8AdvancedConstValuedTest)5320 TEST(QuantizedPadV2OpTest, UInt8AdvancedConstValuedTest) {
5321   AdvancedConstValuedTest<uint8_t, TensorType_UINT8>();
5322 }
TEST(QuantizedPadV2OpTest,Int8AdvancedConstValuedTest)5323 TEST(QuantizedPadV2OpTest, Int8AdvancedConstValuedTest) {
5324   AdvancedConstValuedTest<int8_t, TensorType_INT8>();
5325 }
5326 
5327 template <typename integer_type, TensorType tensor_dtype>
AdvancedDynamicValuedTest()5328 void AdvancedDynamicValuedTest() {
5329   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0},
5330                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5331                                       {tensor_dtype, {}, -1.0, 1.0});
5332   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5333   m.template SetQuantizedPadValue<integer_type>(-0.5);
5334   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
5335   m.Invoke();
5336   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5337               ElementsAreArray(DequantizedArrayNear(
5338                   {-0.5, -0.8, 0.2,  0.9,  -0.5, -0.5, -0.5, -0.5, 0.7,  0.1,
5339                    -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
5340                    -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
5341                   -1.0, 1.0)));
5342   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5343 }
5344 
TEST(QuantizedPadV2OpTest,UInt8AdvancedDynamicValuedTest)5345 TEST(QuantizedPadV2OpTest, UInt8AdvancedDynamicValuedTest) {
5346   AdvancedDynamicValuedTest<uint8_t, TensorType_UINT8>();
5347 }
TEST(QuantizedPadV2OpTest,Int8AdvancedDynamicValuedTest)5348 TEST(QuantizedPadV2OpTest, Int8AdvancedDynamicValuedTest) {
5349   AdvancedDynamicValuedTest<int8_t, TensorType_INT8>();
5350 }
5351 
5352 // A base class of Leaky ReLU op model. It provides the constructor for
5353 // FloatLeakyReluOpModel and QuantizedLeakyReluOpModel.
5354 class LeakyReluOpModel : public SingleOpModelWithNNAPI {
5355  public:
LeakyReluOpModel(const TensorData & input,const float & alpha)5356   LeakyReluOpModel(const TensorData& input, const float& alpha)
5357       : input_type_(input.type) {
5358     input_ = AddInput(input);
5359     output_ = AddOutput({input.type, input.shape, input.min, input.max});
5360 
5361     SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
5362                  CreateLeakyReluOptions(builder_, alpha).Union());
5363     BuildInterpreterWithNNAPI({GetShape(input_)});
5364   }
5365 
SetInput(std::initializer_list<float> data)5366   void SetInput(std::initializer_list<float> data) {
5367     SetData(input_, input_type_, data);
5368   }
5369 
GetOutput()5370   std::vector<float> GetOutput() {
5371     std::vector<float> output;
5372     GetData(output_, input_type_, &output);
5373     return output;
5374   }
5375 
5376  protected:
5377   int input_;
5378   int output_;
5379 
5380   const TensorType input_type_;
5381 };
5382 
TEST(NNAPIDelegate,LeakyReluFloat)5383 TEST(NNAPIDelegate, LeakyReluFloat) {
5384   LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);
5385 
5386   m.SetInput({
5387       0.0f, 1.0f, 3.0f,    // Row 1
5388       1.0f, -1.0f, -2.0f,  // Row 2
5389   });
5390   m.Invoke();
5391   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
5392                                  0.0f, 1.0f, 3.0f,    // Row 1
5393                                  1.0f, -0.5f, -1.0f,  // Row 2
5394 
5395                              }));
5396 }
5397 
TEST(NNAPIDelegate,LeakyReluQuantized)5398 TEST(NNAPIDelegate, LeakyReluQuantized) {
5399   const float kMin = -1;
5400   const float kMax = 127.f / 128.f;
5401   LeakyReluOpModel m({TensorType_UINT8, {2, 3}, 8 * kMin, 8 * kMax}, 0.5f);
5402   m.SetInput({
5403       0.0f, 1.0f, 3.0f,    // Row 1
5404       1.0f, -1.0f, -2.0f,  // Row 2
5405   });
5406   m.Invoke();
5407   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
5408                                  {
5409                                      0.0f, 1.0f, 3.0f,    // Row 1
5410                                      1.0f, -0.5f, -1.0f,  // Row 2
5411                                  },
5412                                  kQuantizedTolerance)));
5413 }
5414 
5415 }  // namespace
5416 }  // namespace tflite
5417