• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/composite/do_signature.h"
18 #include <algorithm>
19 #include <utility>
20 
21 #include "abstract/abstract_value.h"
22 #include "ir/anf.h"
23 #include "ir/dtype.h"
24 #include "abstract/dshape.h"
25 #include "abstract/param_validator.h"
26 #include "frontend/operator/cc_implementations.h"
27 #include "frontend/optimizer/opt.h"
28 #include "pybind_api/api_register.h"
29 
30 namespace mindspore {
31 // namespace to support composite operators definition
32 namespace prim {
33 const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1},    {kNumberTypeInt8, 2},    {kNumberTypeUInt8, 3},
34                                            {kNumberTypeInt16, 4},   {kNumberTypeInt32, 5},   {kNumberTypeInt64, 6},
35                                            {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}};
36 namespace {
GetSignature(const ValuePtr & function)37 const std::vector<Signature> &GetSignature(const ValuePtr &function) {
38   static const auto empty = std::vector<Signature>();
39   if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) {
40     return function->cast<PrimitivePyPtr>()->signatures();
41   } else if (function->isa<MetaFuncGraph>()) {
42     return function->cast<MetaFuncGraphPtr>()->signatures();
43   }
44   return empty;
45 }
46 
ProcessDefault(const std::string & func_name,size_t actual_param_number,const std::vector<Signature> & signature,bool has_var,std::vector<AnfNodePtr> * const op_inputs)47 void ProcessDefault(const std::string &func_name, size_t actual_param_number, const std::vector<Signature> &signature,
48                     bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
49   std::size_t sig_size = signature.size();
50   auto positional_size = sig_size;
51   if (has_var) {
52     positional_size = sig_size - 1;
53   }
54   if (actual_param_number < positional_size) {
55     for (size_t i = actual_param_number; i < sig_size; ++i) {
56       auto default_value = signature[i].default_value;
57       if (default_value == nullptr) {
58         MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
59       } else {
60         (*op_inputs).push_back(NewValueNode(default_value));
61       }
62     }
63   }
64 }
65 
SetMaxType(TypeId * max_type_id,size_t * max_type_number,const TypeId type_id,const size_t type_number)66 void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) {
67   *max_type_id = type_id;
68   *max_type_number = type_number;
69 }
70 
GetTensorOrScalarTypeInfo(const TypePtr & arg_type_origin,TypeId * arg_type_id,TypeId * arg_type=nullptr)71 bool GetTensorOrScalarTypeInfo(const TypePtr &arg_type_origin, TypeId *arg_type_id, TypeId *arg_type = nullptr) {
72   if (arg_type_origin->isa<TensorType>()) {
73     auto tensor = arg_type_origin->cast<TensorTypePtr>();
74     auto tensor_type = tensor->element();
75     MS_EXCEPTION_IF_NULL(tensor_type);
76     *arg_type_id = tensor_type->type_id();
77     if (arg_type != nullptr) {
78       *arg_type = kObjectTypeTensorType;
79     }
80     return true;
81   }
82   if (arg_type_origin->isa<Number>()) {
83     auto scalar_type = arg_type_origin->cast<NumberPtr>();
84     MS_EXCEPTION_IF_NULL(scalar_type);
85     *arg_type_id = scalar_type->type_id();
86     if (arg_type != nullptr) {
87       *arg_type = kObjectTypeNumber;
88     }
89     return true;
90   }
91   return false;
92 }
93 
GetMaxTypeId(const std::vector<TypePtr> & input_types,const std::vector<size_t> & indices)94 TypeId GetMaxTypeId(const std::vector<TypePtr> &input_types, const std::vector<size_t> &indices) {
95   TypeId max_type_id = kTypeUnknown;
96   size_t max_type_number = 0;
97   bool has_int8 = false;
98   bool has_scalar_int64 = false;
99   bool has_scalar_float32 = false;
100   for (const auto &index : indices) {
101     TypeId arg_type_id = kTypeUnknown;
102     TypeId arg_type = kTypeUnknown;
103     if (!GetTensorOrScalarTypeInfo(input_types[index], &arg_type_id, &arg_type)) {
104       continue;
105     }
106     if (arg_type != kObjectTypeTensorType) {
107       if (arg_type_id == kNumberTypeInt64) {
108         has_scalar_int64 = true;
109       } else if (arg_type_id == kNumberTypeFloat32) {
110         has_scalar_float32 = true;
111       }
112       continue;
113     }
114     auto it = type_map.find(arg_type_id);
115     if (it == type_map.end()) {
116       continue;
117     }
118     if (arg_type_id == kNumberTypeInt8) {
119       has_int8 = true;
120     }
121     if (max_type_id == kTypeUnknown) {
122       SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
123       continue;
124     }
125     if (it->second > max_type_number) {
126       SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second);
127     }
128   }
129 
130   if (max_type_id == kNumberTypeUInt8 && has_int8) {
131     max_type_id = kNumberTypeInt16;
132   }
133   // if bool is the max type, see if there is scalar input
134   // if so, it means that max is bool tensor, use scalar type instead.
135   // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2])
136   if (max_type_id == kNumberTypeBool) {
137     if (has_scalar_int64) {
138       max_type_id = kNumberTypeInt64;
139     }
140     if (has_scalar_float32) {
141       max_type_id = kNumberTypeFloat32;
142     }
143   }
144   if (max_type_id != kNumberTypeFloat16 && max_type_id != kNumberTypeFloat32 && max_type_id != kNumberTypeFloat64 &&
145       max_type_id != kTypeUnknown && has_scalar_float32) {
146     max_type_id = kNumberTypeFloat32;
147   }
148   return max_type_id;
149 }
150 
151 // Get the largest type of index in the same SignatureEnumDType of arguments.
152 using MaxTypeMap = std::map<SignatureEnumDType, TypeId>;
GetMaxDtype(const std::vector<SignatureEnumDType> & dtypes,const std::vector<TypePtr> & input_types)153 MaxTypeMap GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, const std::vector<TypePtr> &input_types) {
154   // record index for signature.dtypes of the same type
155   // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
156   std::map<SignatureEnumDType, std::vector<size_t>> type_indices;
157   for (size_t i = 0; i < dtypes.size(); ++i) {
158     auto it = type_indices.find(dtypes[i]);
159     if (it == type_indices.end()) {
160       (void)type_indices.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
161     } else {
162       it->second.push_back(i);
163     }
164   }
165   std::map<SignatureEnumDType, TypeId> dst_type;
166   for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) {
167     auto type = it->first;
168     auto indices = it->second;
169     // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it.
170     if (indices.size() < 2) {
171       continue;
172     }
173     bool has_tensor = false;
174     for (const auto &index : indices) {
175       auto arg_value = input_types[index];
176       if (arg_value->isa<TensorType>()) {
177         has_tensor = true;
178         break;
179       }
180     }
181     if (!has_tensor) {
182       (void)dst_type.insert(std::make_pair(type, kTypeUnknown));
183       continue;
184     }
185     (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(input_types, indices)));
186   }
187   return dst_type;
188 }
189 
DoCast(const AnfNodePtr & param,const TypeId & type_id,const FuncGraphPtr & graph)190 AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGraphPtr &graph) {
191   MS_EXCEPTION_IF_NULL(graph);
192   auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations");
193   MS_EXCEPTION_IF_NULL(prim_cast_class);
194   auto dtype_node = NewValueNode(TypeIdToType(type_id));
195   auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph);
196   return graph->NewCNodeAfter(param, {cast_node, param, dtype_node});
197 }
198 
DoAutoCast(const std::string & func_name,const std::vector<Signature> & signature,const std::vector<TypePtr> & input_types,const FuncGraphPtr & graph,const std::set<size_t> & write_indices,std::vector<AnfNodePtr> * const op_inputs)199 void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
200                 const std::vector<TypePtr> &input_types, const FuncGraphPtr &graph,
201                 const std::set<size_t> &write_indices, std::vector<AnfNodePtr> *const op_inputs) {
202   std::vector<SignatureEnumDType> dtypes;
203   (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
204                        [](const Signature &sig) { return sig.dtype; });
205   int64_t empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
206   if (dtypes.empty() || static_cast<int64_t>(dtypes.size()) == empty_dtype_count) {
207     return;
208   }
209   // Stat the index of the arguments with the largest type in the same SignatureEnumDType.
210   std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, input_types);
211   // Identify which arg requires auto cast
212   for (size_t i = 0; i < input_types.size(); ++i) {
213     auto it = dst_type.find(dtypes[i]);
214     if (it == dst_type.end() || it->second == kTypeUnknown) {
215       continue;
216     }
217     auto rw_it = write_indices.find(i);
218     auto is_write = (rw_it != write_indices.end());
219 
220     TypeId arg_type_id = kTypeUnknown;
221     auto arg_value = input_types[i];
222     (void)GetTensorOrScalarTypeInfo(arg_value, &arg_type_id);
223     auto it_map = type_name_map.find(arg_type_id);
224     if (it_map == type_name_map.end()) {
225       continue;
226     }
227     if (is_write) {
228       if (arg_type_id != it->second) {
229         auto it_name_map = type_name_map.find(it->second);
230         if (it_name_map == type_name_map.end()) {
231           continue;
232         }
233         RaiseExceptionForConvertRefDtype(func_name, it_map->second, it_name_map->second);
234       }
235       continue;
236     }
237     if ((arg_value->isa<TensorType>()) && arg_type_id == it->second) {
238       continue;
239     }
240     MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
241                   << " to " << it->second;
242     (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
243   }
244 }
245 
CheckSigSize(const size_t & sig_size,const bool & has_var,const AbstractBasePtrList & args_spec_list,const std::string & func_name)246 void CheckSigSize(const size_t &sig_size, const bool &has_var, const AbstractBasePtrList &args_spec_list,
247                   const std::string &func_name) {
248   if (sig_size > 0) {
249     if (has_var) {
250       if (sig_size - 1 > args_spec_list.size()) {
251         MS_LOG(EXCEPTION) << "Function " << func_name
252                           << "'s input length less than PositionalKeyword Signature length.";
253       }
254     } else if (args_spec_list.size() > sig_size) {
255       MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length.";
256     }
257   }
258 }
259 
BuildNewCNode(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_spec_list,const std::vector<AnfNodePtr> & params_list)260 AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
261                          const AbstractBasePtrList &args_spec_list, const std::vector<AnfNodePtr> &params_list) {
262   // args: original inputs
263   auto &signature = GetSignature(function);
264   std::size_t sig_size = signature.size();
265   auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional);
266   CheckSigSize(sig_size, has_var, args_spec_list, func_name);
267   std::vector<AnfNodePtr> op_inputs;
268   std::set<size_t> write_indices;
269   std::vector<TypePtr> input_types;
270   op_inputs.push_back(NewValueNode(function));
271   auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
272   // Assume, the write input of op is always the first input. We check if any write op,
273   // and add cast op on other inputs to keep the same type with assigned parameter.
274   for (size_t i = 0; i < args_spec_list.size(); ++i) {
275     AnfNodePtr param = params_list[i];
276     if (args_spec_list[i] == nullptr) {
277       op_inputs.push_back(param);
278       continue;
279     }
280     SignatureEnumRW sig = SignatureEnumRW::kRWDefault;
281     // If sig_size is 0 use default.
282     if (sig_size > 0 && i < sig_size) {
283       sig = signature[i].rw;
284     } else if (has_var && i >= sig_size) {
285       sig = signature[sig_size - 1].rw;
286     }
287 
288     TypePtr type = args_spec_list[i]->BuildType();
289     if (type && type->isa<RefType>()) {
290       if (sig == SignatureEnumRW::kRWRead) {
291         auto source_tensor_type = type->cast<TensorTypePtr>();
292         if (source_tensor_type != nullptr) {
293           auto source_element = source_tensor_type->element();
294           if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
295             auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
296             param = func_graph->NewCNodeAfter(param, {NewValueNode(cast), param, NewValueNode(cast_type)});
297             type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32;
298           }
299         }
300       } else if (sig == SignatureEnumRW::kRWWrite) {
301         write_indices.insert(i);
302       }
303       // If sig is SignatureEnumRW::kRWRef, not do anything.
304     } else if (sig == SignatureEnumRW::kRWWrite &&
305                !((type->type_id() == kObjectTypeRef) || (type->type_id() == kObjectTypeRefKey))) {
306       RaiseExceptionForCheckParameter(func_name, i, type->ToString());
307     }
308     MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
309                   << args_spec_list[i]->ToString() << " type " << type->ToString();
310     input_types.push_back(type);
311     op_inputs.push_back(param);
312   }
313   // process default
314   ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs);
315   DoAutoCast(func_name, signature, input_types, func_graph, write_indices, &op_inputs);
316   return func_graph->NewCNodeInOrder(op_inputs);
317 }
318 }  // namespace
319 
GenerateCNode(const FuncGraphPtr & func_graph,const std::string & func_name,const ValuePtr & function,const AbstractBasePtrList & args_spec_list,const AnfNodePtrList & old_node_inputs)320 AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function,
321                          const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) {
322   auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs);
323   return new_cnode;
324 }
325 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)326 FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
327   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
328 
329   for (size_t i = 0; i < args_spec_list.size(); ++i) {
330     (void)func_graph->add_parameter();
331   }
332   auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters());
333   func_graph->set_output(new_cnode);
334   func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
335   return func_graph;
336 }
337 
RaiseExceptionForConvertRefDtype(const std::string & func_name,const std::string & ref_type,const std::string & target_type)338 void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type,
339                                       const std::string &target_type) {
340   MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n"
341                     << "the type of writable argument is '" << ref_type << "', "
342                     << "but the largest type in the same SignatureEumDtype is '" << target_type
343                     << "'. The writable arg type is not equal to the largest type, "
344                     << "so can not cast automatically.";
345 }
RaiseExceptionForCheckParameter(const std::string & func_name,size_t i,const std::string & source_type)346 void RaiseExceptionForCheckParameter(const std::string &func_name, size_t i, const std::string &source_type) {
347   MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
348                           << source_type << ".";
349 }
350 }  // namespace prim
351 }  // namespace mindspore
352