1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_ 19 20 #include <cstdint> 21 #include <string> 22 #include <vector> 23 #include <memory> 24 #include <utility> 25 26 #include "frontend/parallel/status.h" 27 #include "include/common/utils/convert_utils.h" 28 29 namespace mindspore { 30 namespace parallel { 31 using RankList = std::vector<int64_t>; 32 using Shape = std::vector<int64_t>; 33 using Shapes = std::vector<Shape>; 34 35 class ShapeBase { 36 public: ShapeBase(bool is_list)37 explicit ShapeBase(bool is_list) { is_list_ = is_list; } 38 virtual ~ShapeBase() = default; is_list()39 bool is_list() const { return is_list_; } 40 virtual bool empty() const = 0; 41 virtual int64_t GetBatchValue() = 0; 42 virtual size_t size() = 0; 43 virtual std::vector<int64_t> GetValue() = 0; 44 virtual std::shared_ptr<ShapeBase> GetElement(int64_t idx) = 0; 45 virtual std::vector<std::vector<int64_t>> GetAllElements() = 0; 46 virtual void set_shape(const std::shared_ptr<ShapeBase> shape) = 0; ToString()47 std::string ToString() { 48 std::ostringstream oss; 49 ConvertShapeToStr(&oss); 50 return oss.str(); 51 } ConvertShapeToStr(std::ostringstream * oss)52 virtual void ConvertShapeToStr(std::ostringstream *oss) { MS_LOG(WARNING) << "Please override this func"; } 53 54 private: 55 bool is_list_; 56 }; 57 58 using ShapeBasePtr = std::shared_ptr<ShapeBase>; 59 using NewShapes = std::vector<ShapeBasePtr>; 60 using NewTensorMaps = std::vector<ShapeBasePtr>; 61 62 class ShapeValue : public ShapeBase { 63 public: ShapeValue(std::vector<int64_t> s)64 explicit ShapeValue(std::vector<int64_t> s) : ShapeBase(false), _s(std::move(s)) {} 65 ~ShapeValue() override = default; empty()66 bool empty() const override { return _s.empty(); } GetBatchValue()67 int64_t GetBatchValue() override { return _s[0]; } size()68 size_t size() override { return _s.size(); } GetValue()69 std::vector<int64_t> GetValue() override { return _s; } GetElement(int64_t idx)70 ShapeBasePtr GetElement(int64_t idx) override { 71 MS_LOG(WARNING) << "Can not get element from ShapeValue, please use GetValue"; 72 return std::make_shared<ShapeValue>(_s); 73 } GetAllElements()74 std::vector<std::vector<int64_t>> GetAllElements() override { 75 std::vector<std::vector<int64_t>> all_elements = {_s}; 76 return all_elements; 77 } set_shape(const std::shared_ptr<ShapeBase> shape)78 void set_shape(const std::shared_ptr<ShapeBase> shape) override { 79 if (!shape->is_list()) { 80 _s = shape->GetValue(); 81 } else { 82 MS_LOG(EXCEPTION) << "Can not set list shape to value shape"; 83 } 84 } 85 86 private: ConvertShapeToStr(std::ostringstream * oss)87 void ConvertShapeToStr(std::ostringstream *oss) override { 88 *oss << "["; 89 for (size_t i = 0; i < _s.size(); ++i) { 90 *oss << _s[i]; 91 if (i != _s.size() - 1) { 92 *oss << ", "; 93 } 94 } 95 *oss << "]"; 96 } 97 std::vector<int64_t> _s; 98 }; 99 100 class ShapeList : public ShapeBase { 101 public: ShapeList(std::vector<ShapeBasePtr> s_list)102 explicit ShapeList(std::vector<ShapeBasePtr> s_list) : ShapeBase(true), _s_list(std::move(s_list)) {} 103 ~ShapeList() override = default; empty()104 bool empty() const override { return _s_list.empty(); } GetBatchValue()105 int64_t GetBatchValue() override { 106 MS_LOG(EXCEPTION) << "Can not get batch value from ShapeList"; 107 return 0; 108 } size()109 size_t size() override { return _s_list.size(); } GetValue()110 std::vector<int64_t> GetValue() override { 111 MS_LOG(EXCEPTION) << "Can not get value from ShapeList, please use GetElement"; 112 return {}; 113 } GetElement(int64_t idx)114 ShapeBasePtr GetElement(int64_t idx) override { 115 if (idx < 0 || LongToSize(idx) >= _s_list.size()) { 116 MS_LOG(EXCEPTION) << "Index " << idx << " out of range " << _s_list.size(); 117 } 118 return _s_list[LongToSize(idx)]; 119 } GetAllElements()120 std::vector<std::vector<int64_t>> GetAllElements() override { 121 std::vector<std::vector<int64_t>> all_elements; 122 for (auto &s : _s_list) { 123 auto elements = s->GetAllElements(); 124 all_elements.insert(all_elements.end(), elements.begin(), elements.end()); 125 } 126 return all_elements; 127 } set_shape(const std::shared_ptr<ShapeBase> shape)128 void set_shape(const std::shared_ptr<ShapeBase> shape) override { 129 if (shape->is_list()) { 130 std::vector<ShapeBasePtr> new_list; 131 for (size_t i = 0; i < shape->size(); ++i) { 132 new_list.push_back(shape->GetElement(SizeToLong(i))); 133 } 134 _s_list = new_list; 135 } else { 136 MS_LOG(EXCEPTION) << "Can not set value shape to list shape"; 137 } 138 } 139 140 private: ConvertShapeToStr(std::ostringstream * oss)141 void ConvertShapeToStr(std::ostringstream *oss) override { 142 *oss << "["; 143 for (size_t i = 0; i < _s_list.size(); ++i) { 144 _s_list[i]->ConvertShapeToStr(oss); 145 if (i != _s_list.size() - 1) { 146 *oss << ", "; 147 } 148 } 149 *oss << "]"; 150 } 151 std::vector<ShapeBasePtr> _s_list; 152 }; 153 154 class DeviceMatrix { 155 public: 156 DeviceMatrix(int64_t rank, RankList dev_list, Shape dev_shape); 157 DeviceMatrix() = default; 158 ~DeviceMatrix() = default; group_list()159 std::vector<RankList> group_list() const { return group_list_; } 160 Status CreateGroupList(); 161 Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); 162 Status GetDevicesAlongDim(const uint64_t &dim, RankList *devices); 163 Status GetDevicesAlongMultiDim(const std::vector<int64_t> &dims, RankList *devices); 164 165 private: 166 int64_t rank_ = -1; 167 RankList dev_list_; 168 // From low dim to high dim. eg: [D0 D1 D2 D3] 169 Shape dev_shape_; 170 std::vector<RankList> group_list_; 171 }; 172 173 std::string ShapeToString(const Shape &shape); 174 std::string ShapesToString(const Shapes &shapes); 175 std::string ListToString(const RankList &list); 176 } // namespace parallel 177 } // namespace mindspore 178 179 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MATRIX_H_ 180