1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/forward/do_cast.h"
18 #include <memory>
19 #include <utility>
20 #include <algorithm>
21 #include "mindspore/core/ops/array_ops.h"
22 #include "pipeline/pynative/pynative_utils.h"
23 #include "include/common/profiler.h"
24
25 namespace mindspore {
26 namespace pynative {
DoCast(const FrontendOpRunInfoPtr & op_run_info)27 void CastOperation::DoCast(const FrontendOpRunInfoPtr &op_run_info) {
28 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeCast,
29 op_run_info->base_op_run_info.op_name, true);
30 // Mixed precision conversion tensors which has cast dtype
31 SetTensorMixPrecisionCast(op_run_info);
32 // Implicit transform
33 SetImplicitCast(op_run_info);
34 }
35
ClearRes()36 void CastOperation::ClearRes() {
37 implicit_cast_map_.clear();
38 type_prim_cache_.clear();
39 }
40
IsValueTypeInvalid(const ValuePtr & v) const41 bool CastOperation::IsValueTypeInvalid(const ValuePtr &v) const {
42 MS_EXCEPTION_IF_NULL(v);
43 return !v->isa<tensor::BaseTensor>() && !v->isa<tensor::CSRTensor>() && !v->isa<IntegerImm>() &&
44 !v->isa<FloatImm>() && !v->isa<BoolImm>();
45 }
46
DoNormalCast(const FrontendOpRunInfoPtr & cast_run_info,const ValuePtr & v,const TypeId & type_id) const47 ValuePtr CastOperation::DoNormalCast(const FrontendOpRunInfoPtr &cast_run_info, const ValuePtr &v,
48 const TypeId &type_id) const {
49 MS_EXCEPTION_IF_NULL(v);
50 MS_EXCEPTION_IF_NULL(cast_run_info);
51 // Step 1: Cast scalar value to another scalar value with destination data type.
52 // It is used to avoid to call `cast infer value function` or launch cast op to backend.
53 ValuePtr dst_value = ScalarToDstDtypeValue(v, std::make_pair(type_id, true));
54 if (dst_value != nullptr) {
55 MS_LOG(DEBUG) << "Source value: " << v->ToString() << " cast to value: " << dst_value->ToString();
56 cast_run_info->real_out = dst_value;
57 return dst_value;
58 }
59
60 if (v->isa<tensor::BaseTensor>()) {
61 auto tensor = v->cast<tensor::BaseTensorPtr>();
62 if (type_id == tensor->data_type()) {
63 cast_run_info->real_out = v;
64 return cast_run_info->real_out;
65 }
66 }
67
68 constexpr auto input_size = 2;
69 cast_run_info->op_grad_info->op_prim = GetPrimByTypeId(type_id);
70 auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(type_id));
71 PyNativeAlgo::Common::GetConstInputToAttr(
72 cast_run_info->op_grad_info->op_prim, cast_run_info->base_op_run_info.op_name,
73 cast_run_info->base_op_run_info.device_target, false, &cast_run_info->input_to_attr);
74 (void)cast_run_info->op_grad_info->input_value.emplace_back(v);
75 (void)cast_run_info->op_grad_info->input_value.emplace_back(type_id64);
76 cast_run_info->input_size = input_size;
77 cast_run_info->op_grad_info->input_value_grad_type.resize(input_size);
78 PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->RunOpFrontend(cast_run_info);
79 return cast_run_info->real_out;
80 }
81
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,const std::pair<TypeId,bool> & dst_type,const std::string & op_name,size_t index) const82 ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
83 const std::pair<TypeId, bool> &dst_type, const std::string &op_name,
84 size_t index) const {
85 MS_EXCEPTION_IF_NULL(v);
86 // Step 1: Cast scalar value to another scalar value with destination data type.
87 // It is used to avoid to call `cast infer value function` or launch cast op to backend.
88 ValuePtr dst_value = ScalarToDstDtypeValue(v, dst_type);
89 if (dst_value != nullptr) {
90 MS_LOG(DEBUG) << "Source value: " << v->ToString() << " cast to value: " << dst_value->ToString();
91 return dst_value;
92 }
93 MS_EXCEPTION_IF_NULL(op_run_info);
94 if (op_run_info->source_type[index] != ops::OP_DTYPE::DT_BEGIN && v->isa<tensor::BaseTensor>()) {
95 MS_LOG(DEBUG) << "Source value: " << v->ToString();
96 dst_value = TensorToDstDtypeValue(v, dst_type.first);
97 MS_LOG(DEBUG) << "Cast to value: " << dst_value->ToString() << " without dispatching cast op";
98 return dst_value;
99 }
100 // When step 1 does not work, creating a cast op to get destination data type value.
101 constexpr auto input_size = 2;
102 const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
103 auto cast_prim = GetPrimByTypeId(dst_type.first);
104 auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(dst_type.first));
105 cast_run_info->requires_grad = op_run_info->requires_grad;
106 cast_run_info->base_op_run_info.op_name = prim::kPrimCast->name();
107 cast_run_info->base_op_run_info.is_mixed_precision_cast = true;
108 cast_run_info->base_op_run_info.next_op_name = op_name;
109 cast_run_info->base_op_run_info.next_input_index = index;
110 cast_run_info->base_op_run_info.use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
111 cast_run_info->cell_obj_id = op_run_info->cell_obj_id;
112 cast_run_info->base_op_run_info.device_target =
113 PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->GetCurrentDeviceTarget(cast_prim);
114 bool is_dynamic_shape =
115 cast_run_info->base_op_run_info.has_dynamic_output || cast_run_info->base_op_run_info.use_dynamic_shape_process;
116 PyNativeAlgo::Common::GetConstInputToAttr(cast_prim, cast_run_info->base_op_run_info.op_name,
117 cast_run_info->base_op_run_info.device_target, is_dynamic_shape,
118 &cast_run_info->input_to_attr);
119 (void)cast_run_info->op_grad_info->input_value.emplace_back(v);
120 (void)cast_run_info->op_grad_info->input_value.emplace_back(type_id64);
121 cast_run_info->input_size = input_size;
122 cast_run_info->op_grad_info->input_value_grad_type.resize(input_size);
123 cast_run_info->op_grad_info->op_prim = cast_prim;
124 PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->RunOpFrontend(cast_run_info);
125 return cast_run_info->real_out;
126 }
127
DoParamMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,bool * is_cast,const ValuePtr & v,const std::string & op_name,size_t index) const128 ValuePtr CastOperation::DoParamMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, bool *is_cast,
129 const ValuePtr &v, const std::string &op_name, size_t index) const {
130 MS_EXCEPTION_IF_NULL(op_run_info);
131 MS_EXCEPTION_IF_NULL(is_cast);
132 MS_EXCEPTION_IF_NULL(v);
133 if (op_run_info->mix_type != kNotSet) {
134 auto dst_dtype = kFloat16;
135 if (op_run_info->mix_type == kFP32) {
136 dst_dtype = kFloat32;
137 } else if (op_run_info->mix_type == kBF16) {
138 dst_dtype = kBFloat16;
139 }
140 const auto &tensor = v->cast<tensor::BaseTensorPtr>();
141 MS_EXCEPTION_IF_NULL(tensor);
142 auto source_dtype = tensor->Dtype();
143 if (source_dtype != nullptr && (IsSubType(source_dtype, kFloat) || IsSubType(source_dtype, kBFloat)) &&
144 *source_dtype != *dst_dtype) {
145 MS_LOG(DEBUG) << "MixPrecision cast for " << op_run_info->base_op_run_info.op_name << " " << index
146 << "th input, and to type " << dst_dtype->ToString();
147 *is_cast = true;
148 return DoAutoCast(op_run_info, tensor, std::make_pair(dst_dtype->type_id(), true), op_name, index);
149 }
150 }
151 return v;
152 }
153
DoParamMixPrecisionCastTuple(const FrontendOpRunInfoPtr & op_run_info,bool * is_cast,const ValueSequencePtr & value_seq,const std::string & op_name,size_t index) const154 ValuePtr CastOperation::DoParamMixPrecisionCastTuple(const FrontendOpRunInfoPtr &op_run_info, bool *is_cast,
155 const ValueSequencePtr &value_seq, const std::string &op_name,
156 size_t index) const {
157 MS_EXCEPTION_IF_NULL(op_run_info);
158 MS_EXCEPTION_IF_NULL(is_cast);
159 MS_EXCEPTION_IF_NULL(value_seq);
160 size_t tuple_size = value_seq->size();
161 const auto &value_tuple = value_seq->value();
162 ValuePtrList result(tuple_size, nullptr);
163 for (size_t i = 0; i < tuple_size; i++) {
164 if (value_tuple[i]->isa<tensor::MetaTensor>()) {
165 MS_LOG(DEBUG) << "Call cast for item " << i;
166 result[i] = DoParamMixPrecisionCast(op_run_info, is_cast, value_tuple[i], op_name, index);
167 } else if (value_tuple[i]->isa<ValueSequence>()) {
168 result[i] =
169 DoParamMixPrecisionCastTuple(op_run_info, is_cast, value_tuple[i]->cast<ValueSequencePtr>(), op_name, index);
170 } else {
171 result[i] = value_tuple[i];
172 }
173 }
174 if (value_seq->isa<ValueList>()) {
175 return std::make_shared<ValueList>(result);
176 } else {
177 return std::make_shared<ValueTuple>(result);
178 }
179 }
180
DoSignatureCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes) const181 void CastOperation::DoSignatureCast(const FrontendOpRunInfoPtr &op_run_info,
182 const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type,
183 const std::vector<SignatureEnumDType> &dtypes) const {
184 MS_EXCEPTION_IF_NULL(op_run_info);
185 MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info->op_prim);
186 const auto &signature = op_run_info->signatures;
187 auto &input_args = op_run_info->op_grad_info->input_value;
188 size_t input_args_size = input_args.size();
189 if (dtypes.size() > input_args_size) {
190 MS_LOG(EXCEPTION) << "Signature dtypes size[" << dtypes << "] is greater than input_args_size[" << input_args_size
191 << "].";
192 }
193 for (size_t i = 0; i < dtypes.size(); ++i) {
194 // No need to implicit cast if no dtype.
195 if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
196 MS_LOG(DEBUG) << "Get kDTypeEmptyDefaultValue";
197 continue;
198 }
199 auto it = dst_type.find(dtypes[i]);
200 if (it == dst_type.end() || it->second.first == kTypeUnknown) {
201 MS_LOG(DEBUG) << "Can not find dtype " << (it == dst_type.end()) << ", or type is unknown "
202 << (it->second.first == kTypeUnknown);
203 continue;
204 }
205 const auto &v = input_args[i];
206 auto sig = SignatureEnumRW::kRWDefault;
207 if (!signature.empty()) {
208 if (i >= signature.size()) {
209 MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
210 << ", index " << i;
211 }
212 sig = signature[i].rw;
213 }
214 TypeId arg_type_id = kTypeUnknown;
215 if (v->isa<tensor::MetaTensor>()) {
216 const auto &arg = v->cast<tensor::MetaTensorPtr>();
217 arg_type_id = arg->data_type();
218 }
219 // Implicit cast
220 bool is_same_type = false;
221 if (arg_type_id != kTypeUnknown) {
222 is_same_type = (arg_type_id == it->second.first);
223 }
224 if (sig == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) {
225 prim::RaiseExceptionForConvertRefDtype(op_run_info->op_grad_info->op_prim, TypeIdToString(arg_type_id),
226 TypeIdToString(it->second.first), i);
227 }
228 if (is_same_type) {
229 MS_LOG(DEBUG) << "Get same dtype";
230 continue;
231 }
232
233 if (IsValueTypeInvalid(v)) {
234 std::string type_str = v->type() == nullptr ? "None, value is \"" + v->ToString() + "\"" : v->type()->ToString();
235 MS_EXCEPTION(TypeError) << "For '" << op_run_info->op_grad_info->op_prim->name() << "', the " << (i + 1)
236 << "th input " << signature[i].name << " can not be implicitly converted. "
237 << "Its type is " << type_str << ". Only support Tensor or Scalar.";
238 }
239 MS_LOG(DEBUG) << "Implicit cast for " << op_run_info->base_op_run_info.op_name << " " << i << "th input, from type "
240 << (v->type() == nullptr ? v->ToString() : v->type()->ToString()) << " to type "
241 << TypeIdToType(it->second.first)->ToString();
242 input_args[i] = DoAutoCast(op_run_info, v, it->second, op_run_info->base_op_run_info.op_name, i);
243 }
244 }
245
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info) const246 void CastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info) const {
247 MS_EXCEPTION_IF_NULL(op_run_info);
248 if (op_run_info->async_status.disable_mix_precision) {
249 // Pure function running, mix precision cast is disable, or cell not set mix precision
250 MS_LOG(DEBUG) << "No mix precision for " << op_run_info->base_op_run_info.op_name;
251 return;
252 }
253 MS_EXCEPTION_IF_NULL(op_run_info->op_grad_info->op_prim);
254 const auto &signature = op_run_info->signatures;
255 for (size_t i = 0; i < op_run_info->none_init_inputs_num; i++) {
256 const auto &v = op_run_info->op_grad_info->input_value[i];
257 auto sig = SignatureEnumRW::kRWDefault;
258 if (!signature.empty()) {
259 if (i >= signature.size()) {
260 MS_EXCEPTION(ValueError) << "Signature size is not equal to index, signature size " << signature.size()
261 << ", index " << i;
262 }
263 sig = signature[i].rw;
264 }
265 // mix precision for non param
266 bool is_cast = false;
267 ValuePtr cast_output = nullptr;
268 if (v->isa<tensor::MetaTensor>()) {
269 auto meta_tensor = v->cast<tensor::MetaTensorPtr>();
270 if (meta_tensor && meta_tensor->is_parameter()) {
271 // If parameter write(not kRWRead), no need cast
272 if (sig != SignatureEnumRW::kRWRead) {
273 continue;
274 }
275 }
276 cast_output = DoParamMixPrecisionCast(op_run_info, &is_cast, v, op_run_info->op_grad_info->op_prim->name(), i);
277 } else if (v->isa<ValueSequence>()) {
278 // mix precision for tuple inputs
279 cast_output = DoParamMixPrecisionCastTuple(op_run_info, &is_cast, v->cast<ValueSequencePtr>(),
280 op_run_info->op_grad_info->op_prim->name(), i);
281 }
282 if (is_cast) {
283 MS_EXCEPTION_IF_NULL(cast_output);
284 op_run_info->op_grad_info->input_value[i] = cast_output;
285 }
286 }
287 }
288
289 namespace {
GetTypeInfo(const FrontendOpRunInfoPtr & op_run_info)290 std::pair<std::vector<TypeId>, std::vector<bool>> GetTypeInfo(const FrontendOpRunInfoPtr &op_run_info) {
291 MS_EXCEPTION_IF_NULL(op_run_info);
292 std::vector<TypeId> args_type_id;
293 std::vector<bool> args_has_tensor;
294 args_type_id.resize(op_run_info->input_size);
295 args_has_tensor.resize(op_run_info->input_size, false);
296
297 const auto &input_value = op_run_info->op_grad_info->input_value;
298 for (size_t i = 0; i < op_run_info->input_size; ++i) {
299 if (input_value[i]->isa<tensor::BaseTensor>()) {
300 args_type_id[i] = input_value[i]->cast<tensor::BaseTensorPtr>()->data_type();
301 if (op_run_info->source_type[i] == ops::OP_DTYPE::DT_BEGIN) {
302 args_has_tensor[i] = true;
303 }
304 } else if (input_value[i]->isa<Scalar>()) {
305 const auto type = input_value[i]->cast<ScalarPtr>()->type();
306 MS_EXCEPTION_IF_NULL(type);
307 args_type_id[i] = type->type_id();
308 } else {
309 MS_LOG(DEBUG) << "Get input value " << input_value[i]->ToString();
310 args_type_id[i] = kTypeUnknown;
311 }
312 }
313 return {args_type_id, args_has_tensor};
314 }
315 } // namespace
316
SetImplicitCast(const FrontendOpRunInfoPtr & op_run_info)317 void CastOperation::SetImplicitCast(const FrontendOpRunInfoPtr &op_run_info) {
318 MS_EXCEPTION_IF_NULL(op_run_info);
319 const auto &prim = op_run_info->op_grad_info->op_prim;
320 MS_EXCEPTION_IF_NULL(prim);
321 const auto &it = implicit_cast_map_.find(prim->name());
322 if (it == implicit_cast_map_.end()) {
323 std::vector<SignatureEnumDType> dtypes;
324 bool has_dtype_sig = GetSignatureType(op_run_info->signatures, &dtypes);
325 if (!has_dtype_sig) {
326 PrimSignature sig_value{has_dtype_sig, {}};
327 implicit_cast_map_[prim->name()] = sig_value;
328 MS_LOG(DEBUG) << "Op " << prim->name() << " has no signature";
329 return;
330 }
331 const auto &signature = op_run_info->signatures;
332 auto sig_size = signature.size();
333 // Ignore monad signature
334 for (const auto &sig : signature) {
335 if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
336 --sig_size;
337 }
338 }
339 if (sig_size > 0 && sig_size != op_run_info->none_init_inputs_num) {
340 MS_EXCEPTION(ValueError) << op_run_info->base_op_run_info.op_name << " inputs number "
341 << op_run_info->none_init_inputs_num << " does not match the requires "
342 << "signature size " << sig_size;
343 }
344
345 auto [args_type_id, args_has_tensor] = GetTypeInfo(op_run_info);
346 auto dst_type = GetSignatureTypeMap(dtypes, args_type_id, args_has_tensor);
347 DoSignatureCast(op_run_info, dst_type, dtypes);
348 PrimSignature sig_value{has_dtype_sig, dtypes};
349 implicit_cast_map_[prim->name()] = sig_value;
350 } else {
351 if (!it->second.has_dtype_sig) {
352 MS_LOG(DEBUG) << op_run_info->base_op_run_info.op_name << " have no dtype sig";
353 return;
354 }
355 MS_LOG(DEBUG) << "Do signature for " << op_run_info->base_op_run_info.op_name << " with cache";
356 auto [args_type_id, args_has_tensor] = GetTypeInfo(op_run_info);
357 auto dst_type = GetSignatureTypeMap(it->second.dtypes, args_type_id, args_has_tensor);
358 DoSignatureCast(op_run_info, dst_type, it->second.dtypes);
359 }
360 }
361 } // namespace pynative
362 } // namespace mindspore
363