• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "frameworks/native/ops/hswish_builder.h"
17 
18 #include "ops_test.h"
19 
20 using namespace testing;
21 using namespace testing::ext;
22 using namespace OHOS::NeuralNetworkRuntime::Ops;
23 
24 namespace OHOS {
25 namespace NeuralNetworkRuntime {
26 namespace UnitTest {
27 class HswishBuilderTest : public OpsTest {
28 public:
29     void SetUp() override;
30     void TearDown() override;
31 
32 protected:
33     HswishBuilder m_hswish;
34     std::vector<uint32_t> m_inputs {0};
35     std::vector<uint32_t> m_outputs {1};
36     std::vector<uint32_t> m_params {};
37     std::vector<int32_t> m_inputDim {1, 5, 1, 1};
38     std::vector<int32_t> m_outputDim {1, 5, 1, 1};
39 };
40 
SetUp()41 void HswishBuilderTest::SetUp() {}
42 
TearDown()43 void HswishBuilderTest::TearDown() {}
44 
45 /**
46  * @tc.name: hswish_build_001
47  * @tc.desc: Verify that the build function returns a successful message.
48  * @tc.type: FUNC
49  */
50 HWTEST_F(HswishBuilderTest, hswish_build_001, TestSize.Level0)
51 {
52     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
53     SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
54 
55     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
56     EXPECT_EQ(OH_NN_SUCCESS, ret);
57 }
58 
59 /**
60  * @tc.name: hswish_build_002
61  * @tc.desc: Verify that the build function returns a failed message with true m_isBuild.
62  * @tc.type: FUNC
63  */
64 HWTEST_F(HswishBuilderTest, hswish_build_002, TestSize.Level0)
65 {
66     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
67     SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
68 
69     EXPECT_EQ(OH_NN_SUCCESS, m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors));
70     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
71     EXPECT_EQ(OH_NN_OPERATION_FORBIDDEN, ret);
72 }
73 
74 /**
75  * @tc.name: hswish_build_003
76  * @tc.desc: Verify that the build function returns a failed message with invalided input.
77  * @tc.type: FUNC
78  */
79 HWTEST_F(HswishBuilderTest, hswish_build_003, TestSize.Level0)
80 {
81     m_inputs = {0, 1};
82     m_outputs = {2};
83 
84     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
85     SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
86 
87     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
88     EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
89 }
90 
91 /**
92  * @tc.name: hswish_build_004
93  * @tc.desc: Verify that the build function returns a failed message with invalided output.
94  * @tc.type: FUNC
95  */
96 HWTEST_F(HswishBuilderTest, hswish_build_004, TestSize.Level0)
97 {
98     std::vector<uint32_t> m_outputs = {1, 2};
99 
100     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
101     SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
102 
103     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
104     EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
105 }
106 
107 /**
108  * @tc.name: hswish_build_005
109  * @tc.desc: Verify that the build function returns a failed message with empty allTensor.
110  * @tc.type: FUNC
111  */
112 HWTEST_F(HswishBuilderTest, hswish_build_005, TestSize.Level0)
113 {
114     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputs, m_outputs, m_allTensors);
115     EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
116 }
117 
118 /**
119  * @tc.name: hswish_build_006
120  * @tc.desc: Verify that the build function returns a failed message without output tensor.
121  * @tc.type: FUNC
122  */
123 HWTEST_F(HswishBuilderTest, hswish_build_006, TestSize.Level0)
124 {
125     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
126 
127     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputs, m_allTensors);
128     EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
129 }
130 
131 /**
132  * @tc.name: hswish_build_007
133  * @tc.desc: Verify that the build function returns a failed message with a virtual parameter.
134  * @tc.type: FUNC
135  */
136 HWTEST_F(HswishBuilderTest, hswish_build_007, TestSize.Level0)
137 {
138     std::vector<uint32_t> m_params = {2};
139     std::vector<int32_t> paramDim = {};
140 
141     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
142     SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
143     std::shared_ptr<NNTensor> paramTensor;
144     paramTensor = TransToNNTensor(OH_NN_INT32, paramDim, nullptr, OH_NN_TENSOR);
145     m_allTensors.emplace_back(paramTensor);
146 
147     OH_NN_ReturnCode ret = m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors);
148     EXPECT_EQ(OH_NN_INVALID_PARAMETER, ret);
149 }
150 
151 /**
152  * @tc.name: hswish_getprimitive_001
153  * @tc.desc: Verify that the getPrimitive function returns a successful message
154  * @tc.type: FUNC
155  */
156 HWTEST_F(HswishBuilderTest, hswish_getprimitive_001, TestSize.Level0)
157 {
158     SaveInputTensor(m_inputs, OH_NN_FLOAT32, m_inputDim, nullptr);
159     SaveOutputTensor(m_outputs, OH_NN_FLOAT32, m_outputDim, nullptr);
160 
161     EXPECT_EQ(OH_NN_SUCCESS, m_hswish.Build(m_params, m_inputsIndex, m_outputsIndex, m_allTensors));
162     LiteGraphPrimitvePtr primitive = m_hswish.GetPrimitive();
163     LiteGraphPrimitvePtr expectPrimitive(nullptr, DestroyLiteGraphPrimitive);
164     EXPECT_NE(expectPrimitive, primitive);
165 
166     mindspore::lite::ActivationType activationType = mindspore::lite::ACTIVATION_TYPE_HSWISH;
167     auto returnValue = mindspore::lite::MindIR_Activation_GetActivationType(primitive.get());
168     EXPECT_EQ(returnValue, activationType);
169 }
170 
171 /**
172  * @tc.name: hswish_getprimitive_001
173  * @tc.desc: Verify that the getPrimitive function returns a failed message without build.
174  * @tc.type: FUNC
175  */
176 HWTEST_F(HswishBuilderTest, hswish_getprimitive_002, TestSize.Level0)
177 {
178     HswishBuilder hswish;
179     LiteGraphPrimitvePtr primitive = m_hswish.GetPrimitive();
180     LiteGraphPrimitvePtr expectPrimitive(nullptr, DestroyLiteGraphPrimitive);
181     EXPECT_EQ(expectPrimitive, primitive);
182 }
183 }
184 }
185 }