1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/map.h"
18 #include <algorithm>
19 #include <utility>
20 #include "frontend/parallel/status.h"
21 #include "frontend/parallel/tensor_layout/shape_util.h"
22 #include "include/common/utils/convert_utils.h"
23 #include "utils/log_adapter.h"
24
25 namespace mindspore {
26 namespace parallel {
Init(const Shape & array)27 Status Map::Init(const Shape &array) {
28 Status status = Array::Init(array);
29 if (status != Status::SUCCESS) {
30 return Status::FAILED;
31 }
32 if (!IsValidMap()) {
33 MS_LOG(ERROR) << "invalid map " << this->ToString();
34 return Status::FAILED;
35 }
36 return Status::SUCCESS;
37 }
38
IsValidMap()39 bool Map::IsValidMap() {
40 if (std::any_of(array_.begin(), array_.end(), [](int64_t value) { return ((value < 0) && (value != MAP_NONE)); })) {
41 return false;
42 }
43 // check that all none -1 value in array_ is different
44 Shape sorted_array = array_;
45 std::sort(sorted_array.begin(), sorted_array.end());
46 int64_t value = MAP_NONE;
47 for (auto &element : sorted_array) {
48 if (element == MAP_NONE) {
49 continue;
50 }
51 if (element == value) {
52 return false;
53 }
54 value = element;
55 }
56 return true;
57 }
58
GetMaxItem() const59 int64_t Map::GetMaxItem() const {
60 if (!array_.empty()) {
61 return *std::max_element(array_.begin(), array_.end());
62 } else {
63 return MAP_NONE;
64 }
65 }
66
GetIndexByValue(int64_t value) const67 int64_t Map::GetIndexByValue(int64_t value) const {
68 auto iter = find(array_.begin(), array_.end(), value);
69 if (iter != array_.end()) {
70 return static_cast<int64_t>(std::distance(array_.begin(), iter));
71 } else {
72 return MAP_NONE;
73 }
74 }
75
76 /*
77 * expand.size() should be equal to array_.size()
78 */
ExpandMapByNone(const Arrangement & expand_num_list) const79 std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) const {
80 if (expand_num_list.GetDimSize() != GetDimSize()) {
81 return nullptr;
82 }
83 Shape new_shape;
84 for (size_t i = 0; i != GetDimSize(); i++) {
85 if (GetDimByIdx(i) == MAP_NONE) {
86 for (int64_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) {
87 new_shape.push_back(MAP_NONE);
88 }
89 } else {
90 new_shape.push_back(GetDimByIdx(i));
91 int64_t j = 1;
92 while (j < expand_num_list.GetDimByIdx(i)) {
93 new_shape.push_back(MAP_NONE);
94 j++;
95 }
96 }
97 }
98 auto map_new = std::make_shared<Map>();
99 (void)map_new->Init(new_shape);
100 return map_new;
101 }
102
103 /*
104 * expand.size() should be equal to array_.size()
105 */
ExpandMapByDecreaseNumber(const Arrangement & expand_num_list) const106 std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const {
107 if (GetMaxItem() >= static_cast<int64_t>(expand_num_list.GetDimSize())) {
108 return nullptr;
109 }
110 Shape new_shape;
111 for (size_t i = 0; i < GetDimSize(); i++) {
112 if (GetDimByIdx(i) == MAP_NONE) {
113 new_shape.push_back(MAP_NONE);
114 } else {
115 int64_t start_map =
116 expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast<size_t>(GetDimByIdx(i))];
117 for (int64_t k = expand_num_list.GetDimByReverseIdx(static_cast<size_t>(GetDimByIdx(i))) - 1; k >= 0; k--) {
118 new_shape.push_back(k + start_map);
119 }
120 }
121 }
122 auto map_new = std::make_shared<Map>();
123 (void)map_new->Init(new_shape);
124 return map_new;
125 }
126
ReMapVector(const std::vector<Arrangement> & input_vector) const127 std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arrangement> &input_vector) const {
128 if (GetMaxItem() >= static_cast<int64_t>(input_vector.size())) {
129 return nullptr;
130 }
131 std::vector<Arrangement> out;
132 Arrangement empty_arrangement;
133 for (size_t i = 0; i < GetDimSize(); i++) {
134 if (GetDimByIdx(i) == MAP_NONE) {
135 out.push_back(empty_arrangement);
136 } else {
137 out.push_back(input_vector[input_vector.size() - 1 - LongToSize(GetDimByIdx(i))]);
138 }
139 }
140 return std::make_shared<std::vector<Arrangement>>(out);
141 }
142
CheckNoneByIdxList(const std::vector<size_t> & idx_list) const143 bool Map::CheckNoneByIdxList(const std::vector<size_t> &idx_list) const {
144 return std::all_of(idx_list.begin(), idx_list.end(), [this](size_t value) { return GetDimByIdx(value) == MAP_NONE; });
145 }
146
SqueezeMapByIdxList(const std::vector<size_t> & idx_list) const147 Map Map::SqueezeMapByIdxList(const std::vector<size_t> &idx_list) const {
148 Shape out_shape;
149 for (size_t i = 0; i < GetDimSize(); i++) {
150 auto it = std::find(idx_list.begin(), idx_list.end(), i);
151 if (it == idx_list.end()) {
152 out_shape.push_back(GetDimByIdx(i));
153 }
154 }
155 if (out_shape.empty()) {
156 MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
157 out_shape.push_back(MAP_NONE);
158 }
159 Map out;
160 (void)out.Init(out_shape);
161 return out;
162 }
163 } // namespace parallel
164 } // namespace mindspore
165