• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/map.h"
18 #include <algorithm>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "frontend/parallel/status.h"
22 #include "frontend/parallel/tensor_layout/shape_util.h"
23 #include "utils/convert_utils.h"
24 #include "utils/log_adapter.h"
25 
26 namespace mindspore {
27 namespace parallel {
Init(const Shape & array)28 Status Map::Init(const Shape &array) {
29   Status status = Array::Init(array);
30   if (status != Status::SUCCESS) {
31     return Status::FAILED;
32   }
33   if (!IsValidMap()) {
34     MS_LOG(ERROR) << "invalid map " << this->ToString();
35     return Status::FAILED;
36   }
37   return Status::SUCCESS;
38 }
39 
IsValidMap()40 bool Map::IsValidMap() {
41   if (std::any_of(array_.begin(), array_.end(), [](int64_t value) { return ((value < 0) && (value != MAP_NONE)); })) {
42     return false;
43   }
44   // check that all none -1 value in array_ is different
45   Shape sorted_array = array_;
46   std::sort(sorted_array.begin(), sorted_array.end());
47   int64_t value = MAP_NONE;
48   for (auto &element : sorted_array) {
49     if (element == MAP_NONE) {
50       continue;
51     }
52     if (element == value) {
53       return false;
54     }
55     value = element;
56   }
57   return true;
58 }
59 
GetMaxItem() const60 int64_t Map::GetMaxItem() const {
61   if (!array_.empty()) {
62     return *std::max_element(array_.begin(), array_.end());
63   } else {
64     return MAP_NONE;
65   }
66 }
67 
GetIndexByValue(int64_t value) const68 int64_t Map::GetIndexByValue(int64_t value) const {
69   auto iter = find(array_.begin(), array_.end(), value);
70   if (iter != array_.end()) {
71     return static_cast<int64_t>(std::distance(array_.begin(), iter));
72   } else {
73     return MAP_NONE;
74   }
75 }
76 
77 /*
78  * expand.size() should be equal to array_.size()
79  */
ExpandMapByNone(const Arrangement & expand_num_list) const80 std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) const {
81   if (expand_num_list.GetDimSize() != GetDimSize()) {
82     return nullptr;
83   }
84   Shape new_shape;
85   for (size_t i = 0; i != GetDimSize(); i++) {
86     if (GetDimByIdx(i) == MAP_NONE) {
87       for (int64_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) {
88         new_shape.push_back(MAP_NONE);
89       }
90     } else {
91       new_shape.push_back(GetDimByIdx(i));
92       int64_t j = 1;
93       while (j < expand_num_list.GetDimByIdx(i)) {
94         new_shape.push_back(MAP_NONE);
95         j++;
96       }
97     }
98   }
99   auto map_new = std::make_shared<Map>();
100   (void)map_new->Init(new_shape);
101   return map_new;
102 }
103 
104 /*
105  * expand.size() should be equal to array_.size()
106  */
ExpandMapByDecreaseNumber(const Arrangement & expand_num_list) const107 std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const {
108   if (GetMaxItem() >= static_cast<int64_t>(expand_num_list.GetDimSize())) {
109     return nullptr;
110   }
111   Shape new_shape;
112   for (size_t i = 0; i < GetDimSize(); i++) {
113     if (GetDimByIdx(i) == MAP_NONE) {
114       new_shape.push_back(MAP_NONE);
115     } else {
116       int64_t start_map =
117         expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast<size_t>(GetDimByIdx(i))];
118       for (int64_t k = expand_num_list.GetDimByReverseIdx(static_cast<size_t>(GetDimByIdx(i))) - 1; k >= 0; k--) {
119         new_shape.push_back(k + start_map);
120       }
121     }
122   }
123   auto map_new = std::make_shared<Map>();
124   (void)map_new->Init(new_shape);
125   return map_new;
126 }
127 
ReMapVector(const std::vector<Arrangement> & input_vector) const128 std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arrangement> &input_vector) const {
129   if (GetMaxItem() >= static_cast<int64_t>(input_vector.size())) {
130     return nullptr;
131   }
132   std::vector<Arrangement> out;
133   Arrangement empty_arrangement;
134   for (size_t i = 0; i < GetDimSize(); i++) {
135     if (GetDimByIdx(i) == MAP_NONE) {
136       out.push_back(empty_arrangement);
137     } else {
138       out.push_back(input_vector[input_vector.size() - 1 - LongToSize(GetDimByIdx(i))]);
139     }
140   }
141   return std::make_shared<std::vector<Arrangement>>(out);
142 }
143 
CheckNoneByIdxList(std::vector<size_t> idx_list) const144 bool Map::CheckNoneByIdxList(std::vector<size_t> idx_list) const {
145   for (auto &value : idx_list) {
146     if (GetDimByIdx(value) != MAP_NONE) {
147       return false;
148     }
149   }
150   return true;
151 }
152 
SqueezeMapByIdxList(std::vector<size_t> idx_list) const153 Map Map::SqueezeMapByIdxList(std::vector<size_t> idx_list) const {
154   Shape out_shape;
155   for (size_t i = 0; i < GetDimSize(); i++) {
156     auto it = std::find(idx_list.begin(), idx_list.end(), i);
157     if (it == idx_list.end()) {
158       out_shape.push_back(GetDimByIdx(i));
159     }
160   }
161   if (out_shape.empty()) {
162     MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
163     out_shape.push_back(MAP_NONE);
164   }
165   Map out;
166   (void)out.Init(out_shape);
167   return out;
168 }
169 }  // namespace parallel
170 }  // namespace mindspore
171