• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/composite/do_signature.h"
18 #include <algorithm>
19 #include <utility>
20 
21 #include "abstract/abstract_value.h"
22 #include "abstract/dshape.h"
23 #include "abstract/param_validator.h"
24 #include "frontend/operator/cc_implementations.h"
25 #include "frontend/optimizer/opt.h"
26 #include "include/common/utils/convert_utils.h"
27 #include "include/common/pybind_api/api_register.h"
28 #include "ir/anf.h"
29 #include "ir/dtype.h"
30 #include "ops/op_def.h"
31 #include "mindspore/core/utils/flags.h"
32 #include "mindspore/core/ops/arithmetic_ops.h"
33 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
34 
35 namespace mindspore {
36 // namespace to support composite operators definition
37 namespace prim {
38 const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1},    {kNumberTypeInt8, 2},    {kNumberTypeUInt8, 3},
39                                            {kNumberTypeInt16, 4},   {kNumberTypeInt32, 5},   {kNumberTypeInt64, 6},
40                                            {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}};
41 namespace {
GetSignature(const ValuePtr & function)42 const std::vector<Signature> &GetSignature(const ValuePtr &function) {
43   static const auto empty = std::vector<Signature>();
44   if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) {
45     return function->cast<PrimitivePtr>()->signatures();
46   } else if (function->isa<MetaFuncGraph>()) {
47     return function->cast<MetaFuncGraphPtr>()->signatures();
48   }
49   return empty;
50 }
51 
ProcessDefault(const std::string & func_name,size_t actual_param_number,const std::vector<Signature> & signature,bool has_var,std::vector<AnfNodePtr> * op_inputs)52 void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
53                     bool has_var, std::vector<AnfNodePtr> *op_inputs) {
54   std::size_t sig_size = signature.size();
55   auto positional_size = sig_size;
56   if (has_var) {
57     positional_size = sig_size - 1;
58   }
59   if (actual_param_number < positional_size) {
60     for (size_t i = actual_param_number; i < sig_size; ++i) {
61       auto default_value = signature[i].default_value;
62       if (default_value == nullptr) {
63         MS_LOG(EXCEPTION) << "For '" << func_name << "', the size of input should be " << sig_size << ", but got "
64                           << actual_param_number << ". Please check inputs of the operator.";
65       } else {
66         (*op_inputs).push_back(NewValueNode(default_value));
67       }
68     }
69   }
70 }
71 
GetTypeInfo(const std::vector<TypePtr> & input_types,std::vector<TypeId> * args_type_id,std::vector<bool> * args_has_tensor)72 void GetTypeInfo(const std::vector<TypePtr> &input_types, std::vector<TypeId> *args_type_id,
73                  std::vector<bool> *args_has_tensor) {
74   for (const auto &arg_type : input_types) {
75     if (arg_type->isa<Number>()) {
76       (void)args_type_id->emplace_back(arg_type->cast<NumberPtr>()->type_id());
77       (void)args_has_tensor->emplace_back(false);
78     } else if (arg_type->isa<TensorType>()) {
79       auto elem_type = arg_type->cast<TensorTypePtr>()->element();
80       MS_EXCEPTION_IF_NULL(elem_type);
81       (void)args_type_id->emplace_back(elem_type->type_id());
82       (void)args_has_tensor->emplace_back(true);
83     } else {
84       (void)args_type_id->emplace_back(kTypeUnknown);
85       (void)args_has_tensor->emplace_back(false);
86     }
87   }
88 }
89 
DoAutoCast(const std::vector<Signature> & signature,const std::vector<TypePtr> & input_types,const FuncGraphPtr & graph,const std::pair<ValuePtr,std::set<size_t>> & write_indices_pair,std::vector<AnfNodePtr> * op_inputs)90 void DoAutoCast(const std::vector<Signature> &signature, const std::vector<TypePtr> &input_types,
91                 const FuncGraphPtr &graph, const std::pair<ValuePtr, std::set<size_t>> &write_indices_pair,
92                 std::vector<AnfNodePtr> *op_inputs) {
93   MS_EXCEPTION_IF_NULL(graph);
94   std::vector<SignatureEnumDType> dtypes;
95   (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
96                        [](const Signature &sig) { return sig.dtype; });
97   int64_t empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
98   if (dtypes.empty() || static_cast<int64_t>(dtypes.size()) == empty_dtype_count) {
99     return;
100   }
101   auto args_size = signature.size();
102   if (args_size > input_types.size() || args_size > op_inputs->size()) {
103     // It is possible that op_inputs size is larger than signatures size in vmap.
104     MS_LOG(INTERNAL_EXCEPTION) << "For auto type cast, the number of args should be greater than or equal to "
105                                << args_size << ", but got input_types size: " << input_types.size()
106                                << ", op_inputs size: " << op_inputs->size();
107   }
108   auto func = write_indices_pair.first;
109   auto write_indices = write_indices_pair.second;
110 
111   std::vector<TypeId> args_type_id;
112   std::vector<bool> args_has_tensor;
113   GetTypeInfo(input_types, &args_type_id, &args_has_tensor);
114   auto sig_type_map = GetSignatureTypeMap(dtypes, args_type_id, args_has_tensor, write_indices);
115   for (size_t i = 0; i < args_size; ++i) {
116     auto it = sig_type_map.find(dtypes[i]);
117     if (it == sig_type_map.end()) {
118       continue;
119     }
120     TypeId current_type_id = args_type_id[i];
121     TypeId target_type_id = (it->second).first;
122     if (current_type_id == kTypeUnknown || target_type_id == kTypeUnknown) {
123       continue;
124     }
125     if (write_indices.find(i) != write_indices.end() && current_type_id != target_type_id) {
126       RaiseExceptionForConvertRefDtype(func, TypeIdToString(current_type_id), TypeIdToString(target_type_id), i);
127     }
128     bool arg_is_tensor = args_has_tensor[i];
129     bool contain_tensor = (it->second).second;
130     bool need_scalar_to_tensor = !arg_is_tensor && contain_tensor;
131     auto param = (*op_inputs)[i];
132     auto target_type_node = NewValueNode(static_cast<int64_t>(target_type_id));
133     if (need_scalar_to_tensor) {
134       auto current_type_node = NewValueNode(static_cast<int64_t>(current_type_id));
135       param = graph->NewCNodeAfter(param, {NewValueNode(prim::kPrimScalarToTensor), param, current_type_node});
136       (*op_inputs)[i] = graph->NewCNodeAfter(param, {NewValueNode(prim::kPrimCast), param, target_type_node});
137     } else if (current_type_id != target_type_id) {
138       PrimitivePtr cast_op = contain_tensor ? prim::kPrimCast : prim::kPrimScalarCast;
139       (*op_inputs)[i] = graph->NewCNodeAfter(param, {NewValueNode(cast_op), param, target_type_node});
140     }
141   }
142 }
143 
CheckSigSize(const ValuePtr & function,const size_t & sig_size,const bool & has_var,const AbstractBasePtrList & args_abs_list,const std::string & func_name)144 void CheckSigSize(const ValuePtr &function, const size_t &sig_size, const bool &has_var,
145                   const AbstractBasePtrList &args_abs_list, const std::string &func_name) {
146   if (sig_size > 0) {
147     if (has_var) {
148       if (sig_size - 1 > args_abs_list.size()) {
149         MS_LOG(EXCEPTION) << "Function " << func_name
150                           << "'s input length less than PositionalKeyword Signature length.";
151       }
152       return;
153     }
154     // Consider the case where there are monads in primitive's args_abs_list.
155     size_t args_size = args_abs_list.size();
156     if (function->isa<Primitive>()) {
157       auto prim = function->cast<PrimitivePtr>();
158       if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
159         args_size -= GetAbstractMonadNum(args_abs_list);
160       }
161     }
162     if (args_size > sig_size) {
163       MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
164     }
165   }
166 }
167 
GetSignatureEnumRW(size_t index,const std::vector<Signature> & signature,bool has_var)168 SignatureEnumRW GetSignatureEnumRW(size_t index, const std::vector<Signature> &signature, bool has_var) {
169   SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
170   // If sig_size is 0 use default.
171   std::size_t sig_size = signature.size();
172   if (index < sig_size) {
173     sig = signature[index].rw;
174   } else if (has_var && index >= sig_size) {
175     sig = signature[sig_size - 1].rw;
176   }
177   return sig;
178 }
179 
GetMixedPrecisionTargetType(const FuncGraphPtr & func_graph)180 TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
181   MS_EXCEPTION_IF_NULL(func_graph);
182   if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
183     return kFloat32;
184   } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
185     return kFloat16;
186   } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_BF16)) {
187     return kBFloat16;
188   } else {
189     return nullptr;
190   }
191 }
192 }  // namespace
193 
GetNewInputsBySignatures(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_abs_list,const std::vector<AnfNodePtr> & params_list)194 std::vector<AnfNodePtr> GetNewInputsBySignatures(const FuncGraphPtr &func_graph, const std::string &func_name,
195                                                  const ValuePtr &function, const AbstractBasePtrList &args_abs_list,
196                                                  const std::vector<AnfNodePtr> &params_list) {
197   // args: original inputs
198   auto &signature = GetSignature(function);
199   std::size_t sig_size = signature.size();
200   auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
201   CheckSigSize(function, sig_size, has_var, args_abs_list, func_name);
202   std::vector<AnfNodePtr> op_inputs;
203   std::set<size_t> write_indices;
204   std::vector<TypePtr> input_types;
205   auto cast_type = GetMixedPrecisionTargetType(func_graph);
206   // Assume, the write input of op is always the first input. We check if any write op,
207   // and add cast op on other inputs to keep the same type with assigned parameter.
208   for (size_t i = 0; i < args_abs_list.size(); ++i) {
209     AnfNodePtr param = params_list[i];
210     if (args_abs_list[i] == nullptr) {
211       op_inputs.push_back(param);
212       continue;
213     }
214 
215     SignatureEnumRW sig = GetSignatureEnumRW(i, signature, has_var);
216     TypePtr type = args_abs_list[i]->BuildType();
217     if (type && type->isa<RefType>()) {
218       if (sig == SignatureEnumRW::kRWRead) {
219         auto source_tensor_type = type->cast<TensorTypePtr>();
220         if (source_tensor_type != nullptr) {
221           auto source_element = source_tensor_type->element();
222           if (cast_type != nullptr && (IsSubType(source_element, kFloat) || IsSubType(source_element, kBFloat)) &&
223               *source_element != *cast_type) {
224             auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
225             param = func_graph->NewCNodeAfter(param, {NewValueNode(cast), param, NewValueNode(cast_type)});
226             type = cast_type->type_id() == kNumberTypeFloat16
227                      ? kTensorTypeFP16
228                      : (cast_type->type_id() == kNumberTypeBFloat16 ? kTensorTypeBF16 : kTensorTypeFP32);
229           }
230         }
231       } else if (sig == SignatureEnumRW::kRWWrite) {
232         write_indices.insert(i);
233       }
234       // If sig is SignatureEnumRW::kRWRef, not do anything.
235     } else if (sig == SignatureEnumRW::kRWWrite &&
236                !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
237       RaiseExceptionForCheckParameter(func_name, i, type->ToString());
238     }
239     MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
240                   << args_abs_list[i]->ToString() << " type " << type->ToString() << ".";
241     input_types.push_back(type);
242     op_inputs.push_back(param);
243   }
244   // process default
245   ProcessDefault(func_name, args_abs_list.size(), signature, has_var, &op_inputs);
246   auto write_indices_pair = std::make_pair(function, write_indices);
247   DoAutoCast(signature, input_types, func_graph, write_indices_pair, &op_inputs);
248   return op_inputs;
249 }
250 
GenerateCNode(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_abs_list,const AnfNodePtrList & old_node_inputs)251 AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
252                          const AbstractBasePtrList &args_abs_list, const AnfNodePtrList &old_node_inputs) {
253   auto new_inputs = GetNewInputsBySignatures(func_graph, func_name, function, args_abs_list, old_node_inputs);
254   AnfNodePtrList op_inputs{NewValueNode(function)};
255   (void)std::copy(new_inputs.begin(), new_inputs.end(), std::back_inserter(op_inputs));
256   return func_graph->NewCNodeInOrder(op_inputs);
257 }
258 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)259 FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
260   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
261 
262   for (size_t i = 0; i < args_abs_list.size(); ++i) {
263     (void)func_graph->add_parameter();
264   }
265   auto new_cnode = GenerateCNode(func_graph, name_, function_, args_abs_list, func_graph->parameters());
266   func_graph->set_output(new_cnode);
267   func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
268   return func_graph;
269 }
270 
RaiseExceptionForConvertRefDtype(const ValuePtr & func,const std::string & ref_type,const std::string & target_type,size_t index)271 void RaiseExceptionForConvertRefDtype(const ValuePtr &func, const std::string &ref_type, const std::string &target_type,
272                                       size_t index) {
273   std::ostringstream buffer;
274   if (func->isa<Primitive>()) {
275     auto prim = func->cast<PrimitivePtr>();
276     auto args_names_value = prim->GetAttr("input_names");
277     if (args_names_value != nullptr) {
278       auto args_names = GetValue<std::vector<std::string>>(args_names_value);
279       if (index < args_names.size()) {
280         buffer << " the argument[" << args_names[index] << "]'s data type of primitive[" << prim->name() << "] is ";
281       }
282     }
283   }
284   if (buffer.str().empty()) {
285     buffer << " so data type ";
286   }
287   MS_EXCEPTION(TypeError) << "Data type conversion of 'Parameter' is not supported," << buffer.str() << ref_type
288                           << ", which cannot be converted to data type " << target_type << " automatically.\n";
289 }
290 
RaiseExceptionForCheckParameter(const std::string & func_name,size_t i,const std::string & source_type)291 void RaiseExceptionForCheckParameter(const std::string &func_name, size_t i, const std::string &source_type) {
292   MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
293                           << source_type << ".";
294 }
295 }  // namespace prim
296 }  // namespace mindspore
297