1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <string> 17 #include <list> 18 #include <vector> 19 #include "common/common_test.h" 20 #include "frontend/parallel/strategy.h" 21 22 namespace mindspore { 23 namespace parallel { 24 25 class TestStrategy : public UT::Common { 26 public: 27 TestStrategy() {} 28 29 void SetUp() {} 30 void TearDown() {} 31 }; 32 33 TEST_F(TestStrategy, GetInputNumber) { 34 int32_t number = 2; 35 int32_t stage = 1; 36 Dimensions dimension1 = {2, 4}; 37 Dimensions dimension2 = {2, 2}; 38 Strategys inputs = {dimension1, dimension2}; 39 40 Strategy strategy(stage, inputs); 41 int32_t number_test = strategy.GetInputNumber(); 42 ASSERT_EQ(number, number_test); 43 } 44 45 TEST_F(TestStrategy, GetInputStage) { 46 int32_t stage = 1; 47 Dimensions dimension1 = {2, 4}; 48 Dimensions dimension2 = {2, 2}; 49 Strategys inputs = {dimension1, dimension2}; 50 51 Strategy strategy(stage, inputs); 52 int32_t stage_test = strategy.GetInputStage(); 53 ASSERT_EQ(stage, stage_test); 54 } 55 56 TEST_F(TestStrategy, GetInputDim) { 57 int32_t stage = 1; 58 Dimensions dimension1 = {2, 4}; 59 Dimensions dimension2 = {2, 2}; 60 Strategys inputs = {dimension1, dimension2}; 61 62 Strategy strategy(stage, inputs); 63 Strategys inputs_test = strategy.GetInputDim(); 64 ASSERT_EQ(inputs, inputs_test); 65 } 66 67 TEST_F(TestStrategy, IsEqual) { 68 int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0; 69 Dimensions dimension1 = {8, 1}; 70 Dimensions dimension2 = {1, 8}; 71 Strategys inputs1 = {dimension1}; 72 Strategys inputs2 = {dimension1}; 73 Strategys inputs3 = {dimension2}; 74 Strategys inputs4 = {dimension1, dimension2}; 75 76 StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1); 77 StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2); 78 StrategyPtr stra3 = std::make_shared<Strategy>(stage3, inputs3); 79 StrategyPtr stra4 = std::make_shared<Strategy>(stage4, inputs4); 80 81 ASSERT_EQ(stra1->IsEqual(stra2), true); 82 ASSERT_EQ(stra1->IsEqual(stra3), false); 83 ASSERT_EQ(stra1->IsEqual(stra4), false); 84 } 85 } // namespace parallel 86 } // namespace mindspore 87