1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/parallel/tensor_layout/prime_generator.h"
17
18 namespace mindspore::parallel {
19 const int MAX_PRIME_RANGE = 1e5 + 1; // 100,001
get_prime_table(Shape * prime_arr,const size_t arr_size)20 void get_prime_table(Shape *prime_arr, const size_t arr_size) {
21 std::vector<bool> is_composite_num(arr_size, false);
22 for (size_t i = 2; i <= arr_size; i++) {
23 if (!is_composite_num[i]) {
24 prime_arr->emplace_back(i);
25 }
26 for (size_t j = 0;; j++) {
27 if (j >= prime_arr->size() || LongToSize(prime_arr->at(j)) * i > arr_size) {
28 break;
29 }
30 is_composite_num[LongToSize(prime_arr->at(j)) * i] = true;
31 if (i % LongToSize(prime_arr->at(j)) == 0) {
32 break;
33 }
34 }
35 }
36 prime_arr->resize(prime_arr->size());
37 }
38
PrimeGenerator()39 PrimeGenerator::PrimeGenerator() {
40 this->prime_table_ = {3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
41 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137,
42 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227,
43 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293};
44 }
45
GetCoprimeNum(const Shape & tensor_shape)46 int64_t PrimeGenerator::GetCoprimeNum(const Shape &tensor_shape) {
47 const int64_t unknown_val = -1;
48 if (tensor_shape.empty()) {
49 return this->prime_table_[0];
50 }
51 std::set<int64_t> input_flag;
52 for (int64_t i : tensor_shape) {
53 input_flag.insert(i);
54 }
55 const int64_t two = 2;
56 for (int64_t prime_num : this->prime_table_) {
57 if (prime_num == two) {
58 // skip prime 2.
59 continue;
60 }
61 if (input_flag.find(prime_num) != input_flag.end()) {
62 continue;
63 }
64 bool is_coprime = std::all_of(tensor_shape.begin(), tensor_shape.end(),
65 [prime_num](int64_t v) { return std::gcd(prime_num, v) == 1; });
66 if (is_coprime) {
67 return prime_num;
68 }
69 }
70 MS_LOG(ERROR) << "Cannot find a coprime number for shape " << tensor_shape;
71 return unknown_val;
72 }
73 } // namespace mindspore::parallel
74