1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "frameworks/native/ops/hswish_builder.h"
17
18 #include "ops_test.h"
19
20 using namespace testing;
21 using namespace testing::ext;
22 using namespace OHOS::NeuralNetworkRuntime::Ops;
23
24 namespace OHOS {
25 namespace NeuralNetworkRuntime {
26 namespace UnitTest {
27 class HswishBuilderTest : public OpsTest {
28 public:
29 void SetUp() override;
30 void TearDown() override;
31
32 protected:
33 HswishBuilder m_hswish;
34 std::vector<uint32_t> m_inputs {0};
35 std::vector<uint32_t> m_outputs {1};
36 std::vector<uint32_t> m_params {};
37 std::vector<int32_t> m_inputDim {1, 5, 1, 1};
38 std::vector<int32_t> m_outputDim {1, 5, 1, 1};
39 };
40
SetUp()41 void HswishBuilderTest::SetUp() {}
42
TearDown()43 void HswishBuilderTest::TearDown() {}
44
45 /**
46 * @tc.name: hswish_build_001
47 * @tc.desc: Verify that the build function returns a successful message.
48 * @tc.type: FUNC
49 */
50 HWTEST_F(HswishBuilderTest, hswish_build_001, TestSize.Level0)
51 {
52 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
53 SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
54
55 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
56 EXPECT_EQ(OH_NN_SUCCESS, ret);
57 }
58
59 /**
60 * @tc.name: hswish_build_002
61 * @tc.desc: Verify that the build function returns a failed message with true m_isBuild.
62 * @tc.type: FUNC
63 */
64 HWTEST_F(HswishBuilderTest, hswish_build_002, TestSize.Level0)
65 {
66 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
67 SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
68
69 EXPECT_EQ(OH_NN_SUCCESS, m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors));
70 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
71 EXPECT_EQ(OH_NN_OPERATION_FORBIDDEN, ret);
72 }
73
74 /**
75 * @tc.name: hswish_build_003
76 * @tc.desc: Verify that the build function returns a failed message with invalided input.
77 * @tc.type: FUNC
78 */
79 HWTEST_F(HswishBuilderTest, hswish_build_003, TestSize.Level0)
80 {
81 m_inputs = {0, 1};
82 m_outputs = {2};
83
84 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
85 SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
86
87 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
88 EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
89 }
90
91 /**
92 * @tc.name: hswish_build_004
93 * @tc.desc: Verify that the build function returns a failed message with invalided output.
94 * @tc.type: FUNC
95 */
96 HWTEST_F(HswishBuilderTest, hswish_build_004, TestSize.Level0)
97 {
98 std::vector<uint32_t> m_outputs = {1, 2};
99
100 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
101 SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
102
103 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
104 EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
105 }
106
107 /**
108 * @tc.name: hswish_build_005
109 * @tc.desc: Verify that the build function returns a failed message with empty allTensor.
110 * @tc.type: FUNC
111 */
112 HWTEST_F(HswishBuilderTest, hswish_build_005, TestSize.Level0)
113 {
114 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputs, m_outputs, m_allTensors);
115 EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
116 }
117
118 /**
119 * @tc.name: hswish_build_006
120 * @tc.desc: Verify that the build function returns a failed message without output tensor.
121 * @tc.type: FUNC
122 */
123 HWTEST_F(HswishBuilderTest, hswish_build_006, TestSize.Level0)
124 {
125 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
126
127 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputs, m_allTensors);
128 EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
129 }
130
131 /**
132 * @tc.name: hswish_build_007
133 * @tc.desc: Verify that the build function returns a failed message with a virtual parameter.
134 * @tc.type: FUNC
135 */
136 HWTEST_F(HswishBuilderTest, hswish_build_007, TestSize.Level0)
137 {
138 std::vector<uint32_t> m_params = {2};
139 std::vector<int32_t> paramDim = {};
140
141 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
142 SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
143 std::shared_ptr<NNTensor> paramTensor;
144 paramTensor = TransToNNTensor(OH_NN_INT32, paramDim, nullptr, OH_NN_TENSOR);
145 m_allTensors.emplace_back(paramTensor);
146
147 OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
148 EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
149 }
150
151 /**
152 * @tc.name: hswish_getprimitive_001
153 * @tc.desc: Verify that the getPrimitive function returns a successful message
154 * @tc.type: FUNC
155 */
156 HWTEST_F(HswishBuilderTest, hswish_getprimitive_001, TestSize.Level0)
157 {
158 SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
159 SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
160
161 EXPECT_EQ(OH_NN_SUCCESS, m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors));
162 LiteGraphPrimitvePtr primitive = m_hswish.GetPrimitive();
163 LiteGraphPrimitvePtr expectPrimitive(nullptr, DestroyLiteGraphPrimitive);
164 EXPECT_NE(expectPrimitive, primitive);
165
166 mindspore::lite::ActivationType activationType = mindspore::lite::ACTIVATION_TYPE_HSWISH;
167 auto returnValue = mindspore::lite::MindIR_Activation_GetActivationType(primitive.get());
168 EXPECT_EQ(returnValue, activationType);
169 }
170
171 /**
172 * @tc.name: hswish_getprimitive_001
173 * @tc.desc: Verify that the getPrimitive function returns a failed message without build.
174 * @tc.type: FUNC
175 */
176 HWTEST_F(HswishBuilderTest, hswish_getprimitive_002, TestSize.Level0)
177 {
178 HswishBuilder hswish;
179 LiteGraphPrimitvePtr primitive = m_hswish.GetPrimitive();
180 LiteGraphPrimitvePtr expectPrimitive(nullptr, DestroyLiteGraphPrimitive);
181 EXPECT_EQ(expectPrimitive, primitive);
182 }
183 }
184 }
185 }