1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <vector>
18
19 #include "tensorflow/lite/interpreter.h"
20 #include "tensorflow/lite/kernels/perception/perception_ops.h"
21 #include "tensorflow/lite/kernels/test_util.h"
22 #include "tensorflow/lite/testing/util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace custom {
27
28 namespace {
29
30 using testing::ElementsAreArray;
31
32 class DenseImageWarpOpModel : public SingleOpModel {
33 public:
DenseImageWarpOpModel(const TensorData & input,const TensorData & flow,const TensorData & output)34 DenseImageWarpOpModel(const TensorData& input, const TensorData& flow,
35 const TensorData& output) {
36 input_ = AddInput(input);
37 flow_ = AddInput(flow);
38 output_ = AddOutput(output);
39
40 std::vector<uint8_t> custom_option;
41 SetCustomOp("DenseImageWarp", custom_option, RegisterDenseImageWarp);
42 BuildInterpreter({GetShape(input_), GetShape(flow_)});
43 }
44
SetInput(const std::vector<float> & data)45 void SetInput(const std::vector<float>& data) {
46 PopulateTensor(input_, data);
47 }
SetFlow(const std::vector<float> & data)48 void SetFlow(const std::vector<float>& data) { PopulateTensor(flow_, data); }
49
GetOutput()50 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
51
GetOutputShape()52 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
53
54 protected:
55 int input_;
56 int flow_;
57 int output_;
58 };
59
TEST(DenseImageWarpOpTest,MismatchedSizeTest)60 TEST(DenseImageWarpOpTest, MismatchedSizeTest) {
61 EXPECT_DEATH_IF_SUPPORTED(
62 DenseImageWarpOpModel model(
63 /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
64 /*flow=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
65 /*output=*/{TensorType_FLOAT32, {}});
66 , "input_shape.Dims.2. != flow_shape.Dims.2. .4 != 2.");
67 }
68
TEST(DenseImageWarpOpTest,WrongFlowSizeTest)69 TEST(DenseImageWarpOpTest, WrongFlowSizeTest) {
70 EXPECT_DEATH_IF_SUPPORTED(DenseImageWarpOpModel model(
71 /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
72 /*flow=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
73 /*output=*/{TensorType_FLOAT32, {}});
74 , "The last dimension of flow tensor must be 2.");
75 }
76
TEST(DenseImageWarpOpTest,SimpleTest)77 TEST(DenseImageWarpOpTest, SimpleTest) {
78 DenseImageWarpOpModel model(
79 /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
80 /*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
81 /*output=*/{TensorType_FLOAT32, {}});
82 model.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
83 model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
84 2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
85 model.Invoke();
86
87 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
88 EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 0, 3, 3, 0, 3, 2, 0,
89 0, 3, 12, 15, 12, 0}));
90 }
91
TEST(DenseImageWarpOpTest,RoundTest)92 TEST(DenseImageWarpOpTest, RoundTest) {
93 DenseImageWarpOpModel model(
94 /*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
95 /*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
96 /*output=*/{TensorType_FLOAT32, {}});
97 model.SetInput({0.2, 1.5, 2.4, 3.5, 4.6, 5.1, 6.3, 7.2, 8.5, 9.6, 10.9, 11.6,
98 12.8, 13.2, 14.4, 15.5});
99 model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
100 2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
101 model.Invoke();
102
103 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
104 EXPECT_THAT(model.GetOutput(),
105 ElementsAreArray({0.2, 0.2, 0.2, 0.2, 3.5, 3.5, 0.2, 3.5, 2.4,
106 0.2, 0.2, 3.5, 12.8, 15.5, 12.8, 0.2}));
107 }
108
TEST(DenseImageWarpOpTest,WithBatchandChannelTest)109 TEST(DenseImageWarpOpTest, WithBatchandChannelTest) {
110 DenseImageWarpOpModel model(
111 /*input=*/{TensorType_FLOAT32, {2, 4, 4, 3}},
112 /*flow=*/{TensorType_FLOAT32, {2, 4, 4, 2}},
113 /*output=*/{TensorType_FLOAT32, {}});
114
115 std::vector<float> input_data;
116 for (int i = 0; i < 96; ++i) input_data.push_back(i);
117 model.SetInput(input_data);
118 model.SetFlow({2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
119 4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
120 2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
121 4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0});
122 model.Invoke();
123
124 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 4, 3}));
125 EXPECT_THAT(
126 model.GetOutput(),
127 ElementsAreArray({6, 7, 8, 0, 1, 2, 0, 1, 2, 9, 10, 11, 36, 37,
128 38, 45, 46, 47, 36, 37, 38, 0, 1, 2, 0, 1, 2, 0,
129 1, 2, 0, 1, 2, 0, 1, 2, 9, 10, 11, 21, 22, 23,
130 0, 1, 2, 9, 10, 11, 54, 55, 56, 48, 49, 50, 48, 49,
131 50, 57, 58, 59, 84, 85, 86, 93, 94, 95, 84, 85, 86, 48,
132 49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50,
133 57, 58, 59, 69, 70, 71, 48, 49, 50, 57, 58, 59}));
134 }
135 } // namespace
136 } // namespace custom
137 } // namespace ops
138 } // namespace tflite
139