• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_DO_CAST_PYBOOST_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_DO_CAST_PYBOOST_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <tuple>
24 #include <map>
25 #include <utility>
26 #include "pipeline/pynative/forward/cast_base.h"
27 #include "frontend/operator/composite/do_signature.h"
28 #include "include/common/utils/convert_utils.h"
29 #include "ir/cell.h"
30 
31 namespace mindspore {
32 namespace pynative {
33 static constexpr auto kCast = "Cast";
34 
35 class PyBoostCastOperation : public CastBaseOperation {
36  public:
37   PyBoostCastOperation() = default;
38   ~PyBoostCastOperation() = default;
39 
40   template <typename... InputArgs, std::size_t... Index>
SetTensorMixPrecisionCastHelper(const FrontendOpRunInfoPtr & op_run_info,std::index_sequence<Index...>,const InputArgs &...input_args)41   auto SetTensorMixPrecisionCastHelper(const FrontendOpRunInfoPtr &op_run_info, std::index_sequence<Index...>,
42                                        const InputArgs &... input_args) {
43     return std::make_tuple(SetTensorMixPrecisionCast(op_run_info, input_args, Index)...);
44   }
45 
46   template <typename... InputArgs>
DoMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const InputArgs &...input_args)47   auto DoMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const InputArgs &... input_args) {
48     // Mixed precision conversion tensors which has cast dtype
49     if (op_run_info->async_status.disable_mix_precision) {
50       return std::make_tuple(input_args...);
51     }
52     return SetTensorMixPrecisionCastHelper(op_run_info, std::make_index_sequence<sizeof...(InputArgs)>(),
53                                            input_args...);
54   }
55 
56   template <typename T>
SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr & op_run_info,const T & t,size_t index)57   T SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const T &t, size_t index) const {
58     MS_EXCEPTION_IF_NULL(t);
59     MS_LOG(DEBUG) << "Get input type " << typeid(t).name();
60     return t;
61   }
62 
63   template <typename Item>
GetTypeIdInfo(const FrontendOpRunInfoPtr & op_run_info,std::vector<TypeId> * args_type_id,std::vector<bool> * args_has_tensor,size_t i,const Item & v)64   void GetTypeIdInfo(const FrontendOpRunInfoPtr &op_run_info, std::vector<TypeId> *args_type_id,
65                      std::vector<bool> *args_has_tensor, size_t i, const Item &v) {
66     MS_EXCEPTION_IF_NULL(v);
67     if (v->template isa<tensor::BaseTensor>()) {
68       (*args_type_id)[i] = v->template cast<tensor::BaseTensorPtr>()->data_type();
69       // Indicate have do type cast
70       if (op_run_info->source_type[i] == ops::OP_DTYPE::DT_BEGIN) {
71         (*args_has_tensor)[i] = true;
72       }
73     } else if (v->template isa<Scalar>()) {
74       const auto type = v->template cast<ScalarPtr>()->type();
75       MS_EXCEPTION_IF_NULL(type);
76       (*args_type_id)[i] = type->type_id();
77     } else {
78       MS_LOG(DEBUG) << "Get value " << v->ToString();
79     }
80   }
81 
82   template <typename Item>
GetTypeIdInfo(const FrontendOpRunInfoPtr & op_run_info,std::vector<TypeId> * args_type_id,std::vector<bool> * args_has_tensor,size_t i,const std::optional<Item> & t)83   void GetTypeIdInfo(const FrontendOpRunInfoPtr &op_run_info, std::vector<TypeId> *args_type_id,
84                      std::vector<bool> *args_has_tensor, size_t i, const std::optional<Item> &t) {
85     if (!t.has_value()) {
86       return;
87     }
88   }
89 
90   template <typename TupleInput, size_t... Index>
GetTypeInfo(const FrontendOpRunInfoPtr & op_run_info,const TupleInput & tuple_input,std::index_sequence<Index...>)91   std::pair<std::vector<TypeId>, std::vector<bool>> GetTypeInfo(const FrontendOpRunInfoPtr &op_run_info,
92                                                                 const TupleInput &tuple_input,
93                                                                 std::index_sequence<Index...>) {
94     std::vector<TypeId> args_type_id;
95     std::vector<bool> args_has_tensor;
96     args_type_id.resize(op_run_info->input_size, kTypeUnknown);
97     args_has_tensor.resize(op_run_info->input_size, false);
98 
99     (GetTypeIdInfo(op_run_info, &args_type_id, &args_has_tensor, Index, std::get<Index>(tuple_input)), ...);
100     return {args_type_id, args_has_tensor};
101   }
102 
103   // Implicit transform
104   template <size_t N, typename... InputArgs>
DoImplicitCast(const FrontendOpRunInfoPtr & op_run_info,const std::vector<std::vector<size_t>> & same_type_table,const std::tuple<InputArgs...> & input_args)105   auto DoImplicitCast(const FrontendOpRunInfoPtr &op_run_info, const std::vector<std::vector<size_t>> &same_type_table,
106                       const std::tuple<InputArgs...> &input_args) {
107     MS_EXCEPTION_IF_NULL(op_run_info);
108     MS_LOG(DEBUG) << "Get signature " << same_type_table;
109     const auto &it = implicit_cast_map_.find(op_run_info->base_op_run_info.op_name);
110     if (it == implicit_cast_map_.end()) {
111       std::vector<SignatureEnumDType> dtypes;
112       // Get current inputs signatures
113       bool has_dtype_sig = GetSignatureType(op_run_info->signatures, &dtypes);
114       if (dtypes.size() > op_run_info->input_size) {
115         MS_LOG(EXCEPTION) << "Signature dtypes size[" << dtypes << "] is greater than input_args_size["
116                           << op_run_info->input_size << "].";
117       }
118       if (!has_dtype_sig) {
119         PrimSignature sig_value{has_dtype_sig, {}};
120         implicit_cast_map_[op_run_info->base_op_run_info.op_name] = sig_value;
121         MS_LOG(DEBUG) << "Op " << op_run_info->base_op_run_info.op_name << " has no signature";
122         return input_args;
123       }
124       PrimSignature sig_value{has_dtype_sig, dtypes};
125       implicit_cast_map_[op_run_info->base_op_run_info.op_name] = sig_value;
126 
127       auto [args_type_id, args_has_tensor] =
128         GetTypeInfo(op_run_info, input_args, std::make_index_sequence<sizeof...(InputArgs)>{});
129       auto dst_type = GetSignatureTypeMap(dtypes, args_type_id, args_has_tensor);
130       return SetImplicitCast(op_run_info, dst_type, dtypes, input_args,
131                              std::make_index_sequence<sizeof...(InputArgs)>{});
132     } else {
133       if (!it->second.has_dtype_sig) {
134         MS_LOG(DEBUG) << op_run_info->base_op_run_info.op_name << " have no dtype sig";
135         return input_args;
136       }
137       MS_LOG(DEBUG) << "Do signature for " << op_run_info->base_op_run_info.op_name << " with cache";
138       auto [args_type_id, args_has_tensor] =
139         GetTypeInfo(op_run_info, input_args, std::make_index_sequence<sizeof...(InputArgs)>{});
140       auto dst_type = GetSignatureTypeMap(it->second.dtypes, args_type_id, args_has_tensor);
141       return SetImplicitCast(op_run_info, dst_type, it->second.dtypes, input_args,
142                              std::make_index_sequence<sizeof...(InputArgs)>{});
143     }
144   }
145 
146  private:
147   template <typename TupleInput, size_t... N>
SetImplicitCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes,const TupleInput & input_args,std::index_sequence<N...>)148   auto SetImplicitCast(const FrontendOpRunInfoPtr &op_run_info,
149                        const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type,
150                        const std::vector<SignatureEnumDType> &dtypes, const TupleInput &input_args,
151                        std::index_sequence<N...>) const {
152     MS_EXCEPTION_IF_NULL(op_run_info);
153     return std::make_tuple(DoSignatureCast(op_run_info, dst_type, dtypes, N, std::get<N>(input_args))...);
154   }
155 
156   template <typename Item>
DoSignatureCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes,size_t index,const Item & t)157   Item DoSignatureCast(const FrontendOpRunInfoPtr &op_run_info,
158                        const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type,
159                        const std::vector<SignatureEnumDType> &dtypes, size_t index, const Item &t) const {
160     // No need to implicit cast if no dtype.
161     const auto &signature = op_run_info->signatures;
162     if (dtypes.empty() || index >= dtypes.size() || dtypes[index] == SignatureEnumDType::kDTypeEmptyDefaultValue) {
163       MS_LOG(DEBUG) << "Get kDTypeEmptyDefaultValue, or index " << index << " larger than dtype size " << dtypes.size();
164       return t;
165     }
166     auto it = dst_type.find(dtypes[index]);
167     if (it == dst_type.end() || it->second.first == kTypeUnknown) {
168       MS_LOG(DEBUG) << "Can not find dtype " << (it == dst_type.end()) << ", or type is unknown "
169                     << (it->second.first == kTypeUnknown);
170       return t;
171     }
172 
173     TypeId arg_type_id = kTypeUnknown;
174     if (t->template isa<tensor::BaseTensor>()) {
175       const auto &arg = t->template cast<tensor::BaseTensorPtr>();
176       arg_type_id = arg->data_type();
177     }
178     // Implicit cast
179     bool is_same_type = false;
180     if (arg_type_id != kTypeUnknown) {
181       is_same_type = (arg_type_id == it->second.first);
182     }
183     if (signature[index].rw == SignatureEnumRW::kRWWrite && arg_type_id != kTypeUnknown && !is_same_type) {
184       prim::RaiseExceptionForConvertRefDtype(op_run_info->op_grad_info->op_prim, TypeIdToString(arg_type_id),
185                                              TypeIdToString(it->second.first), index);
186     }
187     if (is_same_type) {
188       MS_LOG(DEBUG) << "Get same dtype";
189       return t;
190     }
191 
192     if (IsValueTypeInvalid(t)) {
193       std::string type_str = t->type() == nullptr ? "None, value is \"" + t->ToString() + "\"" : t->type()->ToString();
194       MS_EXCEPTION(TypeError) << "For '" << op_run_info->op_grad_info->op_prim->name() << "', the " << (index + 1)
195                               << "th input " << signature[index].name << " can not be implicitly converted. "
196                               << "Its type is " << type_str << ". Only support Tensor or Scalar.";
197     }
198     MS_LOG(DEBUG) << "Implicit cast for " << op_run_info->base_op_run_info.op_name << " " << index
199                   << "th input, from type " << (t->type() == nullptr ? t->ToString() : t->type()->ToString())
200                   << " to type " << TypeIdToType(it->second.first)->ToString();
201     // Has tensor input
202     return DoAutoCast(op_run_info, it->second, index, t);
203   }
204 
205   template <typename Item>
DoSignatureCast(const FrontendOpRunInfoPtr & op_run_info,const std::map<SignatureEnumDType,std::pair<TypeId,bool>> & dst_type,const std::vector<SignatureEnumDType> & dtypes,size_t index,const std::optional<Item> & t)206   std::optional<Item> DoSignatureCast(const FrontendOpRunInfoPtr &op_run_info,
207                                       const std::map<SignatureEnumDType, std::pair<TypeId, bool>> &dst_type,
208                                       const std::vector<SignatureEnumDType> &dtypes, size_t index,
209                                       const std::optional<Item> &t) const {
210     if (!t.has_value()) {
211       return std::nullopt;
212     }
213     return std::make_optional(DoSignatureCast(op_run_info, dst_type, dtypes, index, t.value()));
214   }
215 
216   template <class Item>
IsValueTypeInvalid(const Item & v)217   bool IsValueTypeInvalid(const Item &v) const {
218     MS_EXCEPTION_IF_NULL(v);
219     return !v->template isa<tensor::BaseTensor>() && !v->template isa<tensor::CSRTensor>() &&
220            !v->template isa<IntegerImm>() && !v->template isa<FloatImm>() && !v->template isa<BoolImm>();
221   }
222 
223   //  template <class Item, class = typename std::enable_if<std::is_same<Item, tensor::Tensor>::value, Item>::type>
224   template <class Item>
DoAutoCast(const FrontendOpRunInfoPtr & op_run_info,const std::pair<TypeId,bool> & dst_type,size_t index,const Item & t)225   Item DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const std::pair<TypeId, bool> &dst_type, size_t index,
226                   const Item &t) const {
227     MS_EXCEPTION_IF_NULL(t);
228     MS_LOG(DEBUG) << "Get input type " << typeid(t).name();
229     ValuePtr v = t->template cast<ValuePtr>();
230     auto ret = DoAutoCast(op_run_info, dst_type, index, v)->template cast<Item>();
231     MS_EXCEPTION_IF_NULL(ret);
232     return ret;
233   }
234 
235   ValuePtr DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const std::pair<TypeId, bool> &dst_type, size_t index,
236                       const ValuePtr &v) const;
237   tensor::BaseTensorPtr DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, const std::pair<TypeId, bool> &dst_type,
238                                    size_t index, const tensor::BaseTensorPtr &t) const;
239   ValuePtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValuePtr &v, size_t index) const;
240   tensor::BaseTensorPtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
241                                                   const tensor::BaseTensorPtr &t, size_t index) const;
242   std::optional<tensor::BaseTensorPtr> SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info,
243                                                                  const std::optional<tensor::BaseTensorPtr> &t,
244                                                                  size_t index) const;
245   ValueTuplePtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValueTuplePtr &v_tuple,
246                                           size_t index) const;
247   ValueListPtr SetTensorMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValueListPtr &v_list,
248                                          size_t index) const;
249   ValuePtrList SetSeqMixPrecisionCast(const FrontendOpRunInfoPtr &op_run_info, const ValueSequencePtr &v_seq,
250                                       size_t index) const;
251 };
252 using PyBoostCastOperationPtr = std::shared_ptr<PyBoostCastOperation>;
253 
254 }  // namespace pynative
255 }  // namespace mindspore
256 #endif  // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_DO_CAST_PYBOOST_H_
257