• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_MirrorPad")
9 {
10 struct MirrorPadFixture : public ParserFlatbuffersFixture
11 {
MirrorPadFixtureMirrorPadFixture12     explicit MirrorPadFixture(const std::string& inputShape,
13                               const std::string& outputShape,
14                               const std::string& padListShape,
15                               const std::string& padListData,
16                               const std::string& padMode,
17                               const std::string& dataType = "FLOAT32",
18                               const std::string& scale = "1.0",
19                               const std::string& offset = "0")
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": "MIRROR_PAD" } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {
28                             "shape": )" + inputShape + R"(,
29                             "type": )" + dataType + R"(,
30                             "buffer": 0,
31                             "name": "inputTensor",
32                             "quantization": {
33                                 "min": [ 0.0 ],
34                                 "max": [ 255.0 ],
35                                 "scale": [ )" + scale + R"( ],
36                                 "zero_point": [ )" + offset + R"( ],
37                             }
38                         },
39                         {
40                              "shape": )" + outputShape + R"(,
41                              "type": )" + dataType + R"(,
42                              "buffer": 1,
43                              "name": "outputTensor",
44                              "quantization": {
45                                 "min": [ 0.0 ],
46                                 "max": [ 255.0 ],
47                                 "scale": [ )" + scale + R"( ],
48                                 "zero_point": [ )" + offset + R"( ],
49                             }
50                         },
51                         {
52                              "shape": )" + padListShape + R"( ,
53                              "type": "INT32",
54                              "buffer": 2,
55                              "name": "padList",
56                              "quantization": {
57                                 "min": [ 0.0 ],
58                                 "max": [ 255.0 ],
59                                 "scale": [ 1.0 ],
60                                 "zero_point": [ 0 ],
61                              }
62                         }
63                     ],
64                     "inputs": [ 0 ],
65                     "outputs": [ 1 ],
66                     "operators": [
67                         {
68                             "opcode_index": 0,
69                             "inputs": [ 0, 2 ],
70                             "outputs": [ 1 ],
71                             "builtin_options_type": "MirrorPadOptions",
72                             "builtin_options": {
73                               "mode": )" + padMode + R"( ,
74                             },
75                             "custom_options_format": "FLEXBUFFERS"
76                         }
77                     ],
78                 } ],
79                 "buffers" : [
80                     { },
81                     { },
82                     { "data": )" + padListData + R"(, },
83                 ]
84             }
85         )";
86       SetupSingleInputSingleOutput("inputTensor", "outputTensor");
87     }
88 };
89 
90 struct SimpleMirrorPadSymmetricFixture : public MirrorPadFixture
91 {
SimpleMirrorPadSymmetricFixtureSimpleMirrorPadSymmetricFixture92     SimpleMirrorPadSymmetricFixture() : MirrorPadFixture("[ 3, 3 ]", "[ 7, 7 ]", "[ 2, 2 ]",
93                                                          "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 2,0,0,0 ]",
94                                                          "SYMMETRIC", "FLOAT32") {}
95 };
96 
97 TEST_CASE_FIXTURE(SimpleMirrorPadSymmetricFixture, "ParseMirrorPadSymmetric")
98 {
99     RunTest<2, armnn::DataType::Float32>
100             (0,
101              {{ "inputTensor",  { 1.0f, 2.0f, 3.0f,
102                                   4.0f, 5.0f, 6.0f,
103                                   7.0f, 8.0f, 9.0f }}},
104 
105              {{ "outputTensor", { 5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f,
106                                   2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f,
107                                   2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f,
108                                   5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f,
109                                   8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f,
110                                   8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f,
111                                   5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f }}});
112 }
113 
114 struct SimpleMirrorPadReflectFixture : public MirrorPadFixture
115 {
SimpleMirrorPadReflectFixtureSimpleMirrorPadReflectFixture116     SimpleMirrorPadReflectFixture() : MirrorPadFixture("[ 3, 3 ]", "[ 7, 7 ]", "[ 2, 2 ]",
117                                                         "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 2,0,0,0 ]",
118                                                         "REFLECT", "FLOAT32") {}
119 };
120 
121 TEST_CASE_FIXTURE(SimpleMirrorPadReflectFixture, "ParseMirrorPadRelfect")
122 {
123     RunTest<2, armnn::DataType::Float32>
124         (0,
125          {{ "inputTensor",  { 1.0f, 2.0f, 3.0f,
126                               4.0f, 5.0f, 6.0f,
127                               7.0f, 8.0f, 9.0f }}},
128 
129          {{ "outputTensor", { 9.0f, 8.0f, 7.0f, 8.0f, 9.0f, 8.0f, 7.0f,
130                               6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f,
131                               3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f,
132                               6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f,
133                               9.0f, 8.0f, 7.0f, 8.0f, 9.0f, 8.0f, 7.0f,
134                               6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f,
135                               3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f }}});
136 }
137 
138 }
139