• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/forward/do_cast.h"
18 #include <memory>
19 #include <utility>
20 #include <algorithm>
21 #include "mindspore/core/ops/array_ops.h"
22 #include "pipeline/pynative/pynative_utils.h"
23 #include "include/common/profiler.h"
24 
25 namespace mindspore {
26 namespace pynative {
DoCast(const FrontendOpRunInfoPtr & op_run_info)27 void CastOperation::DoCast(const FrontendOpRunInfoPtr &op_run_info) {
28   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeCast,
29                                      op_run_info->base_op_run_info.op_name, true);
30   // Mixed precision conversion tensors which has cast dtype
31   SetTensorMixPrecisionCast(op_run_info);
32   // Implicit transform
33   SetImplicitCast(op_run_info);
34 }
35 
ClearRes()36 void CastOperation::ClearRes() {
37   implicit_cast_map_.clear();
38   type_prim_cache_.clear();
39 }
40 
IsValueTypeInvalid(const ValuePtr & v) const41 bool CastOperation::IsValueTypeInvalid(const ValuePtr &v) const {
42   MS_EXCEPTION_IF_NULL(v);
43   return !v->isa<tensor::BaseTensor>() && !v->isa<tensor::CSRTensor>() && !v->isa<IntegerImm>() &&
44          !v->isa<FloatImm>() && !v->isa<BoolImm>();
45 }
46 
DoNormalCast(const FrontendOpRunInfoPtr & cast_run_info,const ValuePtr & v,const TypeId & type_id) const47 ValuePtr CastOperation::DoNormalCast(const FrontendOpRunInfoPtr &cast_run_info, const ValuePtr &v,
48                                      const TypeId &type_id) const {
49   MS_EXCEPTION_IF_NULL(v);
50   MS_EXCEPTION_IF_NULL(cast_run_info);
51   // Step 1: Cast scalar value to another scalar value with destination data type.
52   // It is used to avoid to call `cast infer value function` or launch cast op to backend.
53   ValuePtr dst_value = ScalarToDstDtypeValue(v, std::make_pair(type_id, true));
54   if (dst_value != nullptr) {
55     MS_LOG(DEBUG) << "Source value: " << v->ToString() << " cast to value: " << dst_value->ToString();
56     cast_run_info->real_out = dst_value;
57     return dst_value;
58   }
59 
60   if (v->isa<tensor::BaseTensor>()) {
61     auto tensor = v->cast<tensor::BaseTensorPtr>();
62     if (type_id == tensor->data_type()) {
63       cast_run_info->real_out = v;
64       return cast_run_info->real_out;
65     }
66   }
67 
68   constexpr auto input_size = 2;
69   cast_run_info->op_grad_info->op_prim = GetPrimByTypeId(type_id);
70   auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(type_id));
71   PyNativeAlgo::Common::GetConstInputToAttr(
72     cast_run_info->op_grad_info->op_prim, cast_run_info->base_op_run_info.op_name,
73     cast_run_info->base_op_run_info.device_target, false, &cast_run_info->input_to_attr);
74   (void)cast_run_info->op_grad_info->input_value.emplace_back(v);
75   (void)cast_run_info->op_grad_info->input_value.emplace_back(type_id64);
76   cast_run_info->input_size = input_size;
77   cast_run_info->op_grad_info->input_value_grad_type.resize(input_size);
78   PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->RunOpFrontend(cast_run_info);
79   return cast_run_info->real_out;
80 }
81 
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,const std::pair<TypeId,bool> & dst_type,const std::string & op_name,size_t index) const82 ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
83                                    const std::pair<TypeId, bool> &dst_type, const std::string &op_name,
84                                    size_t index) const {
85   MS_EXCEPTION_IF_NULL(v);
86   // Step 1: Cast scalar value to another scalar value with destination data type.
87   // It is used to avoid to call `cast infer value function` or launch cast op to backend.
88   ValuePtr dst_value = ScalarToDstDtypeValue(v, dst_type);
89   if (dst_value != nullptr) {
90     MS_LOG(DEBUG) << "Source value: " << v->ToString() << " cast to value: " << dst_value->ToString();
91     return dst_value;
92   }
93   MS_EXCEPTION_IF_NULL(op_run_info);
94   if (op_run_info->source_type[index] != ops::OP_DTYPE::DT_BEGIN && v->isa<tensor::BaseTensor>()) {
95     MS_LOG(DEBUG) << "Source value: " << v->ToString();
96     dst_value = TensorToDstDtypeValue(v, dst_type.first);
97     MS_LOG(DEBUG) << "Cast to value: " << dst_value->ToString() << " without dispatching cast op";
98     return dst_value;
99   }
100   // When step 1 does not work, creating a cast op to get destination data type value.
101   constexpr auto input_size = 2;
102   const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
103   auto cast_prim = GetPrimByTypeId(dst_type.first);
104   auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(dst_type.first));
105   cast_run_info->requires_grad = op_run_info->requires_grad;
106   cast_run_info->base_op_run_info.op_name = prim::kPrimCast->name();
107   cast_run_info->base_op_run_info.is_mixed_precision_cast = true;
108   cast_run_info->base_op_run_info.next_op_name = op_name;
109   cast_run_info->base_op_run_info.next_input_index = index;
110   cast_run_info->base_op_run_info.use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
111   cast_run_info->cell_obj_id = op_run_info->cell_obj_id;
112   cast_run_info->base_op_run_info.device_target =
113     PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->GetCurrentDeviceTarget(cast_prim);
114   bool is_dynamic_shape =
115     cast_run_info->base_op_run_info.has_dynamic_output || cast_run_info->base_op_run_info.use_dynamic_shape_process;
116   PyNativeAlgo::Common::GetConstInputToAttr(cast_prim, cast_run_info->base_op_run_info.op_name,
117                                             cast_run_info->base_op_run_info.device_target, is_dynamic_shape,
118                                             &cast_run_info->input_to_attr);
119   (void)cast_run_info->op_grad_info->input_value.emplace_back(v);
120   (void)cast_run_info->op_grad_info->input_value.emplace_back(type_id64);
121   cast_run_info->input_size = input_size;
122   cast_run_info->op_grad_info->input_value_grad_type.resize(input_size);
123   cast_run_info->op_grad_info->op_prim = cast_prim;
124   PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->RunOpFrontend(cast_run_info);
125   return cast_run_info->real_out;
126 }
127 
DoParamMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,bool * is_cast,const ValuePtr & v,const std::string & op_name,size_t index) const128 ValuePtr CastOperation::DoParamMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, bool *is_cast,
129                                                 const ValuePtr &v, const std::string &op_name, size_t index) const {
130   MS_EXCEPTION_IF_NULL(op_run_info);
131   MS_EXCEPTION_IF_NULL(is_cast);
132   MS_EXCEPTION_IF_NULL(v);
133   if (op_run_info->mix_type != kNotSet) {
134     auto dst_dtype = kFloat16;
135     if (op_run_info->mix_type == kFP32) {
136       dst_dtype = kFloat32;
137     } else if (op_run_info->mix_type == kBF16) {
138       dst_dtype = kBFloat16;
139     }
140     const auto &tensor = v->cast<tensor::BaseTensorPtr>();
141     MS_EXCEPTION_IF_NULL(tensor);
142     auto source_dtype = tensor->Dtype();
143     if (source_dtype != nullptr && (IsSubType(source_dtype, kFloat) || IsSubType(source_dtype, kBFloat)) &&
144         *source_dtype != *dst_dtype) {
145       MS_LOG(DEBUG) << "MixPrecision cast for " << op_run_info->base_op_run_info.op_name << " " << index
146                     << "th input, and to type " << dst_dtype->ToString();
147       *is_cast = true;
148       return DoAutoCast(op_run_info, tensor, std::make_pair(dst_dtype->type_id(), true), op_name, index);
149     }
150   }
151   return v;
152 }
153 
DoParamMixPrecisionCastTuple(const FrontendOpRunInfoPtr & op_run_info,bool * is_cast,const ValueSequencePtr & value_seq,const std::string & op_name,size_t index) const154 ValuePtr CastOperation::DoParamMixPrecisionCastTuple(const FrontendOpRunInfoPtr &op_run_info, bool *is_cast,
155                                                      const ValueSequencePtr &value_seq, const std::string &op_name,
156                                                      size_t index) const {
157   MS_EXCEPTION_IF_NULL(op_run_info);
158   MS_EXCEPTION_IF_NULL(is_cast);
159   MS_EXCEPTION_IF_NULL(value_seq);
160   size_t tuple_size = value_seq->size();
161   const auto &value_tuple = value_seq->value();
162   ValuePtrList result(tuple_size, nullptr);
163   for (size_t i = 0; i < tuple_size; i++) {
164     if (value_tuple[i]->isa<tensor::MetaTensor>()) {
165       MS_LOG(DEBUG) << "Call cast for item " << i;
166       result[i] = DoParamMixPrecisionCast(op_run_info, is_cast, value_tuple[i], op_name, index);
167     } else if (value_tuple[i]->isa<ValueSequence>()) {
168       result[i] =
169         DoParamMixPrecisionCastTuple(op_run_info, is_cast, value_tuple[i]->cast<ValueSequencePtr>(), op_name, index);
170     } else {
171       result[i] = value_tuple[i];
172     }
173   }
174   if (value_seq->isa<ValueList>()) {
175     return std::make_shared<ValueList>(result);
176   } else {
177     return std::make_shared<ValueTuple>(result);
178   }
179 }
180 
DoSignatureCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes) const181 void CastOperation::DoSignatureCast(const FrontendOpRunInfoPtr &op_run_info,
182                                     const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type,
183                                     const std::vector<SignatureEnumDType> &dtypes) const {
184   MS_EXCEPTION_IF_NULL(op_run_info);
185   MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info->op_prim);
186   const auto &signature = op_run_info->signatures;
187   auto &input_args = op_run_info->op_grad_info->input_value;
188   size_t input_args_size = input_args.size();
189   if (dtypes.size() > input_args_size) {
190     MS_LOG(EXCEPTION) << "Signature dtypes size[" << dtypes << "] is greater than input_args_size[" << input_args_size
191                       << "].";
192   }
193   for (size_t i = 0; i < dtypes.size(); ++i) {
194     // No need to implicit cast if no dtype.
195     if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
196       MS_LOG(DEBUG) << "Get kDTypeEmptyDefaultValue";
197       continue;
198     }
199     auto it = dst_type.find(dtypes[i]);
200     if (it == dst_type.end() || it->second.first == kTypeUnknown) {
201       MS_LOG(DEBUG) << "Can not find dtype " << (it == dst_type.end()) << ", or type is unknown "
202                     << (it->second.first == kTypeUnknown);
203       continue;
204     }
205     const auto &v = input_args[i];
206     auto sig = SignatureEnumRW::kRWDefault;
207     if (!signature.empty()) {
208       if (i >= signature.size()) {
209         MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
210                                  << ", index " << i;
211       }
212       sig = signature[i].rw;
213     }
214     TypeId arg_type_id = kTypeUnknown;
215     if (v->isa<tensor::MetaTensor>()) {
216       const auto &arg = v->cast<tensor::MetaTensorPtr>();
217       arg_type_id = arg->data_type();
218     }
219     // Implicit cast
220     bool is_same_type = false;
221     if (arg_type_id != kTypeUnknown) {
222       is_same_type = (arg_type_id == it->second.first);
223     }
224     if (sig == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) {
225       prim::RaiseExceptionForConvertRefDtype(op_run_info->op_grad_info->op_prim, TypeIdToString(arg_type_id),
226                                              TypeIdToString(it->second.first), i);
227     }
228     if (is_same_type) {
229       MS_LOG(DEBUG) << "Get same dtype";
230       continue;
231     }
232 
233     if (IsValueTypeInvalid(v)) {
234       std::string type_str = v->type() == nullptr ? "None, value is \"" + v->ToString() + "\"" : v->type()->ToString();
235       MS_EXCEPTION(TypeError) << "For '" << op_run_info->op_grad_info->op_prim->name() << "', the " << (i + 1)
236                               << "th input " << signature[i].name << " can not be implicitly converted. "
237                               << "Its type is " << type_str << ". Only support Tensor or Scalar.";
238     }
239     MS_LOG(DEBUG) << "Implicit cast for " << op_run_info->base_op_run_info.op_name << " " << i << "th input, from type "
240                   << (v->type() == nullptr ? v->ToString() : v->type()->ToString()) << " to type "
241                   << TypeIdToType(it->second.first)->ToString();
242     input_args[i] = DoAutoCast(op_run_info, v, it->second, op_run_info->base_op_run_info.op_name, i);
243   }
244 }
245 
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info) const246 void CastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info) const {
247   MS_EXCEPTION_IF_NULL(op_run_info);
248   if (op_run_info->async_status.disable_mix_precision) {
249     // Pure function running, mix precision cast is disable, or cell not set mix precision
250     MS_LOG(DEBUG) << "No mix precision for " << op_run_info->base_op_run_info.op_name;
251     return;
252   }
253   MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info->op_prim);
254   const auto &signature = op_run_info->signatures;
255   for (size_t i = 0; i < op_run_info->none_init_inputs_num; i++) {
256     const auto &v = op_run_info->op_grad_info->input_value[i];
257     auto sig = SignatureEnumRW::kRWDefault;
258     if (!signature.empty()) {
259       if (i >= signature.size()) {
260         MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
261                                  << ", index " << i;
262       }
263       sig = signature[i].rw;
264     }
265     // mix precision for non param
266     bool is_cast = false;
267     ValuePtr cast_output = nullptr;
268     if (v->isa<tensor::MetaTensor>()) {
269       auto meta_tensor = v->cast<tensor::MetaTensorPtr>();
270       if (meta_tensor && meta_tensor->is_parameter()) {
271         // If parameter write(not kRWRead), no need cast
272         if (sig != SignatureEnumRW::kRWRead) {
273           continue;
274         }
275       }
276       cast_output = DoParamMixPrecisionCast(op_run_info, &is_cast, v, op_run_info->op_grad_info->op_prim->name(), i);
277     } else if (v->isa<ValueSequence>()) {
278       // mix precision for tuple inputs
279       cast_output = DoParamMixPrecisionCastTuple(op_run_info, &is_cast, v->cast<ValueSequencePtr>(),
280                                                  op_run_info->op_grad_info->op_prim->name(), i);
281     }
282     if (is_cast) {
283       MS_EXCEPTION_IF_NULL(cast_output);
284       op_run_info->op_grad_info->input_value[i] = cast_output;
285     }
286   }
287 }
288 
289 namespace {
GetTypeInfo(const FrontendOpRunInfoPtr & op_run_info)290 std::pair<std::vector<TypeId>, std::vector<bool>> GetTypeInfo(const FrontendOpRunInfoPtr &op_run_info) {
291   MS_EXCEPTION_IF_NULL(op_run_info);
292   std::vector<TypeId> args_type_id;
293   std::vector<bool> args_has_tensor;
294   args_type_id.resize(op_run_info->input_size);
295   args_has_tensor.resize(op_run_info->input_size, false);
296 
297   const auto &input_value = op_run_info->op_grad_info->input_value;
298   for (size_t i = 0; i < op_run_info->input_size; ++i) {
299     if (input_value[i]->isa<tensor::BaseTensor>()) {
300       args_type_id[i] = input_value[i]->cast<tensor::BaseTensorPtr>()->data_type();
301       if (op_run_info->source_type[i] == ops::OP_DTYPE::DT_BEGIN) {
302         args_has_tensor[i] = true;
303       }
304     } else if (input_value[i]->isa<Scalar>()) {
305       const auto type = input_value[i]->cast<ScalarPtr>()->type();
306       MS_EXCEPTION_IF_NULL(type);
307       args_type_id[i] = type->type_id();
308     } else {
309       MS_LOG(DEBUG) << "Get input value " << input_value[i]->ToString();
310       args_type_id[i] = kTypeUnknown;
311     }
312   }
313   return {args_type_id, args_has_tensor};
314 }
315 }  // namespace
316 
SetImplicitCast(const FrontendOpRunInfoPtr & op_run_info)317 void CastOperation::SetImplicitCast(const FrontendOpRunInfoPtr &op_run_info) {
318   MS_EXCEPTION_IF_NULL(op_run_info);
319   const auto &prim = op_run_info->op_grad_info->op_prim;
320   MS_EXCEPTION_IF_NULL(prim);
321   const auto &it = implicit_cast_map_.find(prim->name());
322   if (it == implicit_cast_map_.end()) {
323     std::vector<SignatureEnumDType> dtypes;
324     bool has_dtype_sig = GetSignatureType(op_run_info->signatures, &dtypes);
325     if (!has_dtype_sig) {
326       PrimSignature sig_value{has_dtype_sig, {}};
327       implicit_cast_map_[prim->name()] = sig_value;
328       MS_LOG(DEBUG) << "Op " << prim->name() << " has no signature";
329       return;
330     }
331     const auto &signature = op_run_info->signatures;
332     auto sig_size = signature.size();
333     // Ignore monad signature
334     for (const auto &sig : signature) {
335       if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
336         --sig_size;
337       }
338     }
339     if (sig_size > 0 && sig_size != op_run_info->none_init_inputs_num) {
340       MS_EXCEPTION(ValueError) << op_run_info->base_op_run_info.op_name << " inputs number "
341                                << op_run_info->none_init_inputs_num << " does not match the requires "
342                                << "signature size " << sig_size;
343     }
344 
345     auto [args_type_id, args_has_tensor] = GetTypeInfo(op_run_info);
346     auto dst_type = GetSignatureTypeMap(dtypes, args_type_id, args_has_tensor);
347     DoSignatureCast(op_run_info, dst_type, dtypes);
348     PrimSignature sig_value{has_dtype_sig, dtypes};
349     implicit_cast_map_[prim->name()] = sig_value;
350   } else {
351     if (!it->second.has_dtype_sig) {
352       MS_LOG(DEBUG) << op_run_info->base_op_run_info.op_name << " have no dtype sig";
353       return;
354     }
355     MS_LOG(DEBUG) << "Do signature for " << op_run_info->base_op_run_info.op_name << " with cache";
356     auto [args_type_id, args_has_tensor] = GetTypeInfo(op_run_info);
357     auto dst_type = GetSignatureTypeMap(it->second.dtypes, args_type_id, args_has_tensor);
358     DoSignatureCast(op_run_info, dst_type, it->second.dtypes);
359   }
360 }
361 }  // namespace pynative
362 }  // namespace mindspore
363