1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <string> 17 #include <list> 18 #include <vector> 19 #include "common/common_test.h" 20 #include "frontend/parallel/strategy.h" 21 #include "frontend/parallel/ops_info/virtual_dataset_info.h" 22 #include "frontend/parallel/device_manager.h" 23 #include "frontend/parallel/step_parallel.h" 24 25 namespace mindspore { 26 namespace parallel { 27 28 class VirtualDatasetInfo; 29 using VirtualDatasetInfoPtr = std::shared_ptr<VirtualDatasetInfo>; 30 VirtualDatasetInfoPtr virtual_dataset; 31 32 class TestVirtualDatasetInfo : public UT::Common { 33 public: 34 TestVirtualDatasetInfo() {} 35 void SetUp(); 36 void TearDown() {} 37 }; 38 39 void TestVirtualDatasetInfo::SetUp() { 40 RankList dev_list; 41 42 for (int32_t i = 0; i < 130; i++) { 43 dev_list.push_back(i); 44 } 45 46 RankList stage_map; 47 stage_map.push_back(16); 48 stage_map.push_back(114); 49 50 int32_t local_dev = 0; 51 52 // create a new g_device_manager 53 g_device_manager = std::make_shared<DeviceManager>(); 54 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); 55 56 std::unordered_map<std::string, ValuePtr> attr; 57 58 Shapes inputs_shape = {{128, 32}, {1280, 320}, {12800, 3200}}; 59 Shapes outputs_shape = {{128, 32}, {1280, 320}, {12800, 3200}}; 60 61 virtual_dataset = std::make_shared<VirtualDatasetInfo>("virtual_dataset_info", inputs_shape, outputs_shape, attr); 62 } 63 64 TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) { 65 Strategys inputs = {{16, 1}, {16, 1}, {16, 1}}; 66 StrategyPtr strategy = NewStrategy(0, inputs); 67 virtual_dataset->Init(strategy); 68 Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape(); 69 70 Shape expect = {16, 1}; 71 ASSERT_EQ(dev_matrix_shape, expect); 72 } 73 74 TEST_F(TestVirtualDatasetInfo, GetForwardOp1) { 75 Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; 76 StrategyPtr strategy = NewStrategy(0, inputs); 77 78 virtual_dataset->Init(strategy); 79 OperatorVector forward_op = virtual_dataset->forward_op(); 80 size_t size = forward_op.size(); 81 82 ASSERT_EQ(size, 0); 83 } 84 85 TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) { 86 Strategys inputs = {{8, 1}, {8, 1}, {8, 1}}; 87 StrategyPtr strategy = NewStrategy(0, inputs); 88 89 virtual_dataset->Init(strategy); 90 MirrorOps mirror_ops = virtual_dataset->mirror_ops(); 91 92 size_t size = mirror_ops.size(); 93 // no broadcast 94 ASSERT_EQ(size, 0); 95 // ASSERT_EQ(size, 3); 96 } 97 98 } // namespace parallel 99 } // namespace mindspore 100