• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.h"
17 
18 #include <array>
19 #include <cstdint>
20 #include <functional>
21 #include <numeric>
22 #include <random>
23 #include <vector>
24 
25 #include <gtest/gtest.h>
26 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/schema/schema_conversion_utils.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 #include "tensorflow/lite/version.h"
33 
34 namespace tflite {
35 namespace xnnpack {
36 
Test(TfLiteDelegate * delegate) const37 void ResizeBilinearTester::Test(TfLiteDelegate* delegate) const {
38   std::random_device random_device;
39   auto rng = std::mt19937(random_device());
40   auto input_rng =
41       std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
42 
43   std::vector<char> buffer = CreateTfLiteModel();
44   const Model* model = GetModel(buffer.data());
45 
46   std::unique_ptr<Interpreter> delegate_interpreter;
47   ASSERT_EQ(
48       InterpreterBuilder(
49           model,
50           ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
51           &delegate_interpreter),
52       kTfLiteOk);
53   std::unique_ptr<Interpreter> default_interpreter;
54   ASSERT_EQ(
55       InterpreterBuilder(
56           model,
57           ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())(
58           &default_interpreter),
59       kTfLiteOk);
60 
61   ASSERT_TRUE(delegate_interpreter);
62   ASSERT_TRUE(default_interpreter);
63 
64   ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
65   ASSERT_EQ(default_interpreter->inputs().size(), 1);
66 
67   ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
68   ASSERT_EQ(default_interpreter->outputs().size(), 1);
69 
70   ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
71   ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
72 
73   ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk);
74 
75   float* default_input_data = default_interpreter->typed_tensor<float>(
76       default_interpreter->inputs()[0]);
77   std::generate(default_input_data,
78                 default_input_data +
79                     BatchSize() * InputHeight() * InputWidth() * Channels(),
80                 std::ref(input_rng));
81 
82   float* delegate_input_data = delegate_interpreter->typed_tensor<float>(
83       delegate_interpreter->inputs()[0]);
84   std::copy(default_input_data,
85             default_input_data +
86                 BatchSize() * InputHeight() * InputWidth() * Channels(),
87             delegate_input_data);
88 
89   ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk);
90   ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk);
91 
92   float* default_output_data = default_interpreter->typed_tensor<float>(
93       default_interpreter->outputs()[0]);
94   float* delegate_output_data = delegate_interpreter->typed_tensor<float>(
95       delegate_interpreter->outputs()[0]);
96 
97   for (int i = 0; i < BatchSize(); i++) {
98     for (int y = 0; y < OutputHeight(); y++) {
99       for (int x = 0; x < OutputWidth(); x++) {
100         for (int c = 0; c < Channels(); c++) {
101           const int index =
102               ((i * OutputHeight() + y) * OutputWidth() + x) * Channels() + c;
103           ASSERT_NEAR(default_output_data[index], delegate_output_data[index],
104                       std::max(std::abs(default_output_data[index]) * 1.0e-4f,
105                                10.0f * std::numeric_limits<float>::epsilon()))
106               << "batch " << i << " / " << BatchSize() << ", y position " << y
107               << " / " << OutputHeight() << ", x position " << x << " / "
108               << OutputWidth() << ", channel " << c << " / " << Channels();
109         }
110       }
111     }
112   }
113 }
114 
CreateTfLiteModel() const115 std::vector<char> ResizeBilinearTester::CreateTfLiteModel() const {
116   flatbuffers::FlatBufferBuilder builder;
117   flatbuffers::Offset<OperatorCode> operator_code =
118       CreateOperatorCode(builder, BuiltinOperator_RESIZE_BILINEAR);
119 
120   flatbuffers::Offset<tflite::ResizeBilinearOptions> resize_bilinear_options =
121       CreateResizeBilinearOptions(builder, AlignCorners(), HalfPixelCenters());
122 
123   const std::array<int32_t, 2> size_data{{OutputHeight(), OutputWidth()}};
124 
125   const std::array<flatbuffers::Offset<Buffer>, 2> buffers{{
126       CreateBuffer(builder, builder.CreateVector({})),
127       CreateBuffer(builder,
128                    builder.CreateVector(
129                        reinterpret_cast<const uint8_t*>(size_data.data()),
130                        size_data.size() * sizeof(int32_t))),
131   }};
132 
133   const std::array<int32_t, 4> input_shape{
134       {BatchSize(), InputHeight(), InputWidth(), Channels()}};
135   const std::array<int32_t, 4> output_shape{
136       {BatchSize(), OutputHeight(), OutputWidth(), Channels()}};
137   const std::array<int32_t, 1> size_shape{
138       {static_cast<int32_t>(size_data.size())}};
139 
140   const std::array<flatbuffers::Offset<Tensor>, 3> tensors{{
141       CreateTensor(
142           builder,
143           builder.CreateVector<int32_t>(input_shape.data(), input_shape.size()),
144           TensorType_FLOAT32),
145       CreateTensor(
146           builder,
147           builder.CreateVector<int32_t>(size_shape.data(), size_shape.size()),
148           TensorType_INT32, /*buffer=*/1),
149       CreateTensor(builder,
150                    builder.CreateVector<int32_t>(output_shape.data(),
151                                                  output_shape.size()),
152                    TensorType_FLOAT32),
153   }};
154 
155   const std::array<int32_t, 2> op_inputs{{0, 1}};
156   const std::array<int32_t, 1> op_outputs{{2}};
157   flatbuffers::Offset<Operator> op = CreateOperator(
158       builder, /*opcode_index=*/0,
159       builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
160       builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()),
161       BuiltinOptions_ResizeBilinearOptions, resize_bilinear_options.Union());
162 
163   const std::array<int32_t, 1> subgraph_inputs{{0}};
164   const std::array<int32_t, 1> subgraph_outputs{{2}};
165   flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
166       builder, builder.CreateVector(tensors.data(), tensors.size()),
167       builder.CreateVector<int32_t>(subgraph_inputs.data(),
168                                     subgraph_inputs.size()),
169       builder.CreateVector<int32_t>(subgraph_outputs.data(),
170                                     subgraph_outputs.size()),
171       builder.CreateVector(&op, 1));
172 
173   flatbuffers::Offset<flatbuffers::String> description =
174       builder.CreateString("Resize Bilinear model");
175 
176   flatbuffers::Offset<Model> model_buffer = CreateModel(
177       builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1),
178       builder.CreateVector(&subgraph, 1), description,
179       builder.CreateVector(buffers.data(), buffers.size()));
180 
181   builder.Finish(model_buffer);
182 
183   return std::vector<char>(builder.GetBufferPointer(),
184                            builder.GetBufferPointer() + builder.GetSize());
185 }
186 
187 }  // namespace xnnpack
188 }  // namespace tflite
189