1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <string> 18 #include <list> 19 #include <vector> 20 #include "common/common_test.h" 21 #include "frontend/parallel/strategy.h" 22 #include "frontend/parallel/ops_info/activation_info.h" 23 #include "frontend/parallel/device_manager.h" 24 25 namespace mindspore { 26 namespace parallel { 27 28 class Activation; 29 class Softmax; 30 using ActivationPtr = std::shared_ptr<ActivationInfo>; 31 using SoftmaxPtr = std::shared_ptr<Softmax>; 32 ActivationPtr act_ptr_; 33 SoftmaxPtr soft_ptr_; 34 35 class TestActivation : public UT::Common { 36 public: 37 TestActivation() {} 38 void SetUp(); 39 void TearDown() {} 40 }; 41 42 void TestActivation::SetUp() { 43 RankList dev_list; 44 45 for (int32_t i = 0; i < 1050; i++) { 46 dev_list.push_back(i); 47 } 48 49 RankList stage_map; 50 stage_map.push_back(1024); 51 stage_map.push_back(26); 52 53 int32_t local_dev = 0; 54 55 // create a new g_device_manager 56 g_device_manager = std::make_shared<DeviceManager>(); 57 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); 58 59 ValuePtr relu = MakeValue(std::string("relu")); 60 std::unordered_map<std::string, ValuePtr> relu_attr = {{"activation_type", relu}}; 61 ValuePtr sm = MakeValue(std::string("softmax")); 62 ValuePtr axix = MakeValue(std::int64_t(2)); 63 std::unordered_map<std::string, ValuePtr> softmax_attr = {{"activation_type", sm}, {"axis", axix}}; 64 65 Shapes relu_inputs_shape = {{2, 4, 8, 16}}; 66 Shapes relu_outputs_shape = {{2, 4, 8, 16}}; 67 Shapes sm_inputs_shape = {{8, 8, 8, 16}}; 68 Shapes sm_outputs_shape = {{8, 8, 8, 16}}; 69 70 act_ptr_ = std::make_shared<ActivationInfo>("relu_info", relu_inputs_shape, relu_outputs_shape, relu_attr); 71 soft_ptr_ = std::make_shared<Softmax>("softmax_info", sm_inputs_shape, sm_outputs_shape, softmax_attr); 72 } 73 74 TEST_F(TestActivation, test_activation_strategies) { 75 ASSERT_EQ(act_ptr_->GenerateStrategies(0), Status::SUCCESS); 76 std::vector<std::shared_ptr<StrategyWithCost>> sc = act_ptr_->GetStrategyCost(); 77 for (const auto& swc : sc) { 78 ASSERT_NE(swc, nullptr); 79 ASSERT_GT(swc->cost_list.size(), 0); 80 StrategyPtr sp = swc->strategy_ptr; 81 ASSERT_NE(sp, nullptr); 82 Cost cost = *(swc->cost_list[0]); 83 84 act_ptr_->InitForCostModel(sp); 85 std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); 86 std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); 87 ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), 88 cost.computation_cost_); 89 ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), 90 cost.communication_cost_); 91 } 92 } 93 94 TEST_F(TestActivation, test_softmax_strategies) { 95 ASSERT_EQ(soft_ptr_->GenerateStrategies(0), Status::SUCCESS); 96 std::vector<std::shared_ptr<StrategyWithCost>> sc = soft_ptr_->GetStrategyCost(); 97 for (const auto& swc : sc) { 98 ASSERT_NE(swc, nullptr); 99 ASSERT_GT(swc->cost_list.size(), 0); 100 StrategyPtr sp = swc->strategy_ptr; 101 ASSERT_NE(sp, nullptr); 102 Cost cost = *(swc->cost_list[0]); 103 104 Strategys stra = sp->GetInputDim(); 105 ASSERT_GT(stra.size(), 0); 106 Dimensions input0_stra = stra[0]; 107 ASSERT_GT(input0_stra.size(), 2); 108 ASSERT_EQ(input0_stra[2], 1); 109 soft_ptr_->InitForCostModel(sp); 110 std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); 111 std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); 112 ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), 113 cost.computation_cost_); 114 ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), 115 cost.communication_cost_); 116 } 117 } 118 119 } // namespace parallel 120 } // namespace mindspore 121