• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_
19 
20 #include <cstdint>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include <utility>
25 
26 #include "frontend/parallel/status.h"
27 #include "include/common/utils/convert_utils.h"
28 
29 namespace mindspore {
30 namespace parallel {
31 using RankList = std::vector<int64_t>;
32 using Shape = std::vector<int64_t>;
33 using Shapes = std::vector<Shape>;
34 
35 class ShapeBase {
36  public:
ShapeBase(bool is_list)37   explicit ShapeBase(bool is_list) { is_list_ = is_list; }
38   virtual ~ShapeBase() = default;
is_list()39   bool is_list() const { return is_list_; }
40   virtual bool empty() const = 0;
41   virtual int64_t GetBatchValue() = 0;
42   virtual size_t size() = 0;
43   virtual std::vector<int64_t> GetValue() = 0;
44   virtual std::shared_ptr<ShapeBase> GetElement(int64_t idx) = 0;
45   virtual std::vector<std::vector<int64_t>> GetAllElements() = 0;
46   virtual void set_shape(const std::shared_ptr<ShapeBase> shape) = 0;
ToString()47   std::string ToString() {
48     std::ostringstream oss;
49     ConvertShapeToStr(&oss);
50     return oss.str();
51   }
ConvertShapeToStr(std::ostringstream * oss)52   virtual void ConvertShapeToStr(std::ostringstream *oss) { MS_LOG(WARNING) << "Please override this func"; }
53 
54  private:
55   bool is_list_;
56 };
57 
58 using ShapeBasePtr = std::shared_ptr<ShapeBase>;
59 using NewShapes = std::vector<ShapeBasePtr>;
60 using NewTensorMaps = std::vector<ShapeBasePtr>;
61 
62 class ShapeValue : public ShapeBase {
63  public:
ShapeValue(std::vector<int64_t> s)64   explicit ShapeValue(std::vector<int64_t> s) : ShapeBase(false), _s(std::move(s)) {}
65   ~ShapeValue() override = default;
empty()66   bool empty() const override { return _s.empty(); }
GetBatchValue()67   int64_t GetBatchValue() override { return _s[0]; }
size()68   size_t size() override { return _s.size(); }
GetValue()69   std::vector<int64_t> GetValue() override { return _s; }
GetElement(int64_t idx)70   ShapeBasePtr GetElement(int64_t idx) override {
71     MS_LOG(WARNING) << "Can not get element from ShapeValue, please use GetValue";
72     return std::make_shared<ShapeValue>(_s);
73   }
GetAllElements()74   std::vector<std::vector<int64_t>> GetAllElements() override {
75     std::vector<std::vector<int64_t>> all_elements = {_s};
76     return all_elements;
77   }
set_shape(const std::shared_ptr<ShapeBase> shape)78   void set_shape(const std::shared_ptr<ShapeBase> shape) override {
79     if (!shape->is_list()) {
80       _s = shape->GetValue();
81     } else {
82       MS_LOG(EXCEPTION) << "Can not set list shape to value shape";
83     }
84   }
85 
86  private:
ConvertShapeToStr(std::ostringstream * oss)87   void ConvertShapeToStr(std::ostringstream *oss) override {
88     *oss << "[";
89     for (size_t i = 0; i < _s.size(); ++i) {
90       *oss << _s[i];
91       if (i != _s.size() - 1) {
92         *oss << ", ";
93       }
94     }
95     *oss << "]";
96   }
97   std::vector<int64_t> _s;
98 };
99 
100 class ShapeList : public ShapeBase {
101  public:
ShapeList(std::vector<ShapeBasePtr> s_list)102   explicit ShapeList(std::vector<ShapeBasePtr> s_list) : ShapeBase(true), _s_list(std::move(s_list)) {}
103   ~ShapeList() override = default;
empty()104   bool empty() const override { return _s_list.empty(); }
GetBatchValue()105   int64_t GetBatchValue() override {
106     MS_LOG(EXCEPTION) << "Can not get batch value from ShapeList";
107     return 0;
108   }
size()109   size_t size() override { return _s_list.size(); }
GetValue()110   std::vector<int64_t> GetValue() override {
111     MS_LOG(EXCEPTION) << "Can not get value from ShapeList, please use GetElement";
112     return {};
113   }
GetElement(int64_t idx)114   ShapeBasePtr GetElement(int64_t idx) override {
115     if (idx < 0 || LongToSize(idx) >= _s_list.size()) {
116       MS_LOG(EXCEPTION) << "Index " << idx << " out of range " << _s_list.size();
117     }
118     return _s_list[LongToSize(idx)];
119   }
GetAllElements()120   std::vector<std::vector<int64_t>> GetAllElements() override {
121     std::vector<std::vector<int64_t>> all_elements;
122     for (auto &s : _s_list) {
123       auto elements = s->GetAllElements();
124       all_elements.insert(all_elements.end(), elements.begin(), elements.end());
125     }
126     return all_elements;
127   }
set_shape(const std::shared_ptr<ShapeBase> shape)128   void set_shape(const std::shared_ptr<ShapeBase> shape) override {
129     if (shape->is_list()) {
130       std::vector<ShapeBasePtr> new_list;
131       for (size_t i = 0; i < shape->size(); ++i) {
132         new_list.push_back(shape->GetElement(SizeToLong(i)));
133       }
134       _s_list = new_list;
135     } else {
136       MS_LOG(EXCEPTION) << "Can not set value shape to list shape";
137     }
138   }
139 
140  private:
ConvertShapeToStr(std::ostringstream * oss)141   void ConvertShapeToStr(std::ostringstream *oss) override {
142     *oss << "[";
143     for (size_t i = 0; i < _s_list.size(); ++i) {
144       _s_list[i]->ConvertShapeToStr(oss);
145       if (i != _s_list.size() - 1) {
146         *oss << ", ";
147       }
148     }
149     *oss << "]";
150   }
151   std::vector<ShapeBasePtr> _s_list;
152 };
153 
154 class DeviceMatrix {
155  public:
156   DeviceMatrix(int64_t rank, RankList dev_list, Shape dev_shape);
157   DeviceMatrix() = default;
158   ~DeviceMatrix() = default;
group_list()159   std::vector<RankList> group_list() const { return group_list_; }
160   Status CreateGroupList();
161   Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list);
162   Status GetDevicesAlongDim(const uint64_t &dim, RankList *devices);
163   Status GetDevicesAlongMultiDim(const std::vector<int64_t> &dims, RankList *devices);
164 
165  private:
166   int64_t rank_ = -1;
167   RankList dev_list_;
168   // From low dim to high dim. eg: [D0 D1 D2 D3]
169   Shape dev_shape_;
170   std::vector<RankList> group_list_;
171 };
172 
173 std::string ShapeToString(const Shape &shape);
174 std::string ShapesToString(const Shapes &shapes);
175 std::string ListToString(const RankList &list);
176 }  // namespace parallel
177 }  // namespace mindspore
178 
179 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_
180