• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/map.h"
18 #include <algorithm>
19 #include <utility>
20 #include "frontend/parallel/status.h"
21 #include "frontend/parallel/tensor_layout/shape_util.h"
22 #include "include/common/utils/convert_utils.h"
23 #include "utils/log_adapter.h"
24 
25 namespace mindspore {
26 namespace parallel {
Init(const Shape & array)27 Status Map::Init(const Shape &array) {
28   Status status = Array::Init(array);
29   if (status != Status::SUCCESS) {
30     return Status::FAILED;
31   }
32   if (!IsValidMap()) {
33     MS_LOG(ERROR) << "invalid map " << this->ToString();
34     return Status::FAILED;
35   }
36   return Status::SUCCESS;
37 }
38 
IsValidMap()39 bool Map::IsValidMap() {
40   if (std::any_of(array_.begin(), array_.end(), [](int64_t value) { return ((value < 0) && (value != MAP_NONE)); })) {
41     return false;
42   }
43   // check that all none -1 value in array_ is different
44   Shape sorted_array = array_;
45   std::sort(sorted_array.begin(), sorted_array.end());
46   int64_t value = MAP_NONE;
47   for (auto &element : sorted_array) {
48     if (element == MAP_NONE) {
49       continue;
50     }
51     if (element == value) {
52       return false;
53     }
54     value = element;
55   }
56   return true;
57 }
58 
GetMaxItem() const59 int64_t Map::GetMaxItem() const {
60   if (!array_.empty()) {
61     return *std::max_element(array_.begin(), array_.end());
62   } else {
63     return MAP_NONE;
64   }
65 }
66 
GetIndexByValue(int64_t value) const67 int64_t Map::GetIndexByValue(int64_t value) const {
68   auto iter = find(array_.begin(), array_.end(), value);
69   if (iter != array_.end()) {
70     return static_cast<int64_t>(std::distance(array_.begin(), iter));
71   } else {
72     return MAP_NONE;
73   }
74 }
75 
76 /*
77  * expand.size() should be equal to array_.size()
78  */
ExpandMapByNone(const Arrangement & expand_num_list) const79 std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) const {
80   if (expand_num_list.GetDimSize() != GetDimSize()) {
81     return nullptr;
82   }
83   Shape new_shape;
84   for (size_t i = 0; i != GetDimSize(); i++) {
85     if (GetDimByIdx(i) == MAP_NONE) {
86       for (int64_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) {
87         new_shape.push_back(MAP_NONE);
88       }
89     } else {
90       new_shape.push_back(GetDimByIdx(i));
91       int64_t j = 1;
92       while (j < expand_num_list.GetDimByIdx(i)) {
93         new_shape.push_back(MAP_NONE);
94         j++;
95       }
96     }
97   }
98   auto map_new = std::make_shared<Map>();
99   (void)map_new->Init(new_shape);
100   return map_new;
101 }
102 
103 /*
104  * expand.size() should be equal to array_.size()
105  */
ExpandMapByDecreaseNumber(const Arrangement & expand_num_list) const106 std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const {
107   if (GetMaxItem() >= static_cast<int64_t>(expand_num_list.GetDimSize())) {
108     return nullptr;
109   }
110   Shape new_shape;
111   for (size_t i = 0; i < GetDimSize(); i++) {
112     if (GetDimByIdx(i) == MAP_NONE) {
113       new_shape.push_back(MAP_NONE);
114     } else {
115       int64_t start_map =
116         expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast<size_t>(GetDimByIdx(i))];
117       for (int64_t k = expand_num_list.GetDimByReverseIdx(static_cast<size_t>(GetDimByIdx(i))) - 1; k >= 0; k--) {
118         new_shape.push_back(k + start_map);
119       }
120     }
121   }
122   auto map_new = std::make_shared<Map>();
123   (void)map_new->Init(new_shape);
124   return map_new;
125 }
126 
ReMapVector(const std::vector<Arrangement> & input_vector) const127 std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arrangement> &input_vector) const {
128   if (GetMaxItem() >= static_cast<int64_t>(input_vector.size())) {
129     return nullptr;
130   }
131   std::vector<Arrangement> out;
132   Arrangement empty_arrangement;
133   for (size_t i = 0; i < GetDimSize(); i++) {
134     if (GetDimByIdx(i) == MAP_NONE) {
135       out.push_back(empty_arrangement);
136     } else {
137       out.push_back(input_vector[input_vector.size() - 1 - LongToSize(GetDimByIdx(i))]);
138     }
139   }
140   return std::make_shared<std::vector<Arrangement>>(out);
141 }
142 
CheckNoneByIdxList(const std::vector<size_t> & idx_list) const143 bool Map::CheckNoneByIdxList(const std::vector<size_t> &idx_list) const {
144   return std::all_of(idx_list.begin(), idx_list.end(), [this](size_t value) { return GetDimByIdx(value) == MAP_NONE; });
145 }
146 
SqueezeMapByIdxList(const std::vector<size_t> & idx_list) const147 Map Map::SqueezeMapByIdxList(const std::vector<size_t> &idx_list) const {
148   Shape out_shape;
149   for (size_t i = 0; i < GetDimSize(); i++) {
150     auto it = std::find(idx_list.begin(), idx_list.end(), i);
151     if (it == idx_list.end()) {
152       out_shape.push_back(GetDimByIdx(i));
153     }
154   }
155   if (out_shape.empty()) {
156     MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
157     out_shape.push_back(MAP_NONE);
158   }
159   Map out;
160   (void)out.Init(out_shape);
161   return out;
162 }
163 }  // namespace parallel
164 }  // namespace mindspore
165