• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <vector>
18 
19 #include "tensorflow/lite/interpreter.h"
20 #include "tensorflow/lite/kernels/perception/perception_ops.h"
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/testing/util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27 
28 namespace {
29 
30 using testing::ElementsAreArray;
31 
32 class DenseImageWarpOpModel : public SingleOpModel {
33  public:
DenseImageWarpOpModel(const TensorData & input,const TensorData & flow,const TensorData & output)34   DenseImageWarpOpModel(const TensorData& input, const TensorData& flow,
35                         const TensorData& output) {
36     input_ = AddInput(input);
37     flow_ = AddInput(flow);
38     output_ = AddOutput(output);
39 
40     std::vector<uint8_t> custom_option;
41     SetCustomOp("DenseImageWarp", custom_option, RegisterDenseImageWarp);
42     BuildInterpreter({GetShape(input_), GetShape(flow_)});
43   }
44 
SetInput(const std::vector<float> & data)45   void SetInput(const std::vector<float>& data) {
46     PopulateTensor(input_, data);
47   }
SetFlow(const std::vector<float> & data)48   void SetFlow(const std::vector<float>& data) { PopulateTensor(flow_, data); }
49 
GetOutput()50   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
51 
GetOutputShape()52   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
53 
54  protected:
55   int input_;
56   int flow_;
57   int output_;
58 };
59 
TEST(DenseImageWarpOpTest,MismatchedSizeTest)60 TEST(DenseImageWarpOpTest, MismatchedSizeTest) {
61   EXPECT_DEATH_IF_SUPPORTED(
62       DenseImageWarpOpModel model(
63           /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
64           /*flow=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
65           /*output=*/{TensorType_FLOAT32, {}});
66       , "input_shape.Dims.2. != flow_shape.Dims.2. .4 != 2.");
67 }
68 
TEST(DenseImageWarpOpTest,WrongFlowSizeTest)69 TEST(DenseImageWarpOpTest, WrongFlowSizeTest) {
70   EXPECT_DEATH_IF_SUPPORTED(DenseImageWarpOpModel model(
71                                 /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
72                                 /*flow=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
73                                 /*output=*/{TensorType_FLOAT32, {}});
74                             , "The last dimension of flow tensor must be 2.");
75 }
76 
TEST(DenseImageWarpOpTest,SimpleTest)77 TEST(DenseImageWarpOpTest, SimpleTest) {
78   DenseImageWarpOpModel model(
79       /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
80       /*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
81       /*output=*/{TensorType_FLOAT32, {}});
82   model.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
83   model.SetFlow({4, 10, 6,  10, 4, 2, 6, 6,  10, -4, 2,  -2, 6,  8, 6, 0,
84                  2, -2, 10, 6,  4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
85   model.Invoke();
86 
87   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
88   EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 0, 3, 3, 0, 3, 2, 0,
89                                                    0, 3, 12, 15, 12, 0}));
90 }
91 
TEST(DenseImageWarpOpTest,RoundTest)92 TEST(DenseImageWarpOpTest, RoundTest) {
93   DenseImageWarpOpModel model(
94       /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
95       /*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
96       /*output=*/{TensorType_FLOAT32, {}});
97   model.SetInput({0.2, 1.5, 2.4, 3.5, 4.6, 5.1, 6.3, 7.2, 8.5, 9.6, 10.9, 11.6,
98                   12.8, 13.2, 14.4, 15.5});
99   model.SetFlow({4, 10, 6,  10, 4, 2, 6, 6,  10, -4, 2,  -2, 6,  8, 6, 0,
100                  2, -2, 10, 6,  4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
101   model.Invoke();
102 
103   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
104   EXPECT_THAT(model.GetOutput(),
105               ElementsAreArray({0.2, 0.2, 0.2, 0.2, 3.5, 3.5, 0.2, 3.5, 2.4,
106                                 0.2, 0.2, 3.5, 12.8, 15.5, 12.8, 0.2}));
107 }
108 
TEST(DenseImageWarpOpTest,WithBatchandChannelTest)109 TEST(DenseImageWarpOpTest, WithBatchandChannelTest) {
110   DenseImageWarpOpModel model(
111       /*input=*/{TensorType_FLOAT32, {2, 4, 4, 3}},
112       /*flow=*/{TensorType_FLOAT32, {2, 4, 4, 2}},
113       /*output=*/{TensorType_FLOAT32, {}});
114 
115   std::vector<float> input_data;
116   for (int i = 0; i < 96; ++i) input_data.push_back(i);
117   model.SetInput(input_data);
118   model.SetFlow({2, -2, 10, 6,  4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
119                  4, 10, 6,  10, 4, 2, 6, 6,  10, -4, 2,  -2, 6,  8, 6, 0,
120                  2, -2, 10, 6,  4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
121                  4, 10, 6,  10, 4, 2, 6, 6,  10, -4, 2,  -2, 6,  8, 6, 0});
122   model.Invoke();
123 
124   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 4, 3}));
125   EXPECT_THAT(
126       model.GetOutput(),
127       ElementsAreArray({6,  7,  8,  0,  1,  2,  0,  1,  2,  9,  10, 11, 36, 37,
128                         38, 45, 46, 47, 36, 37, 38, 0,  1,  2,  0,  1,  2,  0,
129                         1,  2,  0,  1,  2,  0,  1,  2,  9,  10, 11, 21, 22, 23,
130                         0,  1,  2,  9,  10, 11, 54, 55, 56, 48, 49, 50, 48, 49,
131                         50, 57, 58, 59, 84, 85, 86, 93, 94, 95, 84, 85, 86, 48,
132                         49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50,
133                         57, 58, 59, 69, 70, 71, 48, 49, 50, 57, 58, 59}));
134 }
135 }  // namespace
136 }  // namespace custom
137 }  // namespace ops
138 }  // namespace tflite
139