1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/map.h"
18 #include <algorithm>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "frontend/parallel/status.h"
22 #include "frontend/parallel/tensor_layout/shape_util.h"
23 #include "utils/convert_utils.h"
24 #include "utils/log_adapter.h"
25
26 namespace mindspore {
27 namespace parallel {
Init(const Shape & array)28 Status Map::Init(const Shape &array) {
29 Status status = Array::Init(array);
30 if (status != Status::SUCCESS) {
31 return Status::FAILED;
32 }
33 if (!IsValidMap()) {
34 MS_LOG(ERROR) << "invalid map " << this->ToString();
35 return Status::FAILED;
36 }
37 return Status::SUCCESS;
38 }
39
IsValidMap()40 bool Map::IsValidMap() {
41 if (std::any_of(array_.begin(), array_.end(), [](int64_t value) { return ((value < 0) && (value != MAP_NONE)); })) {
42 return false;
43 }
44 // check that all none -1 value in array_ is different
45 Shape sorted_array = array_;
46 std::sort(sorted_array.begin(), sorted_array.end());
47 int64_t value = MAP_NONE;
48 for (auto &element : sorted_array) {
49 if (element == MAP_NONE) {
50 continue;
51 }
52 if (element == value) {
53 return false;
54 }
55 value = element;
56 }
57 return true;
58 }
59
GetMaxItem() const60 int64_t Map::GetMaxItem() const {
61 if (!array_.empty()) {
62 return *std::max_element(array_.begin(), array_.end());
63 } else {
64 return MAP_NONE;
65 }
66 }
67
GetIndexByValue(int64_t value) const68 int64_t Map::GetIndexByValue(int64_t value) const {
69 auto iter = find(array_.begin(), array_.end(), value);
70 if (iter != array_.end()) {
71 return static_cast<int64_t>(std::distance(array_.begin(), iter));
72 } else {
73 return MAP_NONE;
74 }
75 }
76
77 /*
78 * expand.size() should be equal to array_.size()
79 */
ExpandMapByNone(const Arrangement & expand_num_list) const80 std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) const {
81 if (expand_num_list.GetDimSize() != GetDimSize()) {
82 return nullptr;
83 }
84 Shape new_shape;
85 for (size_t i = 0; i != GetDimSize(); i++) {
86 if (GetDimByIdx(i) == MAP_NONE) {
87 for (int64_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) {
88 new_shape.push_back(MAP_NONE);
89 }
90 } else {
91 new_shape.push_back(GetDimByIdx(i));
92 int64_t j = 1;
93 while (j < expand_num_list.GetDimByIdx(i)) {
94 new_shape.push_back(MAP_NONE);
95 j++;
96 }
97 }
98 }
99 auto map_new = std::make_shared<Map>();
100 (void)map_new->Init(new_shape);
101 return map_new;
102 }
103
104 /*
105 * expand.size() should be equal to array_.size()
106 */
ExpandMapByDecreaseNumber(const Arrangement & expand_num_list) const107 std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const {
108 if (GetMaxItem() >= static_cast<int64_t>(expand_num_list.GetDimSize())) {
109 return nullptr;
110 }
111 Shape new_shape;
112 for (size_t i = 0; i < GetDimSize(); i++) {
113 if (GetDimByIdx(i) == MAP_NONE) {
114 new_shape.push_back(MAP_NONE);
115 } else {
116 int64_t start_map =
117 expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast<size_t>(GetDimByIdx(i))];
118 for (int64_t k = expand_num_list.GetDimByReverseIdx(static_cast<size_t>(GetDimByIdx(i))) - 1; k >= 0; k--) {
119 new_shape.push_back(k + start_map);
120 }
121 }
122 }
123 auto map_new = std::make_shared<Map>();
124 (void)map_new->Init(new_shape);
125 return map_new;
126 }
127
ReMapVector(const std::vector<Arrangement> & input_vector) const128 std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arrangement> &input_vector) const {
129 if (GetMaxItem() >= static_cast<int64_t>(input_vector.size())) {
130 return nullptr;
131 }
132 std::vector<Arrangement> out;
133 Arrangement empty_arrangement;
134 for (size_t i = 0; i < GetDimSize(); i++) {
135 if (GetDimByIdx(i) == MAP_NONE) {
136 out.push_back(empty_arrangement);
137 } else {
138 out.push_back(input_vector[input_vector.size() - 1 - LongToSize(GetDimByIdx(i))]);
139 }
140 }
141 return std::make_shared<std::vector<Arrangement>>(out);
142 }
143
CheckNoneByIdxList(std::vector<size_t> idx_list) const144 bool Map::CheckNoneByIdxList(std::vector<size_t> idx_list) const {
145 for (auto &value : idx_list) {
146 if (GetDimByIdx(value) != MAP_NONE) {
147 return false;
148 }
149 }
150 return true;
151 }
152
SqueezeMapByIdxList(std::vector<size_t> idx_list) const153 Map Map::SqueezeMapByIdxList(std::vector<size_t> idx_list) const {
154 Shape out_shape;
155 for (size_t i = 0; i < GetDimSize(); i++) {
156 auto it = std::find(idx_list.begin(), idx_list.end(), i);
157 if (it == idx_list.end()) {
158 out_shape.push_back(GetDimByIdx(i));
159 }
160 }
161 if (out_shape.empty()) {
162 MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
163 out_shape.push_back(MAP_NONE);
164 }
165 Map out;
166 (void)out.Init(out_shape);
167 return out;
168 }
169 } // namespace parallel
170 } // namespace mindspore
171