• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/parallel/tensor_layout/prime_generator.h"
17 
18 namespace mindspore::parallel {
19 const int MAX_PRIME_RANGE = 1e5 + 1;  // 100,001
get_prime_table(Shape * prime_arr,const size_t arr_size)20 void get_prime_table(Shape *prime_arr, const size_t arr_size) {
21   std::vector<bool> is_composite_num(arr_size, false);
22   for (size_t i = 2; i <= arr_size; i++) {
23     if (!is_composite_num[i]) {
24       prime_arr->emplace_back(i);
25     }
26     for (size_t j = 0;; j++) {
27       if (j >= prime_arr->size() || LongToSize(prime_arr->at(j)) * i > arr_size) {
28         break;
29       }
30       is_composite_num[LongToSize(prime_arr->at(j)) * i] = true;
31       if (i % LongToSize(prime_arr->at(j)) == 0) {
32         break;
33       }
34     }
35   }
36   prime_arr->resize(prime_arr->size());
37 }
38 
PrimeGenerator()39 PrimeGenerator::PrimeGenerator() {
40   this->prime_table_ = {3,   5,   7,   11,  13,  17,  19,  23,  29,  31,  37,  41,  43,  47,  53,  59,
41                         61,  67,  71,  73,  79,  83,  89,  97,  101, 103, 107, 109, 113, 127, 131, 137,
42                         139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227,
43                         229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293};
44 }
45 
GetCoprimeNum(const Shape & tensor_shape)46 int64_t PrimeGenerator::GetCoprimeNum(const Shape &tensor_shape) {
47   const int64_t unknown_val = -1;
48   if (tensor_shape.empty()) {
49     return this->prime_table_[0];
50   }
51   std::set<int64_t> input_flag;
52   for (int64_t i : tensor_shape) {
53     input_flag.insert(i);
54   }
55   const int64_t two = 2;
56   for (int64_t prime_num : this->prime_table_) {
57     if (prime_num == two) {
58       // skip prime 2.
59       continue;
60     }
61     if (input_flag.find(prime_num) != input_flag.end()) {
62       continue;
63     }
64     bool is_coprime = std::all_of(tensor_shape.begin(), tensor_shape.end(),
65                                   [prime_num](int64_t v) { return std::gcd(prime_num, v) == 1; });
66     if (is_coprime) {
67       return prime_num;
68     }
69   }
70   MS_LOG(ERROR) << "Cannot find a coprime number for shape " << tensor_shape;
71   return unknown_val;
72 }
73 }  // namespace mindspore::parallel
74