• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/forward/do_pyboost_cast.h"
18 #include "pipeline/pynative/pynative_utils.h"
19 #include "kernel/pyboost/auto_generate/cast.h"
20 #include "include/common/utils/stub_tensor.h"
21 
22 namespace mindspore {
23 namespace pynative {
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const std::pair<TypeId,bool> & dst_type,size_t index,const ValuePtr & v) const24 ValuePtr PyBoostCastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info,
25                                           const std::pair<TypeId, bool> &dst_type, size_t index,
26                                           const ValuePtr &v) const {
27   MS_EXCEPTION_IF_NULL(v);
28   ValuePtr dst_value = ScalarToDstDtypeValue(v, dst_type);
29   if (dst_value != nullptr) {
30     MS_LOG(DEBUG) << "Source value: " << v->ToString() << " cast to value: " << dst_value->ToString();
31     return dst_value;
32   }
33   if (!v->isa<tensor::BaseTensor>()) {
34     return v;
35   }
36   return DoAutoCast(op_run_info, dst_type, index, v->cast<tensor::BaseTensorPtr>());
37 }
38 
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const std::pair<TypeId,bool> & dst_type,size_t index,const tensor::BaseTensorPtr & t) const39 tensor::BaseTensorPtr PyBoostCastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info,
40                                                        const std::pair<TypeId, bool> &dst_type, size_t index,
41                                                        const tensor::BaseTensorPtr &t) const {
42   if (op_run_info->source_type[index] != ops::OP_DTYPE::DT_BEGIN) {
43     MS_LOG(DEBUG) << "Try cast Source tensor: " << t->ToString();
44     auto dst_tensor = TensorToDstDtypeValue(t, dst_type.first);
45     MS_LOG(DEBUG) << "Cast to dst tensor: " << dst_tensor->ToString() << " without dispatching cast op";
46     return dst_tensor;
47   }
48   auto type_id64 = std::make_shared<Int64Imm>(static_cast<int64_t>(dst_type.first));
49   const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
50   auto cast_prim = GetPrimByTypeId(dst_type.first);
51   // Use pyboost op call
52   cast_run_info->base_op_run_info.device_target =
53     PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->GetCurrentDeviceTarget(cast_prim);
54   auto cast_op = CREATE_PYBOOST_OP(Cast, cast_run_info->base_op_run_info.device_target);
55   (void)cast_op->Call(t, type_id64);
56   cast_run_info->requires_grad = op_run_info->requires_grad;
57   PyNativeAlgo::PyBoost::UpdateOpRunInfo(cast_op, cast_run_info);
58   if (op_run_info->requires_grad) {
59     constexpr auto input_size = 2;
60     cast_run_info->input_size = input_size;
61     cast_run_info->base_op_run_info.op_name = kCast;
62     cast_run_info->op_grad_info->op_prim = cast_prim;
63     PyNativeAlgo::PyBoost::DoGrad(cast_op, cast_run_info, {t, type_id64});
64   }
65   return cast_run_info->real_out->cast<tensor::BaseTensorPtr>();
66 }
67 
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValuePtr & v,size_t index) const68 ValuePtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v,
69                                                          size_t index) const {
70   MS_EXCEPTION_IF_NULL(v);
71   if (v->isa<tensor::BaseTensor>()) {
72     return SetTensorMixPrecisionCast(op_run_info, v->cast<tensor::BaseTensorPtr>(), index);
73   }
74   return v;
75 }
76 
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const tensor::BaseTensorPtr & t,size_t index) const77 tensor::BaseTensorPtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
78                                                                       const tensor::BaseTensorPtr &t,
79                                                                       size_t index) const {
80   MS_EXCEPTION_IF_NULL(op_run_info);
81   MS_EXCEPTION_IF_NULL(t);
82   if (op_run_info->mix_type != kNotSet) {
83     auto dst_dtype = kFloat16;
84     if (op_run_info->mix_type == kFP32) {
85       dst_dtype = kFloat32;
86     } else if (op_run_info->mix_type == kBF16) {
87       dst_dtype = kBFloat16;
88     }
89 
90     auto source_dtype = t->Dtype();
91     if (source_dtype != nullptr && (IsSubType(source_dtype, kFloat) || IsSubType(source_dtype, kBFloat)) &&
92         *source_dtype != *dst_dtype) {
93       MS_LOG(DEBUG) << "MixPrecision cast for " << op_run_info->base_op_run_info.op_name << " " << index
94                     << "th input, and to type " << dst_dtype->ToString();
95       auto cast_t = DoAutoCast(op_run_info, std::make_pair(dst_dtype->type_id(), true), index, t);
96       return cast_t;
97     }
98   }
99   return t;
100 }
101 
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const std::optional<tensor::BaseTensorPtr> & t,size_t index) const102 std::optional<tensor::BaseTensorPtr> PyBoostCastOperation::SetTensorMixPrecisionCast(
103   const FrontendOpRunInfoPtr &op_run_info, const std::optional<tensor::BaseTensorPtr> &t, size_t index) const {
104   MS_EXCEPTION_IF_NULL(op_run_info);
105   if (!t.has_value()) {
106     return std::nullopt;
107   }
108   return std::make_optional(SetTensorMixPrecisionCast(op_run_info, t.value(), index));
109 }
110 
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValueTuplePtr & v_tuple,size_t index) const111 ValueTuplePtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
112                                                               const ValueTuplePtr &v_tuple, size_t index) const {
113   return std::make_shared<ValueTuple>(SetSeqMixPrecisionCast(op_run_info, v_tuple, index));
114 }
115 
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValueListPtr & v_list,size_t index) const116 ValueListPtr PyBoostCastOperation::SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
117                                                              const ValueListPtr &v_list, size_t index) const {
118   return std::make_shared<ValueList>(SetSeqMixPrecisionCast(op_run_info, v_list, index));
119 }
120 
SetSeqMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const ValueSequencePtr & v_seq,size_t index) const121 ValuePtrList PyBoostCastOperation::SetSeqMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
122                                                           const ValueSequencePtr &v_seq, size_t index) const {
123   MS_EXCEPTION_IF_NULL(op_run_info);
124   MS_EXCEPTION_IF_NULL(v_seq);
125   size_t tuple_size = v_seq->size();
126   const auto &value_tuple = v_seq->value();
127   ValuePtrList result(tuple_size, nullptr);
128   for (size_t i = 0; i < tuple_size; i++) {
129     if (value_tuple[i]->isa<tensor::MetaTensor>()) {
130       MS_LOG(DEBUG) << "Call cast for " << i << "th input";
131       result[i] = SetTensorMixPrecisionCast(op_run_info, value_tuple[i], index);
132     } else if (value_tuple[i]->isa<ValueTuple>()) {
133       result[i] = SetTensorMixPrecisionCast(op_run_info, value_tuple[i]->cast<ValueTuplePtr>(), index);
134     } else if (value_tuple[i]->isa<ValueList>()) {
135       result[i] = SetTensorMixPrecisionCast(op_run_info, value_tuple[i]->cast<ValueListPtr>(), index);
136     } else {
137       result[i] = value_tuple[i];
138     }
139   }
140   return result;
141 }
142 }  // namespace pynative
143 }  // namespace mindspore
144