1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <string>
17 #include <list>
18 #include <vector>
19 #include "common/common_test.h"
20 #include "frontend/parallel/strategy.h"
21 #include "frontend/parallel/ops_info/virtual_dataset_info.h"
22 #include "frontend/parallel/device_manager.h"
23 #include "frontend/parallel/step_parallel.h"
24
25 namespace mindspore {
26 namespace parallel {
27
28 class VirtualDatasetInfo;
29 using VirtualDatasetInfoPtr = std::shared_ptr<VirtualDatasetInfo>;
30 VirtualDatasetInfoPtr virtual_dataset;
31
32 class TestVirtualDatasetInfo : public UT::Common {
33 public:
TestVirtualDatasetInfo()34 TestVirtualDatasetInfo() {}
35 void SetUp();
TearDown()36 void TearDown() {}
37 };
38
SetUp()39 void TestVirtualDatasetInfo::SetUp() {
40 RankList dev_list;
41
42 for (int32_t i = 0; i < 130; i++) {
43 dev_list.push_back(i);
44 }
45
46 RankList stage_map;
47 stage_map.push_back(16);
48 stage_map.push_back(114);
49
50 int32_t local_dev = 0;
51
52 // create a new g_device_manager
53 g_device_manager = std::make_shared<DeviceManager>();
54 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
55
56 std::unordered_map<std::string, ValuePtr> attr;
57
58 Shapes inputs_shape = {{128, 32}, {1280, 320}, {12800, 3200}};
59 Shapes outputs_shape = {{128, 32}, {1280, 320}, {12800, 3200}};
60
61 virtual_dataset = std::make_shared<VirtualDatasetInfo>("virtual_dataset_info", inputs_shape, outputs_shape, attr);
62 }
63
TEST_F(TestVirtualDatasetInfo,InferDevMatrixShape1)64 TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
65 Strategys inputs = {{16, 1}, {16, 1}, {16, 1}};
66 StrategyPtr strategy = NewStrategy(0, inputs);
67 virtual_dataset->Init(strategy);
68 Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape();
69
70 Shape expect = {16, 1};
71 ASSERT_EQ(dev_matrix_shape, expect);
72 }
73
TEST_F(TestVirtualDatasetInfo,GetForwardOp1)74 TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
75 Strategys inputs = {{8, 1}, {8, 1}, {8, 1}};
76 StrategyPtr strategy = NewStrategy(0, inputs);
77
78 virtual_dataset->Init(strategy);
79 OperatorVector forward_op = virtual_dataset->forward_op();
80 size_t size = forward_op.size();
81
82 ASSERT_EQ(size, 0);
83 }
84
TEST_F(TestVirtualDatasetInfo,GetMirrorOPs1)85 TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) {
86 Strategys inputs = {{8, 1}, {8, 1}, {8, 1}};
87 StrategyPtr strategy = NewStrategy(0, inputs);
88
89 virtual_dataset->Init(strategy);
90 MirrorOps mirror_ops = virtual_dataset->mirror_ops();
91
92 size_t size = mirror_ops.size();
93 // no broadcast
94 ASSERT_EQ(size, 0);
95 // ASSERT_EQ(size, 3);
96 }
97
98 } // namespace parallel
99 } // namespace mindspore
100