1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/expanders/utils.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 #include <unordered_set>
22
23 #include "backend/common/graph_kernel/model/lite_graph.h"
24 #include "backend/common/graph_kernel/model/node.h"
25 #include "ir/value.h"
26 #include "utils/check_convert_utils.h"
27
28 namespace mindspore::graphkernel::expanders {
29 constexpr int OFFSET1 = 1;
30 constexpr int OFFSET2 = 2;
31 constexpr int OFFSET3 = 3;
32 constexpr int OFFSET4 = 4;
Run(const BaseInfoList & inputs,const BaseInfoList & outputs,const inner::DAttrs & attrs,const std::string & processor)33 inner::LiteGraphPtr OpDesc::Run(const BaseInfoList &inputs, const BaseInfoList &outputs, const inner::DAttrs &attrs,
34 const std::string &processor) {
35 this->inputs_info_ = inputs;
36 this->outputs_info_ = outputs;
37 this->attrs_ = attrs;
38 this->processor_ = processor;
39 if (std::any_of(validators_.begin(), validators_.end(),
40 [this](const std::unique_ptr<Validator> &v) { return !(v->Check(*this)); })) {
41 return nullptr;
42 }
43 Init();
44 if (!this->CheckInputs()) {
45 return nullptr;
46 }
47 for (auto &inp : inputs) {
48 (void)gb.Parameter(inp);
49 }
50 auto result = this->Expand(gb.Get()->inputs());
51 gb.SetOutputs(result);
52 if (!this->CheckOutputs()) {
53 return nullptr;
54 }
55 return gb.Get();
56 }
57
CheckOutputs()58 bool OpDesc::CheckOutputs() {
59 // check the output shape/type/format are same as the original basic node's output.
60 const NodePtrList &outputs = gb.Get()->GetOutputs();
61 if (outputs.size() != this->outputs_info_.size()) {
62 MS_LOG(INFO) << "the output num was not equal to the original output num : " << outputs.size() << " vs "
63 << outputs_info_.size();
64 return false;
65 }
66 for (size_t i = 0; i < outputs.size(); i++) {
67 if (outputs[i]->shape != outputs_info_[i].shape) {
68 std::ostringstream oss;
69 oss << "Op " << this->name_ << "'s output shape [";
70 for (auto s : outputs[i]->shape) {
71 oss << s << ",";
72 }
73 oss << "] is wrong. expect: [";
74 for (auto s : outputs_info_[i].shape) {
75 oss << s << ",";
76 }
77 oss << "]";
78 MS_LOG(INFO) << oss.str();
79 return false;
80 }
81 if (outputs[i]->type != outputs_info_[i].type) {
82 MS_LOG(INFO) << "Op " << this->name_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
83 << outputs_info_[i].type << "]";
84 return false;
85 }
86 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
87 bool format_check_condition =
88 (outputs[i]->format != kOpFormat_DEFAULT && outputs_info_[i].format != kOpFormat_DEFAULT) &&
89 outputs[i]->format != outputs_info_[i].format;
90 #else
91 bool format_check_condition = outputs[i]->format != outputs_info_[i].format;
92 if ((outputs[i]->format == kOpFormat_DEFAULT && outputs_info_[i].format == kOpFormat_NCHW) ||
93 (outputs[i]->format == kOpFormat_NCHW && outputs_info_[i].format == kOpFormat_DEFAULT)) {
94 format_check_condition = false;
95 }
96 #endif
97 if (format_check_condition) {
98 MS_LOG(INFO) << "Op " << this->name_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
99 << outputs_info_[i].format << "]";
100 return false;
101 }
102 }
103 return true;
104 }
105
GetAxisList(const ValuePtr & value)106 std::vector<int64_t> GetAxisList(const ValuePtr &value) {
107 std::vector<int64_t> result;
108 auto get_int_value = [](const ValuePtr &value) -> int64_t {
109 return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
110 };
111 if (value->isa<ValueSequence>()) {
112 const auto &vals = value->cast<ValueSequencePtr>()->value();
113 (void)std::transform(vals.begin(), vals.end(), std::back_inserter(result), get_int_value);
114 } else if (value->isa<tensor::Tensor>()) {
115 result = CheckAndConvertUtils::CheckTensorIntValue("axes value", value, "GetAxisList");
116 } else {
117 result.push_back(get_int_value(value));
118 }
119 return result;
120 }
121
InferShapeFromFractalnz(const std::vector<int64_t> & fractal)122 std::vector<int64_t> InferShapeFromFractalnz(const std::vector<int64_t> &fractal) {
123 std::vector<int64_t> shape;
124 size_t dims = fractal.size();
125 size_t batch = dims - OFFSET4;
126 for (size_t i = 0; i < batch; i++) {
127 shape.push_back(fractal[i]);
128 }
129 shape.push_back(fractal[dims - OFFSET3] * fractal[dims - OFFSET2]);
130 shape.push_back(fractal[dims - OFFSET4] * fractal[dims - OFFSET1]);
131 return shape;
132 }
133
GetReducedOriShape(const std::vector<int64_t> & shape,const std::vector<int64_t> & axis)134 std::vector<int64_t> GetReducedOriShape(const std::vector<int64_t> &shape, const std::vector<int64_t> &axis) {
135 std::vector<int64_t> reduced_ori_shape;
136 std::unordered_set<int64_t> axis_set(axis.begin(), axis.end());
137 for (size_t i = 0; i < shape.size(); i++) {
138 if (axis_set.count(SizeToLong(i)) > 0) {
139 reduced_ori_shape.push_back(1);
140 } else {
141 reduced_ori_shape.push_back(shape[i]);
142 }
143 }
144 return reduced_ori_shape;
145 }
146
ToFracZAxis(const std::vector<int64_t> & ori_shape,const std::vector<int64_t> & ori_axis)147 std::vector<int64_t> ToFracZAxis(const std::vector<int64_t> &ori_shape, const std::vector<int64_t> &ori_axis) {
148 std::vector<int64_t> frac_z_axis = ori_axis;
149 int64_t shape_len = SizeToLong(ori_shape.size());
150 if (shape_len == 0) {
151 MS_LOG(EXCEPTION) << "In ToFracZAxis, divisor is zero";
152 }
153 for (size_t i = 0; i < frac_z_axis.size(); i++) {
154 int64_t axis_index = (frac_z_axis[i] + shape_len) % shape_len;
155 if (axis_index == shape_len - OFFSET1) {
156 frac_z_axis[i] = axis_index - OFFSET1;
157 frac_z_axis.push_back(axis_index + OFFSET2);
158 } else if (axis_index == shape_len - OFFSET2) {
159 frac_z_axis[i] = axis_index + OFFSET1;
160 frac_z_axis.push_back(axis_index + OFFSET2);
161 } else {
162 frac_z_axis[i] = axis_index;
163 }
164 }
165 return frac_z_axis;
166 }
167 } // namespace mindspore::graphkernel::expanders
168