1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_DO_CAST_PYBOOST_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_DO_CAST_PYBOOST_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <tuple> 24 #include <map> 25 #include <utility> 26 #include "pipeline/pynative/forward/cast_base.h" 27 #include "frontend/operator/composite/do_signature.h" 28 #include "include/common/utils/convert_utils.h" 29 #include "ir/cell.h" 30 31 namespace mindspore { 32 namespace pynative { 33 static constexpr auto kCast = "Cast"; 34 35 class PyBoostCastOperation : public CastBaseOperation { 36 public: 37 PyBoostCastOperation() = default; 38 ~PyBoostCastOperation() = default; 39 40 template <typename... InputArgs, std::size_t... Index> SetTensorMixPrecisionCastHelper(const FrontendOpRunInfoPtr & op_run_info,std::index_sequence<Index...>,const InputArgs &...input_args)41 auto SetTensorMixPrecisionCastHelper(const FrontendOpRunInfoPtr &op_run_info, std::index_sequence<Index...>, 42 const InputArgs &... input_args) { 43 return std::make_tuple(SetTensorMixPrecisionCast(op_run_info, input_args, Index)...); 44 } 45 46 template <typename... InputArgs> DoMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const InputArgs &...input_args)47 auto DoMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const InputArgs &... input_args) { 48 // Mixed precision conversion tensors which has cast dtype 49 if (op_run_info->async_status.disable_mix_precision) { 50 return std::make_tuple(input_args...); 51 } 52 return SetTensorMixPrecisionCastHelper(op_run_info, std::make_index_sequence<sizeof...(InputArgs)>(), 53 input_args...); 54 } 55 56 template <typename T> SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const T & t,size_t index)57 T SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const T &t, size_t index) const { 58 MS_EXCEPTION_IF_NULL(t); 59 MS_LOG(DEBUG) << "Get input type " << typeid(t).name(); 60 return t; 61 } 62 63 template <typename Item> GetTypeIdInfo(const FrontendOpRunInfoPtr & op_run_info,std::vector<TypeId> * args_type_id,std::vector<bool> * args_has_tensor,size_t i,const Item & v)64 void GetTypeIdInfo(const FrontendOpRunInfoPtr &op_run_info, std::vector<TypeId> *args_type_id, 65 std::vector<bool> *args_has_tensor, size_t i, const Item &v) { 66 MS_EXCEPTION_IF_NULL(v); 67 if (v->template isa<tensor::BaseTensor>()) { 68 (*args_type_id)[i] = v->template cast<tensor::BaseTensorPtr>()->data_type(); 69 // Indicate have do type cast 70 if (op_run_info->source_type[i] == ops::OP_DTYPE::DT_BEGIN) { 71 (*args_has_tensor)[i] = true; 72 } 73 } else if (v->template isa<Scalar>()) { 74 const auto type = v->template cast<ScalarPtr>()->type(); 75 MS_EXCEPTION_IF_NULL(type); 76 (*args_type_id)[i] = type->type_id(); 77 } else { 78 MS_LOG(DEBUG) << "Get value " << v->ToString(); 79 } 80 } 81 82 template <typename Item> GetTypeIdInfo(const FrontendOpRunInfoPtr & op_run_info,std::vector<TypeId> * args_type_id,std::vector<bool> * args_has_tensor,size_t i,const std::optional<Item> & t)83 void GetTypeIdInfo(const FrontendOpRunInfoPtr &op_run_info, std::vector<TypeId> *args_type_id, 84 std::vector<bool> *args_has_tensor, size_t i, const std::optional<Item> &t) { 85 if (!t.has_value()) { 86 return; 87 } 88 } 89 90 template <typename TupleInput, size_t... Index> GetTypeInfo(const FrontendOpRunInfoPtr & op_run_info,const TupleInput & tuple_input,std::index_sequence<Index...>)91 std::pair<std::vector<TypeId>, std::vector<bool>> GetTypeInfo(const FrontendOpRunInfoPtr &op_run_info, 92 const TupleInput &tuple_input, 93 std::index_sequence<Index...>) { 94 std::vector<TypeId> args_type_id; 95 std::vector<bool> args_has_tensor; 96 args_type_id.resize(op_run_info->input_size, kTypeUnknown); 97 args_has_tensor.resize(op_run_info->input_size, false); 98 99 (GetTypeIdInfo(op_run_info, &args_type_id, &args_has_tensor, Index, std::get<Index>(tuple_input)), ...); 100 return {args_type_id, args_has_tensor}; 101 } 102 103 // Implicit transform 104 template <size_t N, typename... InputArgs> DoImplicitCast(const FrontendOpRunInfoPtr & op_run_info,const std::vector<std::vector<size_t>> & same_type_table,const std::tuple<InputArgs...> & input_args)105 auto DoImplicitCast(const FrontendOpRunInfoPtr &op_run_info, const std::vector<std::vector<size_t>> &same_type_table, 106 const std::tuple<InputArgs...> &input_args) { 107 MS_EXCEPTION_IF_NULL(op_run_info); 108 MS_LOG(DEBUG) << "Get signature " << same_type_table; 109 const auto &it = implicit_cast_map_.find(op_run_info->base_op_run_info.op_name); 110 if (it == implicit_cast_map_.end()) { 111 std::vector<SignatureEnumDType> dtypes; 112 // Get current inputs signatures 113 bool has_dtype_sig = GetSignatureType(op_run_info->signatures, &dtypes); 114 if (dtypes.size() > op_run_info->input_size) { 115 MS_LOG(EXCEPTION) << "Signature dtypes size[" << dtypes << "] is greater than input_args_size[" 116 << op_run_info->input_size << "]."; 117 } 118 if (!has_dtype_sig) { 119 PrimSignature sig_value{has_dtype_sig, {}}; 120 implicit_cast_map_[op_run_info->base_op_run_info.op_name] = sig_value; 121 MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name << " has no signature"; 122 return input_args; 123 } 124 PrimSignature sig_value{has_dtype_sig, dtypes}; 125 implicit_cast_map_[op_run_info->base_op_run_info.op_name] = sig_value; 126 127 auto [args_type_id, args_has_tensor] = 128 GetTypeInfo(op_run_info, input_args, std::make_index_sequence<sizeof...(InputArgs)>{}); 129 auto dst_type = GetSignatureTypeMap(dtypes, args_type_id, args_has_tensor); 130 return SetImplicitCast(op_run_info, dst_type, dtypes, input_args, 131 std::make_index_sequence<sizeof...(InputArgs)>{}); 132 } else { 133 if (!it->second.has_dtype_sig) { 134 MS_LOG(DEBUG) << op_run_info->base_op_run_info.op_name << " have no dtype sig"; 135 return input_args; 136 } 137 MS_LOG(DEBUG) << "Do signature for " << op_run_info->base_op_run_info.op_name << " with cache"; 138 auto [args_type_id, args_has_tensor] = 139 GetTypeInfo(op_run_info, input_args, std::make_index_sequence<sizeof...(InputArgs)>{}); 140 auto dst_type = GetSignatureTypeMap(it->second.dtypes, args_type_id, args_has_tensor); 141 return SetImplicitCast(op_run_info, dst_type, it->second.dtypes, input_args, 142 std::make_index_sequence<sizeof...(InputArgs)>{}); 143 } 144 } 145 146 private: 147 template <typename TupleInput, size_t... N> SetImplicitCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes,const TupleInput & input_args,std::index_sequence<N...>)148 auto SetImplicitCast(const FrontendOpRunInfoPtr &op_run_info, 149 const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type, 150 const std::vector<SignatureEnumDType> &dtypes, const TupleInput &input_args, 151 std::index_sequence<N...>) const { 152 MS_EXCEPTION_IF_NULL(op_run_info); 153 return std::make_tuple(DoSignatureCast(op_run_info, dst_type, dtypes, N, std::get<N>(input_args))...); 154 } 155 156 template <typename Item> DoSignatureCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes,size_t index,const Item & t)157 Item DoSignatureCast(const FrontendOpRunInfoPtr &op_run_info, 158 const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type, 159 const std::vector<SignatureEnumDType> &dtypes, size_t index, const Item &t) const { 160 // No need to implicit cast if no dtype. 161 const auto &signature = op_run_info->signatures; 162 if (dtypes.empty() || index >= dtypes.size() || dtypes[index] == SignatureEnumDType::kDTypeEmptyDefaultValue) { 163 MS_LOG(DEBUG) << "Get kDTypeEmptyDefaultValue, or index " << index << " larger than dtype size " << dtypes.size(); 164 return t; 165 } 166 auto it = dst_type.find(dtypes[index]); 167 if (it == dst_type.end() || it->second.first == kTypeUnknown) { 168 MS_LOG(DEBUG) << "Can not find dtype " << (it == dst_type.end()) << ", or type is unknown " 169 << (it->second.first == kTypeUnknown); 170 return t; 171 } 172 173 TypeId arg_type_id = kTypeUnknown; 174 if (t->template isa<tensor::BaseTensor>()) { 175 const auto &arg = t->template cast<tensor::BaseTensorPtr>(); 176 arg_type_id = arg->data_type(); 177 } 178 // Implicit cast 179 bool is_same_type = false; 180 if (arg_type_id != kTypeUnknown) { 181 is_same_type = (arg_type_id == it->second.first); 182 } 183 if (signature[index].rw == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) { 184 prim::RaiseExceptionForConvertRefDtype(op_run_info->op_grad_info->op_prim, TypeIdToString(arg_type_id), 185 TypeIdToString(it->second.first), index); 186 } 187 if (is_same_type) { 188 MS_LOG(DEBUG) << "Get same dtype"; 189 return t; 190 } 191 192 if (IsValueTypeInvalid(t)) { 193 std::string type_str = t->type() == nullptr ? "None, value is \"" + t->ToString() + "\"" : t->type()->ToString(); 194 MS_EXCEPTION(TypeError) << "For '" << op_run_info->op_grad_info->op_prim->name() << "', the " << (index + 1) 195 << "th input " << signature[index].name << " can not be implicitly converted. " 196 << "Its type is " << type_str << ". Only support Tensor or Scalar."; 197 } 198 MS_LOG(DEBUG) << "Implicit cast for " << op_run_info->base_op_run_info.op_name << " " << index 199 << "th input, from type " << (t->type() == nullptr ? t->ToString() : t->type()->ToString()) 200 << " to type " << TypeIdToType(it->second.first)->ToString(); 201 // Has tensor input 202 return DoAutoCast(op_run_info, it->second, index, t); 203 } 204 205 template <typename Item> DoSignatureCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes,size_t index,const std::optional<Item> & t)206 std::optional<Item> DoSignatureCast(const FrontendOpRunInfoPtr &op_run_info, 207 const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type, 208 const std::vector<SignatureEnumDType> &dtypes, size_t index, 209 const std::optional<Item> &t) const { 210 if (!t.has_value()) { 211 return std::nullopt; 212 } 213 return std::make_optional(DoSignatureCast(op_run_info, dst_type, dtypes, index, t.value())); 214 } 215 216 template <class Item> IsValueTypeInvalid(const Item & v)217 bool IsValueTypeInvalid(const Item &v) const { 218 MS_EXCEPTION_IF_NULL(v); 219 return !v->template isa<tensor::BaseTensor>() && !v->template isa<tensor::CSRTensor>() && 220 !v->template isa<IntegerImm>() && !v->template isa<FloatImm>() && !v->template isa<BoolImm>(); 221 } 222 223 // template <class Item, class = typename std::enable_if<std::is_same<Item, tensor::Tensor>::value, Item>::type> 224 template <class Item> DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const std::pair<TypeId,bool> & dst_type,size_t index,const Item & t)225 Item DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const std::pair<TypeId, bool> &dst_type, size_t index, 226 const Item &t) const { 227 MS_EXCEPTION_IF_NULL(t); 228 MS_LOG(DEBUG) << "Get input type " << typeid(t).name(); 229 ValuePtr v = t->template cast<ValuePtr>(); 230 auto ret = DoAutoCast(op_run_info, dst_type, index, v)->template cast<Item>(); 231 MS_EXCEPTION_IF_NULL(ret); 232 return ret; 233 } 234 235 ValuePtr DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const std::pair<TypeId, bool> &dst_type, size_t index, 236 const ValuePtr &v) const; 237 tensor::BaseTensorPtr DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const std::pair<TypeId, bool> &dst_type, 238 size_t index, const tensor::BaseTensorPtr &t) const; 239 ValuePtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v, size_t index) const; 240 tensor::BaseTensorPtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, 241 const tensor::BaseTensorPtr &t, size_t index) const; 242 std::optional<tensor::BaseTensorPtr> SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, 243 const std::optional<tensor::BaseTensorPtr> &t, 244 size_t index) const; 245 ValueTuplePtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValueTuplePtr &v_tuple, 246 size_t index) const; 247 ValueListPtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValueListPtr &v_list, 248 size_t index) const; 249 ValuePtrList SetSeqMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &v_seq, 250 size_t index) const; 251 }; 252 using PyBoostCastOperationPtr = std::shared_ptr<PyBoostCastOperation>; 253 254 } // namespace pynative 255 } // namespace mindspore 256 #endif // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_DO_CAST_PYBOOST_H_ 257