1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
19
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/status.h"
28
29 namespace mindspore {
30 namespace parallel {
31 #define MIN_SLICE_NUM 1
32
33 using Dimensions = Shape;
34 using Strategys = std::vector<Dimensions>;
35 class Strategy;
36 using StrategyPtr = std::shared_ptr<Strategy>;
37
38 class Strategy {
39 public:
Strategy(int64_t stage,Strategys inputs)40 Strategy(int64_t stage, Strategys inputs)
41 : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
42
Strategy(const Strategy & another_stra)43 Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
44 inputs_ = another_stra.GetInputDim();
45 internal_size_ = another_stra.GetInternalSize();
46 if (internal_size_ != 0) {
47 internal_stragies_ = another_stra.GetInternalStrategies();
48 } else {
49 internal_stragies_ = {};
50 }
51 }
52
53 ~Strategy() = default;
GetInputNumber()54 size_t GetInputNumber() const { return inputs_.size(); }
GetInputDim()55 Strategys GetInputDim() const { return inputs_; }
GetInputStage()56 int64_t GetInputStage() const { return stage_; }
ExpandInputDimFromOneToTwo()57 void ExpandInputDimFromOneToTwo() {
58 if (inputs_.size() == 1) {
59 inputs_.push_back(inputs_[0]);
60 }
61 }
ResetInputs(const Strategys & input)62 void ResetInputs(const Strategys &input) { inputs_ = input; }
GetInternalStrategies()63 std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
GetInternalSize()64 size_t GetInternalSize() const { return internal_size_; }
65
66 // TODO(Xiaoda): need fix for adapting 'CoverStrategy'
IsEqual(const StrategyPtr & another_stra)67 bool IsEqual(const StrategyPtr &another_stra) {
68 if (another_stra == nullptr) {
69 return false;
70 }
71 if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) {
72 return false;
73 }
74 return true;
75 }
76
77 // Include 'another_stra' into this strategy
CoverStrategy(const StrategyPtr & another_stra)78 void CoverStrategy(const StrategyPtr &another_stra) {
79 internal_stragies_.push_back(another_stra);
80 internal_size_++;
81 }
82
83 private:
84 const int64_t stage_;
85
86 // The size of Dimensions must equal to inputs_ tensor dimension.
87 Strategys inputs_;
88 size_t internal_size_ = 0;
89 std::vector<StrategyPtr> internal_stragies_;
90 };
91
NewStrategy(const int64_t stage,const Strategys & inputs)92 inline StrategyPtr NewStrategy(const int64_t stage, const Strategys &inputs) {
93 return std::make_shared<Strategy>(stage, inputs);
94 }
95 } // namespace parallel
96 } // namespace mindspore
97
98 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
99