• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/expanders/utils.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 #include <unordered_set>
22 
23 #include "backend/common/graph_kernel/model/lite_graph.h"
24 #include "backend/common/graph_kernel/model/node.h"
25 #include "ir/value.h"
26 #include "utils/check_convert_utils.h"
27 
28 namespace mindspore::graphkernel::expanders {
29 constexpr int OFFSET1 = 1;
30 constexpr int OFFSET2 = 2;
31 constexpr int OFFSET3 = 3;
32 constexpr int OFFSET4 = 4;
Run(const BaseInfoList & inputs,const BaseInfoList & outputs,const inner::DAttrs & attrs,const std::string & processor)33 inner::LiteGraphPtr OpDesc::Run(const BaseInfoList &inputs, const BaseInfoList &outputs, const inner::DAttrs &attrs,
34                                 const std::string &processor) {
35   this->inputs_info_ = inputs;
36   this->outputs_info_ = outputs;
37   this->attrs_ = attrs;
38   this->processor_ = processor;
39   if (std::any_of(validators_.begin(), validators_.end(),
40                   [this](const std::unique_ptr<Validator> &v) { return !(v->Check(*this)); })) {
41     return nullptr;
42   }
43   Init();
44   if (!this->CheckInputs()) {
45     return nullptr;
46   }
47   for (auto &inp : inputs) {
48     (void)gb.Parameter(inp);
49   }
50   auto result = this->Expand(gb.Get()->inputs());
51   gb.SetOutputs(result);
52   if (!this->CheckOutputs()) {
53     return nullptr;
54   }
55   return gb.Get();
56 }
57 
CheckOutputs()58 bool OpDesc::CheckOutputs() {
59   // check the output shape/type/format are same as the original basic node's output.
60   const NodePtrList &outputs = gb.Get()->GetOutputs();
61   if (outputs.size() != this->outputs_info_.size()) {
62     MS_LOG(INFO) << "the output num was not equal to the original output num : " << outputs.size() << " vs "
63                  << outputs_info_.size();
64     return false;
65   }
66   for (size_t i = 0; i < outputs.size(); i++) {
67     if (outputs[i]->shape != outputs_info_[i].shape) {
68       std::ostringstream oss;
69       oss << "Op " << this->name_ << "'s output shape [";
70       for (auto s : outputs[i]->shape) {
71         oss << s << ",";
72       }
73       oss << "] is wrong. expect: [";
74       for (auto s : outputs_info_[i].shape) {
75         oss << s << ",";
76       }
77       oss << "]";
78       MS_LOG(INFO) << oss.str();
79       return false;
80     }
81     if (outputs[i]->type != outputs_info_[i].type) {
82       MS_LOG(INFO) << "Op " << this->name_ << "'s output type [" << outputs[i]->type << "] is wrong, expect: ["
83                    << outputs_info_[i].type << "]";
84       return false;
85     }
86 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
87     bool format_check_condition =
88       (outputs[i]->format != kOpFormat_DEFAULT && outputs_info_[i].format != kOpFormat_DEFAULT) &&
89       outputs[i]->format != outputs_info_[i].format;
90 #else
91     bool format_check_condition = outputs[i]->format != outputs_info_[i].format;
92     if ((outputs[i]->format == kOpFormat_DEFAULT && outputs_info_[i].format == kOpFormat_NCHW) ||
93         (outputs[i]->format == kOpFormat_NCHW && outputs_info_[i].format == kOpFormat_DEFAULT)) {
94       format_check_condition = false;
95     }
96 #endif
97     if (format_check_condition) {
98       MS_LOG(INFO) << "Op " << this->name_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: ["
99                    << outputs_info_[i].format << "]";
100       return false;
101     }
102   }
103   return true;
104 }
105 
GetAxisList(const ValuePtr & value)106 std::vector<int64_t> GetAxisList(const ValuePtr &value) {
107   std::vector<int64_t> result;
108   auto get_int_value = [](const ValuePtr &value) -> int64_t {
109     return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
110   };
111   if (value->isa<ValueSequence>()) {
112     const auto &vals = value->cast<ValueSequencePtr>()->value();
113     (void)std::transform(vals.begin(), vals.end(), std::back_inserter(result), get_int_value);
114   } else if (value->isa<tensor::Tensor>()) {
115     result = CheckAndConvertUtils::CheckTensorIntValue("axes value", value, "GetAxisList");
116   } else {
117     result.push_back(get_int_value(value));
118   }
119   return result;
120 }
121 
InferShapeFromFractalnz(const std::vector<int64_t> & fractal)122 std::vector<int64_t> InferShapeFromFractalnz(const std::vector<int64_t> &fractal) {
123   std::vector<int64_t> shape;
124   size_t dims = fractal.size();
125   size_t batch = dims - OFFSET4;
126   for (size_t i = 0; i < batch; i++) {
127     shape.push_back(fractal[i]);
128   }
129   shape.push_back(fractal[dims - OFFSET3] * fractal[dims - OFFSET2]);
130   shape.push_back(fractal[dims - OFFSET4] * fractal[dims - OFFSET1]);
131   return shape;
132 }
133 
GetReducedOriShape(const std::vector<int64_t> & shape,const std::vector<int64_t> & axis)134 std::vector<int64_t> GetReducedOriShape(const std::vector<int64_t> &shape, const std::vector<int64_t> &axis) {
135   std::vector<int64_t> reduced_ori_shape;
136   std::unordered_set<int64_t> axis_set(axis.begin(), axis.end());
137   for (size_t i = 0; i < shape.size(); i++) {
138     if (axis_set.count(SizeToLong(i)) > 0) {
139       reduced_ori_shape.push_back(1);
140     } else {
141       reduced_ori_shape.push_back(shape[i]);
142     }
143   }
144   return reduced_ori_shape;
145 }
146 
ToFracZAxis(const std::vector<int64_t> & ori_shape,const std::vector<int64_t> & ori_axis)147 std::vector<int64_t> ToFracZAxis(const std::vector<int64_t> &ori_shape, const std::vector<int64_t> &ori_axis) {
148   std::vector<int64_t> frac_z_axis = ori_axis;
149   int64_t shape_len = SizeToLong(ori_shape.size());
150   if (shape_len == 0) {
151     MS_LOG(EXCEPTION) << "In ToFracZAxis, divisor is zero";
152   }
153   for (size_t i = 0; i < frac_z_axis.size(); i++) {
154     int64_t axis_index = (frac_z_axis[i] + shape_len) % shape_len;
155     if (axis_index == shape_len - OFFSET1) {
156       frac_z_axis[i] = axis_index - OFFSET1;
157       frac_z_axis.push_back(axis_index + OFFSET2);
158     } else if (axis_index == shape_len - OFFSET2) {
159       frac_z_axis[i] = axis_index + OFFSET1;
160       frac_z_axis.push_back(axis_index + OFFSET2);
161     } else {
162       frac_z_axis[i] = axis_index;
163     }
164   }
165   return frac_z_axis;
166 }
167 }  // namespace mindspore::graphkernel::expanders
168