1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct BiasAddFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
BiasAddFixtureBiasAddFixture14 explicit BiasAddFixture(const std::string& dataFormat)
15 {
16 m_Prototext = R"(
17 node {
18 name: "graphInput"
19 op: "Placeholder"
20 attr {
21 key: "dtype"
22 value {
23 type: DT_FLOAT
24 }
25 }
26 attr {
27 key: "shape"
28 value {
29 shape {
30 }
31 }
32 }
33 }
34 node {
35 name: "bias"
36 op: "Const"
37 attr {
38 key: "dtype"
39 value {
40 type: DT_FLOAT
41 }
42 }
43 attr {
44 key: "value"
45 value {
46 tensor {
47 dtype: DT_FLOAT
48 tensor_shape {
49 dim {
50 size: 3
51 }
52 }
53 float_val: 1
54 float_val: 2
55 float_val: 3
56 }
57 }
58 }
59 }
60 node {
61 name: "biasAdd"
62 op : "BiasAdd"
63 input: "graphInput"
64 input: "bias"
65 attr {
66 key: "T"
67 value {
68 type: DT_FLOAT
69 }
70 }
71 attr {
72 key: "data_format"
73 value {
74 s: ")" + dataFormat + R"("
75 }
76 }
77 }
78 )";
79
80 SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd");
81 }
82 };
83
84 struct BiasAddFixtureNCHW : BiasAddFixture
85 {
BiasAddFixtureNCHWBiasAddFixtureNCHW86 BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {}
87 };
88
89 struct BiasAddFixtureNHWC : BiasAddFixture
90 {
BiasAddFixtureNHWCBiasAddFixtureNHWC91 BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {}
92 };
93
BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW,BiasAddFixtureNCHW)94 BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW)
95 {
96 RunTest<4>(std::vector<float>(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 });
97 }
98
BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC,BiasAddFixtureNHWC)99 BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC)
100 {
101 RunTest<4>(std::vector<float>(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 });
102 }
103
104 BOOST_AUTO_TEST_SUITE_END()
105