• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/expander/base/utils.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 
22 #include "utils/check_convert_utils.h"
23 
24 namespace mindspore::graphkernel::expander {
FormatDefaultNchwSame(const std::string & f0,const std::string & f1)25 bool FormatDefaultNchwSame(const std::string &f0, const std::string &f1) {
26   return f0 == f1 || (f0 == kOpFormat_DEFAULT && f1 == kOpFormat_NCHW) ||
27          (f0 == kOpFormat_NCHW && f1 == kOpFormat_DEFAULT);
28 }
29 
CheckAllFormatsSame(const DefaultIrBuilder * ib,const std::function<bool (const std::string &,const std::string &)> & check)30 bool CheckAllFormatsSame(const DefaultIrBuilder *ib,
31                          const std::function<bool(const std::string &, const std::string &)> &check) {
32   auto inputs = ib->inputs();
33   if (inputs.empty()) {
34     return true;
35   }
36   const auto &fmt_0 = inputs[0]->GetFormat();
37   for (size_t i = 1; i < inputs.size(); i++) {
38     MS_LOG(INFO) << i << "th format: " << inputs[i]->GetFormat();
39     bool is_same = check == nullptr ? (inputs[i]->GetFormat() == fmt_0) : check(inputs[i]->GetFormat(), fmt_0);
40     if (!is_same) {
41       MS_LOG(INFO) << "The " << i << "th format: " << inputs[i]->GetFormat() << " is not same as 0th format: " << fmt_0
42                    << " of op " << ib->name();
43       return false;
44     }
45   }
46   return true;
47 }
48 
CheckAttrs(const DefaultIrBuilder * ib,const std::vector<std::string> & attrs)49 bool CheckAttrs(const DefaultIrBuilder *ib, const std::vector<std::string> &attrs) {
50   for (auto &a : attrs) {
51     if (ib->attrs().count(a) == 0) {
52       MS_LOG(INFO) << "attr " << a << " dose not exist. Op: " << ib->name();
53       return false;
54     }
55   }
56   return true;
57 }
58 
CheckSupportFormat(const DefaultIrBuilder * ib,const std::vector<std::vector<std::string>> & formats_list)59 bool CheckSupportFormat(const DefaultIrBuilder *ib, const std::vector<std::vector<std::string>> &formats_list) {
60   for (auto &formats : formats_list) {
61     if (formats.size() != ib->inputs().size()) {
62       continue;
63     }
64     bool match = true;
65     for (size_t i = 0; i < formats.size(); i++) {
66       if (ib->inputs()[i]->GetFormat() != formats[i]) {
67         match = false;
68         break;
69       }
70     }
71     if (match) {
72       return true;
73     }
74   }
75   MS_LOG(INFO) << "unsupported format for op " << ib->name();
76   return false;
77 }
78 
ExpandDimsInferShape(const ShapeVector & shape,const std::vector<int64_t> & axis)79 ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis) {
80   ShapeVector new_shape = shape;
81   for (auto x : axis) {
82     int64_t rank = static_cast<int64_t>(new_shape.size());
83     if (x > rank || x < -rank - 1) {
84       MS_LOG(EXCEPTION) << "ExpandDims attr 'axis' value " << x << " is out of range of [" << (-rank - 1) << ", "
85                         << rank << "]";
86     }
87     if (x >= 0) {
88       (void)new_shape.insert(new_shape.cbegin() + x, 1LL);
89     } else {
90       (void)new_shape.insert(new_shape.cbegin() + (x + rank + 1), 1LL);
91     }
92   }
93   return new_shape;
94 }
95 
GetAxisList(const ValuePtr & value)96 std::vector<int64_t> GetAxisList(const ValuePtr &value) {
97   std::vector<int64_t> result;
98   auto get_int_value = [](const ValuePtr &value) -> int64_t {
99     return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
100   };
101   if (value->isa<ValueSequence>()) {
102     const auto &vals = value->cast<ValueSequencePtr>()->value();
103     (void)std::transform(vals.begin(), vals.end(), std::back_inserter(result), get_int_value);
104   } else if (value->isa<tensor::Tensor>()) {
105     result = CheckAndConvertUtils::CheckTensorIntValue("axes value", value, "GetAxisList");
106   } else {
107     result.push_back(get_int_value(value));
108   }
109   return result;
110 }
111 }  // namespace mindspore::graphkernel::expander
112