1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/forward/do_pyboost_cast.h"
18 #include "pipeline/pynative/pynative_utils.h"
19 #include "kernel/pyboost/auto_generate/cast.h"
20 #include "include/common/utils/stub_tensor.h"
21
22 namespace mindspore {
23 namespace pynative {
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const std::pair<TypeId,bool> & dst_type,size_t index,const ValuePtr & v) const24 ValuePtr PyBoostCastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info,
25 const std::pair<TypeId, bool> &dst_type, size_t index,
26 const ValuePtr &v) const {
27 MS_EXCEPTION_IF_NULL(v);
28 ValuePtr dst_value = ScalarToDstDtypeValue(v, dst_type);
29 if (dst_value != nullptr) {
30 MS_LOG(DEBUG) << "Source value: " << v->ToString() << " cast to value: " << dst_value->ToString();
31 return dst_value;
32 }
33 if (!v->isa<tensor::BaseTensor>()) {
34 return v;
35 }
36 return DoAutoCast(op_run_info, dst_type, index, v->cast<tensor::BaseTensorPtr>());
37 }
38
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const std::pair<TypeId,bool> & dst_type,size_t index,const tensor::BaseTensorPtr & t) const39 tensor::BaseTensorPtr PyBoostCastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info,
40 const std::pair<TypeId, bool> &dst_type, size_t index,
41 const tensor::BaseTensorPtr &t) const {
42 if (op_run_info->source_type[index] != ops::OP_DTYPE::DT_BEGIN) {
43 MS_LOG(DEBUG) << "Try cast Source tensor: " << t->ToString();
44 auto dst_tensor = TensorToDstDtypeValue(t, dst_type.first);
45 MS_LOG(DEBUG) << "Cast to dst tensor: " << dst_tensor->ToString() << " without dispatching cast op";
46 return dst_tensor;
47 }
48 auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(dst_type.first));
49 const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
50 auto cast_prim = GetPrimByTypeId(dst_type.first);
51 // Use pyboost op call
52 cast_run_info->base_op_run_info.device_target =
53 PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->GetCurrentDeviceTarget(cast_prim);
54 auto cast_op = CREATE_PYBOOST_OP(Cast, cast_run_info->base_op_run_info.device_target);
55 (void)cast_op->Call(t, type_id64);
56 cast_run_info->requires_grad = op_run_info->requires_grad;
57 PyNativeAlgo::PyBoost::UpdateOpRunInfo(cast_op, cast_run_info);
58 if (op_run_info->requires_grad) {
59 constexpr auto input_size = 2;
60 cast_run_info->input_size = input_size;
61 cast_run_info->base_op_run_info.op_name = kCast;
62 cast_run_info->op_grad_info->op_prim = cast_prim;
63 PyNativeAlgo::PyBoost::DoGrad(cast_op, cast_run_info, {t, type_id64});
64 }
65 return cast_run_info->real_out->cast<tensor::BaseTensorPtr>();
66 }
67
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,size_t index) const68 ValuePtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
69 size_t index) const {
70 MS_EXCEPTION_IF_NULL(v);
71 if (v->isa<tensor::BaseTensor>()) {
72 return SetTensorMixPrecisionCast(op_run_info, v->cast<tensor::BaseTensorPtr>(), index);
73 }
74 return v;
75 }
76
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const tensor::BaseTensorPtr & t,size_t index) const77 tensor::BaseTensorPtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
78 const tensor::BaseTensorPtr &t,
79 size_t index) const {
80 MS_EXCEPTION_IF_NULL(op_run_info);
81 MS_EXCEPTION_IF_NULL(t);
82 if (op_run_info->mix_type != kNotSet) {
83 auto dst_dtype = kFloat16;
84 if (op_run_info->mix_type == kFP32) {
85 dst_dtype = kFloat32;
86 } else if (op_run_info->mix_type == kBF16) {
87 dst_dtype = kBFloat16;
88 }
89
90 auto source_dtype = t->Dtype();
91 if (source_dtype != nullptr && (IsSubType(source_dtype, kFloat) || IsSubType(source_dtype, kBFloat)) &&
92 *source_dtype != *dst_dtype) {
93 MS_LOG(DEBUG) << "MixPrecision cast for " << op_run_info->base_op_run_info.op_name << " " << index
94 << "th input, and to type " << dst_dtype->ToString();
95 auto cast_t = DoAutoCast(op_run_info, std::make_pair(dst_dtype->type_id(), true), index, t);
96 return cast_t;
97 }
98 }
99 return t;
100 }
101
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const std::optional<tensor::BaseTensorPtr> & t,size_t index) const102 std::optional<tensor::BaseTensorPtr> PyBoostCastOperation::SetTensorMixPrecisionCast(
103 const FrontendOpRunInfoPtr &op_run_info, const std::optional<tensor::BaseTensorPtr> &t, size_t index) const {
104 MS_EXCEPTION_IF_NULL(op_run_info);
105 if (!t.has_value()) {
106 return std::nullopt;
107 }
108 return std::make_optional(SetTensorMixPrecisionCast(op_run_info, t.value(), index));
109 }
110
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValueTuplePtr & v_tuple,size_t index) const111 ValueTuplePtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
112 const ValueTuplePtr &v_tuple, size_t index) const {
113 return std::make_shared<ValueTuple>(SetSeqMixPrecisionCast(op_run_info, v_tuple, index));
114 }
115
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValueListPtr & v_list,size_t index) const116 ValueListPtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
117 const ValueListPtr &v_list, size_t index) const {
118 return std::make_shared<ValueList>(SetSeqMixPrecisionCast(op_run_info, v_list, index));
119 }
120
SetSeqMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & v_seq,size_t index) const121 ValuePtrList PyBoostCastOperation::SetSeqMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
122 const ValueSequencePtr &v_seq, size_t index) const {
123 MS_EXCEPTION_IF_NULL(op_run_info);
124 MS_EXCEPTION_IF_NULL(v_seq);
125 size_t tuple_size = v_seq->size();
126 const auto &value_tuple = v_seq->value();
127 ValuePtrList result(tuple_size, nullptr);
128 for (size_t i = 0; i < tuple_size; i++) {
129 if (value_tuple[i]->isa<tensor::MetaTensor>()) {
130 MS_LOG(DEBUG) << "Call cast for " << i << "th input";
131 result[i] = SetTensorMixPrecisionCast(op_run_info, value_tuple[i], index);
132 } else if (value_tuple[i]->isa<ValueTuple>()) {
133 result[i] = SetTensorMixPrecisionCast(op_run_info, value_tuple[i]->cast<ValueTuplePtr>(), index);
134 } else if (value_tuple[i]->isa<ValueList>()) {
135 result[i] = SetTensorMixPrecisionCast(op_run_info, value_tuple[i]->cast<ValueListPtr>(), index);
136 } else {
137 result[i] = value_tuple[i];
138 }
139 }
140 return result;
141 }
142 } // namespace pynative
143 } // namespace mindspore
144