• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_InstanceNormalization")
12 {
13 struct InstanceNormalizationFixture : public ParserFlatbuffersSerializeFixture
14 {
InstanceNormalizationFixtureInstanceNormalizationFixture15     explicit InstanceNormalizationFixture(const std::string &inputShape,
16                                           const std::string &outputShape,
17                                           const std::string &gamma,
18                                           const std::string &beta,
19                                           const std::string &epsilon,
20                                           const std::string &dataType,
21                                           const std::string &dataLayout)
22     {
23         m_JsonString = R"(
24     {
25         inputIds: [0],
26         outputIds: [2],
27         layers: [
28            {
29             layer_type: "InputLayer",
30             layer: {
31                 base: {
32                     layerBindingId: 0,
33                     base: {
34                         index: 0,
35                         layerName: "InputLayer",
36                         layerType: "Input",
37                         inputSlots: [{
38                             index: 0,
39                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
40                             }],
41                         outputSlots: [{
42                             index: 0,
43                             tensorInfo: {
44                                 dimensions: )" + inputShape + R"(,
45                                 dataType: ")" + dataType + R"(",
46                                 quantizationScale: 0.5,
47                                 quantizationOffset: 0
48                                 },
49                             }]
50                         },
51                     }
52                 },
53             },
54         {
55         layer_type: "InstanceNormalizationLayer",
56         layer : {
57             base: {
58                 index:1,
59                 layerName: "InstanceNormalizationLayer",
60                 layerType: "InstanceNormalization",
61                 inputSlots: [{
62                         index: 0,
63                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
64                    }],
65                 outputSlots: [{
66                     index: 0,
67                     tensorInfo: {
68                         dimensions: )" + outputShape + R"(,
69                         dataType: ")" + dataType + R"("
70                     },
71                     }],
72                 },
73             descriptor: {
74                 dataLayout: ")" + dataLayout + R"(",
75                 gamma: ")" + gamma + R"(",
76                 beta: ")" + beta + R"(",
77                 eps: )" + epsilon + R"(
78                 },
79             },
80         },
81         {
82         layer_type: "OutputLayer",
83         layer: {
84             base:{
85                 layerBindingId: 0,
86                 base: {
87                     index: 2,
88                     layerName: "OutputLayer",
89                     layerType: "Output",
90                     inputSlots: [{
91                         index: 0,
92                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
93                     }],
94                     outputSlots: [ {
95                         index: 0,
96                         tensorInfo: {
97                             dimensions: )" + outputShape + R"(,
98                             dataType: ")" + dataType + R"("
99                         },
100                     }],
101                 }
102             }},
103         }]
104     }
105 )";
106         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
107     }
108 };
109 
110 struct InstanceNormalizationFloat32Fixture : InstanceNormalizationFixture
111 {
InstanceNormalizationFloat32FixtureInstanceNormalizationFloat32Fixture112     InstanceNormalizationFloat32Fixture():InstanceNormalizationFixture("[ 2, 2, 2, 2 ]",
113                                                                        "[ 2, 2, 2, 2 ]",
114                                                                        "1.0",
115                                                                        "0.0",
116                                                                        "0.0001",
117                                                                        "Float32",
118                                                                        "NHWC") {}
119 };
120 
121 TEST_CASE_FIXTURE(InstanceNormalizationFloat32Fixture, "InstanceNormalizationFloat32")
122 {
123     RunTest<4, armnn::DataType::Float32>(
124         0,
125          {
126              0.f,  1.f,
127              0.f,  2.f,
128 
129              0.f,  2.f,
130              0.f,  4.f,
131 
132              1.f, -1.f,
133             -1.f,  2.f,
134 
135             -1.f, -2.f,
136              1.f,  4.f
137         },
138         {
139              0.0000000f, -1.1470304f,
140              0.0000000f, -0.2294061f,
141 
142              0.0000000f, -0.2294061f,
143              0.0000000f,  1.6058424f,
144 
145              0.9999501f, -0.7337929f,
146             -0.9999501f,  0.5241377f,
147 
148             -0.9999501f, -1.1531031f,
149              0.9999501f,  1.3627582f
150         });
151 }
152 
153 }
154