• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <string>
17 #include <list>
18 #include <vector>
19 #include "common/common_test.h"
20 #include "frontend/parallel/strategy.h"
21 #include "frontend/parallel/ops_info/virtual_dataset_info.h"
22 #include "frontend/parallel/device_manager.h"
23 #include "frontend/parallel/step_parallel.h"
24 
25 namespace mindspore {
26 namespace parallel {
27 
28 class VirtualDatasetInfo;
29 using VirtualDatasetInfoPtr = std::shared_ptr<VirtualDatasetInfo>;
30 VirtualDatasetInfoPtr virtual_dataset;
31 
32 class TestVirtualDatasetInfo : public UT::Common {
33  public:
34   TestVirtualDatasetInfo() {}
35   void SetUp();
36   void TearDown() {}
37 };
38 
39 void TestVirtualDatasetInfo::SetUp() {
40   RankList dev_list;
41 
42   for (int32_t i = 0; i < 130; i++) {
43     dev_list.push_back(i);
44   }
45 
46   RankList stage_map;
47   stage_map.push_back(16);
48   stage_map.push_back(114);
49 
50   int32_t local_dev = 0;
51 
52   // create a new g_device_manager
53   g_device_manager = std::make_shared<DeviceManager>();
54   g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
55 
56   std::unordered_map<std::string, ValuePtr> attr;
57 
58   Shapes inputs_shape = {{128, 32}, {1280, 320}, {12800, 3200}};
59   Shapes outputs_shape = {{128, 32}, {1280, 320}, {12800, 3200}};
60 
61   virtual_dataset = std::make_shared<VirtualDatasetInfo>("virtual_dataset_info", inputs_shape, outputs_shape, attr);
62 }
63 
64 TEST_F(TestVirtualDatasetInfo, InferDevMatrixShape1) {
65   Strategys inputs = {{16, 1}, {16, 1}, {16, 1}};
66   StrategyPtr strategy = NewStrategy(0, inputs);
67   virtual_dataset->Init(strategy);
68   Shape dev_matrix_shape = virtual_dataset->dev_matrix_shape();
69 
70   Shape expect = {16, 1};
71   ASSERT_EQ(dev_matrix_shape, expect);
72 }
73 
74 TEST_F(TestVirtualDatasetInfo, GetForwardOp1) {
75   Strategys inputs = {{8, 1}, {8, 1}, {8, 1}};
76   StrategyPtr strategy = NewStrategy(0, inputs);
77 
78   virtual_dataset->Init(strategy);
79   OperatorVector forward_op = virtual_dataset->forward_op();
80   size_t size = forward_op.size();
81 
82   ASSERT_EQ(size, 0);
83 }
84 
85 TEST_F(TestVirtualDatasetInfo, GetMirrorOPs1) {
86   Strategys inputs = {{8, 1}, {8, 1}, {8, 1}};
87   StrategyPtr strategy = NewStrategy(0, inputs);
88 
89   virtual_dataset->Init(strategy);
90   MirrorOps mirror_ops = virtual_dataset->mirror_ops();
91 
92   size_t size = mirror_ops.size();
93   // no broadcast
94   ASSERT_EQ(size, 0);
95   // ASSERT_EQ(size, 3);
96 }
97 
98 }  // namespace parallel
99 }  // namespace mindspore
100