1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/expander/base/utils.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21
22 #include "utils/check_convert_utils.h"
23
24 namespace mindspore::graphkernel::expander {
FormatDefaultNchwSame(const std::string & f0,const std::string & f1)25 bool FormatDefaultNchwSame(const std::string &f0, const std::string &f1) {
26 return f0 == f1 || (f0 == kOpFormat_DEFAULT && f1 == kOpFormat_NCHW) ||
27 (f0 == kOpFormat_NCHW && f1 == kOpFormat_DEFAULT);
28 }
29
CheckAllFormatsSame(const DefaultIrBuilder * ib,const std::function<bool (const std::string &,const std::string &)> & check)30 bool CheckAllFormatsSame(const DefaultIrBuilder *ib,
31 const std::function<bool(const std::string &, const std::string &)> &check) {
32 auto inputs = ib->inputs();
33 if (inputs.empty()) {
34 return true;
35 }
36 const auto &fmt_0 = inputs[0]->GetFormat();
37 for (size_t i = 1; i < inputs.size(); i++) {
38 MS_LOG(INFO) << i << "th format: " << inputs[i]->GetFormat();
39 bool is_same = check == nullptr ? (inputs[i]->GetFormat() == fmt_0) : check(inputs[i]->GetFormat(), fmt_0);
40 if (!is_same) {
41 MS_LOG(INFO) << "The " << i << "th format: " << inputs[i]->GetFormat() << " is not same as 0th format: " << fmt_0
42 << " of op " << ib->name();
43 return false;
44 }
45 }
46 return true;
47 }
48
CheckAttrs(const DefaultIrBuilder * ib,const std::vector<std::string> & attrs)49 bool CheckAttrs(const DefaultIrBuilder *ib, const std::vector<std::string> &attrs) {
50 for (auto &a : attrs) {
51 if (ib->attrs().count(a) == 0) {
52 MS_LOG(INFO) << "attr " << a << " dose not exist. Op: " << ib->name();
53 return false;
54 }
55 }
56 return true;
57 }
58
CheckSupportFormat(const DefaultIrBuilder * ib,const std::vector<std::vector<std::string>> & formats_list)59 bool CheckSupportFormat(const DefaultIrBuilder *ib, const std::vector<std::vector<std::string>> &formats_list) {
60 for (auto &formats : formats_list) {
61 if (formats.size() != ib->inputs().size()) {
62 continue;
63 }
64 bool match = true;
65 for (size_t i = 0; i < formats.size(); i++) {
66 if (ib->inputs()[i]->GetFormat() != formats[i]) {
67 match = false;
68 break;
69 }
70 }
71 if (match) {
72 return true;
73 }
74 }
75 MS_LOG(INFO) << "unsupported format for op " << ib->name();
76 return false;
77 }
78
ExpandDimsInferShape(const ShapeVector & shape,const std::vector<int64_t> & axis)79 ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) {
80 ShapeVector new_shape = shape;
81 for (auto x : axis) {
82 int64_t rank = static_cast<int64_t>(new_shape.size());
83 if (x > rank || x < -rank - 1) {
84 MS_LOG(EXCEPTION) << "ExpandDims attr 'axis' value " << x << " is out of range of [" << (-rank - 1) << ", "
85 << rank << "]";
86 }
87 if (x >= 0) {
88 (void)new_shape.insert(new_shape.cbegin() + x, 1LL);
89 } else {
90 (void)new_shape.insert(new_shape.cbegin() + (x + rank + 1), 1LL);
91 }
92 }
93 return new_shape;
94 }
95
GetAxisList(const ValuePtr & value)96 std::vector<int64_t> GetAxisList(const ValuePtr &value) {
97 std::vector<int64_t> result;
98 auto get_int_value = [](const ValuePtr &value) -> int64_t {
99 return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
100 };
101 if (value->isa<ValueSequence>()) {
102 const auto &vals = value->cast<ValueSequencePtr>()->value();
103 (void)std::transform(vals.begin(), vals.end(), std::back_inserter(result), get_int_value);
104 } else if (value->isa<tensor::Tensor>()) {
105 result = CheckAndConvertUtils::CheckTensorIntValue("axes value", value, "GetAxisList");
106 } else {
107 result.push_back(get_int_value(value));
108 }
109 return result;
110 }
111 } // namespace mindspore::graphkernel::expander
112