1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/composite/do_signature.h"
18 #include <algorithm>
19 #include <utility>
20
21 #include "abstract/abstract_value.h"
22 #include "ir/anf.h"
23 #include "ir/dtype.h"
24 #include "abstract/dshape.h"
25 #include "abstract/param_validator.h"
26 #include "frontend/operator/cc_implementations.h"
27 #include "frontend/optimizer/opt.h"
28 #include "pybind_api/api_register.h"
29
30 namespace mindspore {
31 // namespace to support composite operators definition
32 namespace prim {
33 const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3},
34 {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6},
35 {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}};
36 namespace {
GetSignature(const ValuePtr & function)37 const std::vector<Signature> &GetSignature(const ValuePtr &function) {
38 static const auto empty = std::vector<Signature>();
39 if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) {
40 return function->cast<PrimitivePyPtr>()->signatures();
41 } else if (function->isa<MetaFuncGraph>()) {
42 return function->cast<MetaFuncGraphPtr>()->signatures();
43 }
44 return empty;
45 }
46
ProcessDefault(const std::string & func_name,size_t actual_param_number,const std::vector<Signature> & signature,bool has_var,std::vector<AnfNodePtr> * const op_inputs)47 void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
48 bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
49 std::size_t sig_size = signature.size();
50 auto positional_size = sig_size;
51 if (has_var) {
52 positional_size = sig_size - 1;
53 }
54 if (actual_param_number < positional_size) {
55 for (size_t i = actual_param_number; i < sig_size; ++i) {
56 auto default_value = signature[i].default_value;
57 if (default_value == nullptr) {
58 MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
59 } else {
60 (*op_inputs).push_back(NewValueNode(default_value));
61 }
62 }
63 }
64 }
65
SetMaxType(TypeId * max_type_id,size_t * max_type_number,const TypeId type_id,const size_t type_number)66 void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) {
67 *max_type_id = type_id;
68 *max_type_number = type_number;
69 }
70
GetTensorOrScalarTypeInfo(const TypePtr & arg_type_origin,TypeId * arg_type_id,TypeId * arg_type=nullptr)71 bool GetTensorOrScalarTypeInfo(const TypePtr &arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) {
72 if (arg_type_origin->isa<TensorType>()) {
73 auto tensor = arg_type_origin->cast<TensorTypePtr>();
74 auto tensor_type = tensor->element();
75 MS_EXCEPTION_IF_NULL(tensor_type);
76 *arg_type_id = tensor_type->type_id();
77 if (arg_type != nullptr) {
78 *arg_type = kObjectTypeTensorType;
79 }
80 return true;
81 }
82 if (arg_type_origin->isa<Number>()) {
83 auto scalar_type = arg_type_origin->cast<NumberPtr>();
84 MS_EXCEPTION_IF_NULL(scalar_type);
85 *arg_type_id = scalar_type->type_id();
86 if (arg_type != nullptr) {
87 *arg_type = kObjectTypeNumber;
88 }
89 return true;
90 }
91 return false;
92 }
93
GetMaxTypeId(const std::vector<TypePtr> & input_types,const std::vector<size_t> & indices)94 TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, const std::vector<size_t> &indices) {
95 TypeId max_type_id = kTypeUnknown;
96 size_t max_type_number = 0;
97 bool has_int8 = false;
98 bool has_scalar_int64 = false;
99 bool has_scalar_float32 = false;
100 for (const auto &index : indices) {
101 TypeId arg_type_id = kTypeUnknown;
102 TypeId arg_type = kTypeUnknown;
103 if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) {
104 continue;
105 }
106 if (arg_type != kObjectTypeTensorType) {
107 if (arg_type_id == kNumberTypeInt64) {
108 has_scalar_int64 = true;
109 } else if (arg_type_id == kNumberTypeFloat32) {
110 has_scalar_float32 = true;
111 }
112 continue;
113 }
114 auto it = type_map.find(arg_type_id);
115 if (it == type_map.end()) {
116 continue;
117 }
118 if (arg_type_id == kNumberTypeInt8) {
119 has_int8 = true;
120 }
121 if (max_type_id == kTypeUnknown) {
122 SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
123 continue;
124 }
125 if (it->second > max_type_number) {
126 SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
127 }
128 }
129
130 if (max_type_id == kNumberTypeUInt8 && has_int8) {
131 max_type_id = kNumberTypeInt16;
132 }
133 // if bool is the max type, see if there is scalar input
134 // if so, it means that max is bool tensor, use scalar type instead.
135 // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
136 if (max_type_id == kNumberTypeBool) {
137 if (has_scalar_int64) {
138 max_type_id = kNumberTypeInt64;
139 }
140 if (has_scalar_float32) {
141 max_type_id = kNumberTypeFloat32;
142 }
143 }
144 if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
145 max_type_id != kTypeUnknown && has_scalar_float32) {
146 max_type_id = kNumberTypeFloat32;
147 }
148 return max_type_id;
149 }
150
151 // Get the largest type of index in the same SignatureEnumDType of arguments.
152 using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
GetMaxDtype(const std::vector<SignatureEnumDType> & dtypes,const std::vector<TypePtr> & input_types)153 MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types) {
154 // record index for signature.dtypes of the same type
155 // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
156 std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
157 for (size_t i = 0; i < dtypes.size(); ++i) {
158 auto it = type_indices.find(dtypes[i]);
159 if (it == type_indices.end()) {
160 (void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
161 } else {
162 it->second.push_back(i);
163 }
164 }
165 std::map<SignatureEnumDType, TypeId> dst_type;
166 for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) {
167 auto type = it->first;
168 auto indices = it->second;
169 // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
170 if (indices.size() < 2) {
171 continue;
172 }
173 bool has_tensor = false;
174 for (const auto &index : indices) {
175 auto arg_value = input_types[index];
176 if (arg_value->isa<TensorType>()) {
177 has_tensor = true;
178 break;
179 }
180 }
181 if (!has_tensor) {
182 (void)dst_type.insert(std::make_pair(type, kTypeUnknown));
183 continue;
184 }
185 (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices)));
186 }
187 return dst_type;
188 }
189
DoCast(const AnfNodePtr & param,const TypeId & type_id,const FuncGraphPtr & graph)190 AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGraphPtr &graph) {
191 MS_EXCEPTION_IF_NULL(graph);
192 auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations");
193 MS_EXCEPTION_IF_NULL(prim_cast_class);
194 auto dtype_node = NewValueNode(TypeIdToType(type_id));
195 auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph);
196 return graph->NewCNodeAfter(param, {cast_node, param, dtype_node});
197 }
198
DoAutoCast(const std::string & func_name,const std::vector<Signature> & signature,const std::vector<TypePtr> & input_types,const FuncGraphPtr & graph,const std::set<size_t> & write_indices,std::vector<AnfNodePtr> * const op_inputs)199 void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
200 const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
201 const std::set<size_t> &write_indices, std::vector<AnfNodePtr> *const op_inputs) {
202 std::vector<SignatureEnumDType> dtypes;
203 (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
204 [](const Signature &sig) { return sig.dtype; });
205 int64_t empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
206 if (dtypes.empty() || static_cast<int64_t>(dtypes.size()) == empty_dtype_count) {
207 return;
208 }
209 // Stat the index of the arguments with the largest type in the same SignatureEnumDType.
210 std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types);
211 // Identify which arg requires auto cast
212 for (size_t i = 0; i < input_types.size(); ++i) {
213 auto it = dst_type.find(dtypes[i]);
214 if (it == dst_type.end() || it->second == kTypeUnknown) {
215 continue;
216 }
217 auto rw_it = write_indices.find(i);
218 auto is_write = (rw_it != write_indices.end());
219
220 TypeId arg_type_id = kTypeUnknown;
221 auto arg_value = input_types[i];
222 (void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id);
223 auto it_map = type_name_map.find(arg_type_id);
224 if (it_map == type_name_map.end()) {
225 continue;
226 }
227 if (is_write) {
228 if (arg_type_id != it->second) {
229 auto it_name_map = type_name_map.find(it->second);
230 if (it_name_map == type_name_map.end()) {
231 continue;
232 }
233 RaiseExceptionForConvertRefDtype(func_name, it_map->second, it_name_map->second);
234 }
235 continue;
236 }
237 if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
238 continue;
239 }
240 MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
241 << " to " << it->second;
242 (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
243 }
244 }
245
CheckSigSize(const size_t & sig_size,const bool & has_var,const AbstractBasePtrList & args_spec_list,const std::string & func_name)246 void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list,
247 const std::string &func_name) {
248 if (sig_size > 0) {
249 if (has_var) {
250 if (sig_size - 1 > args_spec_list.size()) {
251 MS_LOG(EXCEPTION) << "Function " << func_name
252 << "'s input length less than PositionalKeyword Signature length.";
253 }
254 } else if (args_spec_list.size() > sig_size) {
255 MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
256 }
257 }
258 }
259
BuildNewCNode(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_spec_list,const std::vector<AnfNodePtr> & params_list)260 AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
261 const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> ¶ms_list) {
262 // args: original inputs
263 auto &signature = GetSignature(function);
264 std::size_t sig_size = signature.size();
265 auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
266 CheckSigSize(sig_size, has_var, args_spec_list, func_name);
267 std::vector<AnfNodePtr> op_inputs;
268 std::set<size_t> write_indices;
269 std::vector<TypePtr> input_types;
270 op_inputs.push_back(NewValueNode(function));
271 auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
272 // Assume, the write input of op is always the first input. We check if any write op,
273 // and add cast op on other inputs to keep the same type with assigned parameter.
274 for (size_t i = 0; i < args_spec_list.size(); ++i) {
275 AnfNodePtr param = params_list[i];
276 if (args_spec_list[i] == nullptr) {
277 op_inputs.push_back(param);
278 continue;
279 }
280 SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
281 // If sig_size is 0 use default.
282 if (sig_size > 0 && i < sig_size) {
283 sig = signature[i].rw;
284 } else if (has_var && i >= sig_size) {
285 sig = signature[sig_size - 1].rw;
286 }
287
288 TypePtr type = args_spec_list[i]->BuildType();
289 if (type && type->isa<RefType>()) {
290 if (sig == SignatureEnumRW::kRWRead) {
291 auto source_tensor_type = type->cast<TensorTypePtr>();
292 if (source_tensor_type != nullptr) {
293 auto source_element = source_tensor_type->element();
294 if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
295 auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
296 param = func_graph->NewCNodeAfter(param, {NewValueNode(cast), param, NewValueNode(cast_type)});
297 type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32;
298 }
299 }
300 } else if (sig == SignatureEnumRW::kRWWrite) {
301 write_indices.insert(i);
302 }
303 // If sig is SignatureEnumRW::kRWRef, not do anything.
304 } else if (sig == SignatureEnumRW::kRWWrite &&
305 !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
306 RaiseExceptionForCheckParameter(func_name, i, type->ToString());
307 }
308 MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
309 << args_spec_list[i]->ToString() << " type " << type->ToString();
310 input_types.push_back(type);
311 op_inputs.push_back(param);
312 }
313 // process default
314 ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
315 DoAutoCast(func_name, signature, input_types, func_graph, write_indices, &op_inputs);
316 return func_graph->NewCNodeInOrder(op_inputs);
317 }
318 } // namespace
319
GenerateCNode(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_spec_list,const AnfNodePtrList & old_node_inputs)320 AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
321 const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) {
322 auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs);
323 return new_cnode;
324 }
325
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)326 FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
327 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
328
329 for (size_t i = 0; i < args_spec_list.size(); ++i) {
330 (void)func_graph->add_parameter();
331 }
332 auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters());
333 func_graph->set_output(new_cnode);
334 func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
335 return func_graph;
336 }
337
RaiseExceptionForConvertRefDtype(const std::string & func_name,const std::string & ref_type,const std::string & target_type)338 void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type,
339 const std::string &target_type) {
340 MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n"
341 << "the type of writable argument is '" << ref_type << "', "
342 << "but the largest type in the same SignatureEumDtype is '" << target_type
343 << "'. The writable arg type is not equal to the largest type, "
344 << "so can not cast automatically.";
345 }
RaiseExceptionForCheckParameter(const std::string & func_name,size_t i,const std::string & source_type)346 void RaiseExceptionForCheckParameter(const std::string &func_name, size_t i, const std::string &source_type) {
347 MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
348 << source_type << ".";
349 }
350 } // namespace prim
351 } // namespace mindspore
352