1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/composite/do_signature.h"
18 #include <algorithm>
19 #include <utility>
20
21 #include "abstract/abstract_value.h"
22 #include "abstract/dshape.h"
23 #include "abstract/param_validator.h"
24 #include "frontend/operator/cc_implementations.h"
25 #include "frontend/optimizer/opt.h"
26 #include "include/common/utils/convert_utils.h"
27 #include "include/common/pybind_api/api_register.h"
28 #include "ir/anf.h"
29 #include "ir/dtype.h"
30 #include "ops/op_def.h"
31 #include "mindspore/core/utils/flags.h"
32 #include "mindspore/core/ops/arithmetic_ops.h"
33 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
34
35 namespace mindspore {
36 // namespace to support composite operators definition
37 namespace prim {
38 const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3},
39 {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6},
40 {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}};
41 namespace {
GetSignature(const ValuePtr & function)42 const std::vector<Signature> &GetSignature(const ValuePtr &function) {
43 static const auto empty = std::vector<Signature>();
44 if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) {
45 return function->cast<PrimitivePtr>()->signatures();
46 } else if (function->isa<MetaFuncGraph>()) {
47 return function->cast<MetaFuncGraphPtr>()->signatures();
48 }
49 return empty;
50 }
51
ProcessDefault(const std::string & func_name,size_t actual_param_number,const std::vector<Signature> & signature,bool has_var,std::vector<AnfNodePtr> * op_inputs)52 void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
53 bool has_var, std::vector<AnfNodePtr> *op_inputs) {
54 std::size_t sig_size = signature.size();
55 auto positional_size = sig_size;
56 if (has_var) {
57 positional_size = sig_size - 1;
58 }
59 if (actual_param_number < positional_size) {
60 for (size_t i = actual_param_number; i < sig_size; ++i) {
61 auto default_value = signature[i].default_value;
62 if (default_value == nullptr) {
63 MS_LOG(EXCEPTION) << "For '" << func_name << "', the size of input should be " << sig_size << ", but got "
64 << actual_param_number << ". Please check inputs of the operator.";
65 } else {
66 (*op_inputs).push_back(NewValueNode(default_value));
67 }
68 }
69 }
70 }
71
GetTypeInfo(const std::vector<TypePtr> & input_types,std::vector<TypeId> * args_type_id,std::vector<bool> * args_has_tensor)72 void GetTypeInfo(const std::vector<TypePtr> &input_types, std::vector<TypeId> *args_type_id,
73 std::vector<bool> *args_has_tensor) {
74 for (const auto &arg_type : input_types) {
75 if (arg_type->isa<Number>()) {
76 (void)args_type_id->emplace_back(arg_type->cast<NumberPtr>()->type_id());
77 (void)args_has_tensor->emplace_back(false);
78 } else if (arg_type->isa<TensorType>()) {
79 auto elem_type = arg_type->cast<TensorTypePtr>()->element();
80 MS_EXCEPTION_IF_NULL(elem_type);
81 (void)args_type_id->emplace_back(elem_type->type_id());
82 (void)args_has_tensor->emplace_back(true);
83 } else {
84 (void)args_type_id->emplace_back(kTypeUnknown);
85 (void)args_has_tensor->emplace_back(false);
86 }
87 }
88 }
89
DoAutoCast(const std::vector<Signature> & signature,const std::vector<TypePtr> & input_types,const FuncGraphPtr & graph,const std::pair<ValuePtr,std::set<size_t>> & write_indices_pair,std::vector<AnfNodePtr> * op_inputs)90 void DoAutoCast(const std::vector<Signature> &signature, const std::vector<TypePtr> &input_types,
91 const FuncGraphPtr &graph, const std::pair<ValuePtr, std::set<size_t>> &write_indices_pair,
92 std::vector<AnfNodePtr> *op_inputs) {
93 MS_EXCEPTION_IF_NULL(graph);
94 std::vector<SignatureEnumDType> dtypes;
95 (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
96 [](const Signature &sig) { return sig.dtype; });
97 int64_t empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
98 if (dtypes.empty() || static_cast<int64_t>(dtypes.size()) == empty_dtype_count) {
99 return;
100 }
101 auto args_size = signature.size();
102 if (args_size > input_types.size() || args_size > op_inputs->size()) {
103 // It is possible that op_inputs size is larger than signatures size in vmap.
104 MS_LOG(INTERNAL_EXCEPTION) << "For auto type cast, the number of args should be greater than or equal to "
105 << args_size << ", but got input_types size: " << input_types.size()
106 << ", op_inputs size: " << op_inputs->size();
107 }
108 auto func = write_indices_pair.first;
109 auto write_indices = write_indices_pair.second;
110
111 std::vector<TypeId> args_type_id;
112 std::vector<bool> args_has_tensor;
113 GetTypeInfo(input_types, &args_type_id, &args_has_tensor);
114 auto sig_type_map = GetSignatureTypeMap(dtypes, args_type_id, args_has_tensor, write_indices);
115 for (size_t i = 0; i < args_size; ++i) {
116 auto it = sig_type_map.find(dtypes[i]);
117 if (it == sig_type_map.end()) {
118 continue;
119 }
120 TypeId current_type_id = args_type_id[i];
121 TypeId target_type_id = (it->second).first;
122 if (current_type_id == kTypeUnknown || target_type_id == kTypeUnknown) {
123 continue;
124 }
125 if (write_indices.find(i) != write_indices.end() && current_type_id != target_type_id) {
126 RaiseExceptionForConvertRefDtype(func, TypeIdToString(current_type_id), TypeIdToString(target_type_id), i);
127 }
128 bool arg_is_tensor = args_has_tensor[i];
129 bool contain_tensor = (it->second).second;
130 bool need_scalar_to_tensor = !arg_is_tensor && contain_tensor;
131 auto param = (*op_inputs)[i];
132 auto target_type_node = NewValueNode(static_cast<int64_t>(target_type_id));
133 if (need_scalar_to_tensor) {
134 auto current_type_node = NewValueNode(static_cast<int64_t>(current_type_id));
135 param = graph->NewCNodeAfter(param, {NewValueNode(prim::kPrimScalarToTensor), param, current_type_node});
136 (*op_inputs)[i] = graph->NewCNodeAfter(param, {NewValueNode(prim::kPrimCast), param, target_type_node});
137 } else if (current_type_id != target_type_id) {
138 PrimitivePtr cast_op = contain_tensor ? prim::kPrimCast : prim::kPrimScalarCast;
139 (*op_inputs)[i] = graph->NewCNodeAfter(param, {NewValueNode(cast_op), param, target_type_node});
140 }
141 }
142 }
143
CheckSigSize(const ValuePtr & function,const size_t & sig_size,const bool & has_var,const AbstractBasePtrList & args_abs_list,const std::string & func_name)144 void CheckSigSize(const ValuePtr &function, const size_t &sig_size, const bool &has_var,
145 const AbstractBasePtrList &args_abs_list, const std::string &func_name) {
146 if (sig_size > 0) {
147 if (has_var) {
148 if (sig_size - 1 > args_abs_list.size()) {
149 MS_LOG(EXCEPTION) << "Function " << func_name
150 << "'s input length less than PositionalKeyword Signature length.";
151 }
152 return;
153 }
154 // Consider the case where there are monads in primitive's args_abs_list.
155 size_t args_size = args_abs_list.size();
156 if (function->isa<Primitive>()) {
157 auto prim = function->cast<PrimitivePtr>();
158 if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
159 args_size -= GetAbstractMonadNum(args_abs_list);
160 }
161 }
162 if (args_size > sig_size) {
163 MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
164 }
165 }
166 }
167
GetSignatureEnumRW(size_t index,const std::vector<Signature> & signature,bool has_var)168 SignatureEnumRW GetSignatureEnumRW(size_t index, const std::vector<Signature> &signature, bool has_var) {
169 SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
170 // If sig_size is 0 use default.
171 std::size_t sig_size = signature.size();
172 if (index < sig_size) {
173 sig = signature[index].rw;
174 } else if (has_var && index >= sig_size) {
175 sig = signature[sig_size - 1].rw;
176 }
177 return sig;
178 }
179
GetMixedPrecisionTargetType(const FuncGraphPtr & func_graph)180 TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
181 MS_EXCEPTION_IF_NULL(func_graph);
182 if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
183 return kFloat32;
184 } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
185 return kFloat16;
186 } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_BF16)) {
187 return kBFloat16;
188 } else {
189 return nullptr;
190 }
191 }
192 } // namespace
193
GetNewInputsBySignatures(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_abs_list,const std::vector<AnfNodePtr> & params_list)194 std::vector<AnfNodePtr> GetNewInputsBySignatures(const FuncGraphPtr &func_graph, const std::string &func_name,
195 const ValuePtr &function, const AbstractBasePtrList &args_abs_list,
196 const std::vector<AnfNodePtr> ¶ms_list) {
197 // args: original inputs
198 auto &signature = GetSignature(function);
199 std::size_t sig_size = signature.size();
200 auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
201 CheckSigSize(function, sig_size, has_var, args_abs_list, func_name);
202 std::vector<AnfNodePtr> op_inputs;
203 std::set<size_t> write_indices;
204 std::vector<TypePtr> input_types;
205 auto cast_type = GetMixedPrecisionTargetType(func_graph);
206 // Assume, the write input of op is always the first input. We check if any write op,
207 // and add cast op on other inputs to keep the same type with assigned parameter.
208 for (size_t i = 0; i < args_abs_list.size(); ++i) {
209 AnfNodePtr param = params_list[i];
210 if (args_abs_list[i] == nullptr) {
211 op_inputs.push_back(param);
212 continue;
213 }
214
215 SignatureEnumRW sig = GetSignatureEnumRW(i, signature, has_var);
216 TypePtr type = args_abs_list[i]->BuildType();
217 if (type && type->isa<RefType>()) {
218 if (sig == SignatureEnumRW::kRWRead) {
219 auto source_tensor_type = type->cast<TensorTypePtr>();
220 if (source_tensor_type != nullptr) {
221 auto source_element = source_tensor_type->element();
222 if (cast_type != nullptr && (IsSubType(source_element, kFloat) || IsSubType(source_element, kBFloat)) &&
223 *source_element != *cast_type) {
224 auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
225 param = func_graph->NewCNodeAfter(param, {NewValueNode(cast), param, NewValueNode(cast_type)});
226 type = cast_type->type_id() == kNumberTypeFloat16
227 ? kTensorTypeFP16
228 : (cast_type->type_id() == kNumberTypeBFloat16 ? kTensorTypeBF16 : kTensorTypeFP32);
229 }
230 }
231 } else if (sig == SignatureEnumRW::kRWWrite) {
232 write_indices.insert(i);
233 }
234 // If sig is SignatureEnumRW::kRWRef, not do anything.
235 } else if (sig == SignatureEnumRW::kRWWrite &&
236 !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
237 RaiseExceptionForCheckParameter(func_name, i, type->ToString());
238 }
239 MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
240 << args_abs_list[i]->ToString() << " type " << type->ToString() << ".";
241 input_types.push_back(type);
242 op_inputs.push_back(param);
243 }
244 // process default
245 ProcessDefault(func_name, args_abs_list.size(), signature, has_var, &op_inputs);
246 auto write_indices_pair = std::make_pair(function, write_indices);
247 DoAutoCast(signature, input_types, func_graph, write_indices_pair, &op_inputs);
248 return op_inputs;
249 }
250
GenerateCNode(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_abs_list,const AnfNodePtrList & old_node_inputs)251 AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
252 const AbstractBasePtrList &args_abs_list, const AnfNodePtrList &old_node_inputs) {
253 auto new_inputs = GetNewInputsBySignatures(func_graph, func_name, function, args_abs_list, old_node_inputs);
254 AnfNodePtrList op_inputs{NewValueNode(function)};
255 (void)std::copy(new_inputs.begin(), new_inputs.end(), std::back_inserter(op_inputs));
256 return func_graph->NewCNodeInOrder(op_inputs);
257 }
258
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)259 FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
260 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
261
262 for (size_t i = 0; i < args_abs_list.size(); ++i) {
263 (void)func_graph->add_parameter();
264 }
265 auto new_cnode = GenerateCNode(func_graph, name_, function_, args_abs_list, func_graph->parameters());
266 func_graph->set_output(new_cnode);
267 func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
268 return func_graph;
269 }
270
RaiseExceptionForConvertRefDtype(const ValuePtr & func,const std::string & ref_type,const std::string & target_type,size_t index)271 void RaiseExceptionForConvertRefDtype(const ValuePtr &func, const std::string &ref_type, const std::string &target_type,
272 size_t index) {
273 std::ostringstream buffer;
274 if (func->isa<Primitive>()) {
275 auto prim = func->cast<PrimitivePtr>();
276 auto args_names_value = prim->GetAttr("input_names");
277 if (args_names_value != nullptr) {
278 auto args_names = GetValue<std::vector<std::string>>(args_names_value);
279 if (index < args_names.size()) {
280 buffer << " the argument[" << args_names[index] << "]'s data type of primitive[" << prim->name() << "] is ";
281 }
282 }
283 }
284 if (buffer.str().empty()) {
285 buffer << " so data type ";
286 }
287 MS_EXCEPTION(TypeError) << "Data type conversion of 'Parameter' is not supported," << buffer.str() << ref_type
288 << ", which cannot be converted to data type " << target_type << " automatically.\n";
289 }
290
RaiseExceptionForCheckParameter(const std::string & func_name,size_t i,const std::string & source_type)291 void RaiseExceptionForCheckParameter(const std::string &func_name, size_t i, const std::string &source_type) {
292 MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
293 << source_type << ".";
294 }
295 } // namespace prim
296 } // namespace mindspore
297