1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_Clip") 10 { 11 struct ClipMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12 { ClipMainFixtureClipMainFixture13 ClipMainFixture(std::string min, std::string max) 14 { 15 m_Prototext = R"( 16 ir_version: 3 17 producer_name: "CNTK" 18 producer_version: "2.5.1" 19 domain: "ai.cntk" 20 model_version: 1 21 graph { 22 name: "CNTKGraph" 23 input { 24 name: "Input" 25 type { 26 tensor_type { 27 elem_type: 1 28 shape { 29 dim { 30 dim_value: 5 31 } 32 } 33 } 34 } 35 } 36 node { 37 input: "Input" 38 input:")" + min + R"(" 39 input:")" + max + R"(" 40 output: "Output" 41 name: "ActivationLayer" 42 op_type: "Clip" 43 } 44 output { 45 name: "Output" 46 type { 47 tensor_type { 48 elem_type: 1 49 shape { 50 dim { 51 dim_value: 5 52 } 53 } 54 } 55 } 56 } 57 } 58 opset_import { 59 version: 7 60 })"; 61 Setup(); 62 } 63 }; 64 65 struct ClipAttributeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 66 { ClipAttributeFixtureClipAttributeFixture67 ClipAttributeFixture(std::string min, std::string max) 68 { 69 m_Prototext = R"( 70 ir_version: 3 71 producer_name: "CNTK" 72 producer_version: "2.5.1" 73 domain: "ai.cntk" 74 model_version: 1 75 graph { 76 name: "CNTKGraph" 77 input { 78 name: "Input" 79 type { 80 tensor_type { 81 elem_type: 1 82 shape { 83 dim { 84 dim_value: 5 85 } 86 } 87 } 88 } 89 } 90 node { 91 input: "Input" 92 output: "Output" 93 name: "ActivationLayer" 94 op_type: "Clip" 95 attribute { 96 name: "min" 97 f: )" + min + R"( 98 type: FLOAT 99 } 100 attribute { 101 name: "max" 102 f: )" + max + R"( 103 type: FLOAT 104 } 105 } 106 output { 107 name: "Output" 108 type { 109 tensor_type { 110 elem_type: 1 111 shape { 112 dim { 113 dim_value: 5 114 } 115 } 116 } 117 } 118 } 119 } 120 opset_import { 121 version: 7 122 })"; 123 Setup(); 124 } 125 }; 126 127 struct ClipFixture : ClipMainFixture 128 { ClipFixtureClipFixture129 ClipFixture() : ClipMainFixture("2", "3.5") {} 130 }; 131 132 TEST_CASE_FIXTURE(ClipFixture, "ValidClipTest") 133 { 134 RunTest<1>({{"Input", { -1.5f, 1.25f, 3.5f, 8.0, 2.5}}}, 135 {{ "Output", { 2.0f, 2.0f, 3.5f, 3.5, 2.5}}}); 136 } 137 138 struct ClipNoMaxInputFixture : ClipMainFixture 139 { ClipNoMaxInputFixtureClipNoMaxInputFixture140 ClipNoMaxInputFixture() : ClipMainFixture("0", std::string()) {} 141 }; 142 143 TEST_CASE_FIXTURE(ClipNoMaxInputFixture, "ValidNoMaxInputClipTest") 144 { 145 RunTest<1>({{"Input", { -1.5f, -5.25f, -0.5f, 8.0f, std::numeric_limits<float>::max() }}}, 146 {{ "Output", { 0.0f, 0.0f, 0.0f, 8.0f, std::numeric_limits<float>::max() }}}); 147 } 148 149 struct ClipNoMinInputFixture : ClipMainFixture 150 { ClipNoMinInputFixtureClipNoMinInputFixture151 ClipNoMinInputFixture() : ClipMainFixture(std::string(), "6") {} 152 }; 153 154 TEST_CASE_FIXTURE(ClipNoMinInputFixture, "ValidNoMinInputClipTest") 155 { 156 RunTest<1>({{"Input", { std::numeric_limits<float>::lowest(), -5.25f, -0.5f, 8.0f, 200.0f }}}, 157 {{ "Output", { std::numeric_limits<float>::lowest(), -5.25f, -0.5f, 6.0f, 6.0f }}}); 158 } 159 160 struct ClipNoInputFixture : ClipMainFixture 161 { ClipNoInputFixtureClipNoInputFixture162 ClipNoInputFixture() : ClipMainFixture(std::string(), std::string()) {} 163 }; 164 165 TEST_CASE_FIXTURE(ClipNoInputFixture, "ValidNoInputClipTest") 166 { 167 RunTest<1>({{"Input", { std::numeric_limits<float>::lowest(), -1.25f, 3.5f, 8.0f, 168 std::numeric_limits<float>::max()}}}, 169 {{ "Output", { std::numeric_limits<float>::lowest(), -1.25f, 3.5f, 8.0f, 170 std::numeric_limits<float>::max()}}}); 171 } 172 173 struct ClipMinMaxAttributeFixture : ClipAttributeFixture 174 { ClipMinMaxAttributeFixtureClipMinMaxAttributeFixture175 ClipMinMaxAttributeFixture() : ClipAttributeFixture("2", "3.5") {} 176 }; 177 178 TEST_CASE_FIXTURE(ClipMinMaxAttributeFixture, "ValidClipAttributeTest") 179 { 180 RunTest<1>({{ "Input", { -1.5f, 1.25f, 3.5f, 8.0, 2.5}}}, 181 {{ "Output", { 2.0f, 2.0f, 3.5f, 3.5, 2.5}}}); 182 } 183 184 } 185