1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14
15 struct GatherFixture : public ParserFlatbuffersSerializeFixture
16 {
GatherFixtureGatherFixture17 explicit GatherFixture(const std::string& inputShape,
18 const std::string& indicesShape,
19 const std::string& input1Content,
20 const std::string& outputShape,
21 const std::string& axis,
22 const std::string dataType,
23 const std::string constDataType)
24 {
25 m_JsonString = R"(
26 {
27 inputIds: [0],
28 outputIds: [3],
29 layers: [
30 {
31 layer_type: "InputLayer",
32 layer: {
33 base: {
34 layerBindingId: 0,
35 base: {
36 index: 0,
37 layerName: "InputLayer",
38 layerType: "Input",
39 inputSlots: [{
40 index: 0,
41 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
42 }],
43 outputSlots: [ {
44 index: 0,
45 tensorInfo: {
46 dimensions: )" + inputShape + R"(,
47 dataType: )" + dataType + R"(
48 }}]
49 }
50 }}},
51 {
52 layer_type: "ConstantLayer",
53 layer: {
54 base: {
55 index:1,
56 layerName: "ConstantLayer",
57 layerType: "Constant",
58 outputSlots: [ {
59 index: 0,
60 tensorInfo: {
61 dimensions: )" + indicesShape + R"(,
62 dataType: "Signed32",
63 },
64 }],
65 },
66 input: {
67 info: {
68 dimensions: )" + indicesShape + R"(,
69 dataType: )" + dataType + R"(
70 },
71 data_type: )" + constDataType + R"(,
72 data: {
73 data: )" + input1Content + R"(,
74 } }
75 },},
76 {
77 layer_type: "GatherLayer",
78 layer: {
79 base: {
80 index: 2,
81 layerName: "GatherLayer",
82 layerType: "Gather",
83 inputSlots: [
84 {
85 index: 0,
86 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
87 },
88 {
89 index: 1,
90 connection: {sourceLayerIndex:1, outputSlotIndex:0 }
91 }],
92 outputSlots: [ {
93 index: 0,
94 tensorInfo: {
95 dimensions: )" + outputShape + R"(,
96 dataType: )" + dataType + R"(
97
98 }}]},
99 descriptor: {
100 axis: )" + axis + R"(
101 }
102 }},
103 {
104 layer_type: "OutputLayer",
105 layer: {
106 base:{
107 layerBindingId: 0,
108 base: {
109 index: 3,
110 layerName: "OutputLayer",
111 layerType: "Output",
112 inputSlots: [{
113 index: 0,
114 connection: {sourceLayerIndex:2, outputSlotIndex:0 },
115 }],
116 outputSlots: [ {
117 index: 0,
118 tensorInfo: {
119 dimensions: )" + outputShape + R"(,
120 dataType: )" + dataType + R"(
121 },
122 }],
123 }}},
124 }]
125 } )";
126
127 Setup();
128 }
129 };
130
131 struct SimpleGatherFixtureFloat32 : GatherFixture
132 {
SimpleGatherFixtureFloat32SimpleGatherFixtureFloat32133 SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
134 "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
135 };
136
BOOST_FIXTURE_TEST_CASE(GatherFloat32,SimpleGatherFixtureFloat32)137 BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
138 {
139 RunTest<4, armnn::DataType::Float32>(0,
140 {{"InputLayer", { 1, 2, 3,
141 4, 5, 6,
142 7, 8, 9,
143 10, 11, 12,
144 13, 14, 15,
145 16, 17, 18 }}},
146 {{"OutputLayer", { 7, 8, 9,
147 10, 11, 12,
148 13, 14, 15,
149 16, 17, 18,
150 7, 8, 9,
151 10, 11, 12,
152 13, 14, 15,
153 16, 17, 18,
154 7, 8, 9,
155 10, 11, 12,
156 1, 2, 3,
157 4, 5, 6 }}});
158 }
159
160 BOOST_AUTO_TEST_SUITE_END()
161
162