• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_Reshape")
12 {
13 struct ReshapeFixture : public ParserFlatbuffersSerializeFixture
14 {
ReshapeFixtureReshapeFixture15     explicit ReshapeFixture(const std::string &inputShape,
16                             const std::string &targetShape,
17                             const std::string &outputShape,
18                             const std::string &dataType)
19     {
20         m_JsonString = R"(
21         {
22                 inputIds: [0],
23                 outputIds: [2],
24                 layers: [
25                 {
26                     layer_type: "InputLayer",
27                     layer: {
28                           base: {
29                                 layerBindingId: 0,
30                                 base: {
31                                     index: 0,
32                                     layerName: "InputLayer",
33                                     layerType: "Input",
34                                     inputSlots: [{
35                                         index: 0,
36                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
37                                     }],
38                                     outputSlots: [ {
39                                         index: 0,
40                                         tensorInfo: {
41                                             dimensions: )" + inputShape + R"(,
42                                             dataType: )" + dataType + R"(
43                                             }}]
44                                     }
45                     }}},
46                     {
47                     layer_type: "ReshapeLayer",
48                     layer: {
49                           base: {
50                                index: 1,
51                                layerName: "ReshapeLayer",
52                                layerType: "Reshape",
53                                inputSlots: [{
54                                       index: 0,
55                                       connection: {sourceLayerIndex:0, outputSlotIndex:0 },
56                                }],
57                                outputSlots: [ {
58                                       index: 0,
59                                       tensorInfo: {
60                                            dimensions: )" + inputShape + R"(,
61                                            dataType: )" + dataType + R"(
62 
63                                }}]},
64                           descriptor: {
65                                targetShape: )" + targetShape + R"(,
66                                }
67 
68                     }},
69                     {
70                     layer_type: "OutputLayer",
71                     layer: {
72                         base:{
73                               layerBindingId: 2,
74                               base: {
75                                     index: 2,
76                                     layerName: "OutputLayer",
77                                     layerType: "Output",
78                                     inputSlots: [{
79                                         index: 0,
80                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
81                                     }],
82                                     outputSlots: [ {
83                                         index: 0,
84                                         tensorInfo: {
85                                             dimensions: )" + outputShape + R"(,
86                                             dataType: )" + dataType + R"(
87                                         },
88                                 }],
89                             }}},
90                 }]
91          }
92      )";
93      SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
94     }
95 };
96 
97 struct SimpleReshapeFixture : ReshapeFixture
98 {
SimpleReshapeFixtureSimpleReshapeFixture99     SimpleReshapeFixture() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]",
100                                             "QuantisedAsymm8") {}
101 };
102 
103 struct SimpleReshapeFixture2 : ReshapeFixture
104 {
SimpleReshapeFixture2SimpleReshapeFixture2105     SimpleReshapeFixture2() : ReshapeFixture("[ 2, 2, 1, 1 ]",
106                                              "[ 2, 2, 1, 1 ]",
107                                              "[ 2, 2, 1, 1 ]",
108                                              "Float32") {}
109 };
110 
111 TEST_CASE_FIXTURE(SimpleReshapeFixture, "ReshapeQuantisedAsymm8")
112 {
113     RunTest<2, armnn::DataType::QAsymmU8>(0,
114                                                 { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
115                                                 { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
116 }
117 
118 TEST_CASE_FIXTURE(SimpleReshapeFixture2, "ReshapeFloat32")
119 {
120     RunTest<4, armnn::DataType::Float32>(0,
121                                         { 111, 85, 226, 3 },
122                                         { 111, 85, 226, 3 });
123 }
124 
125 
126 }