• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/forward/cast_base.h"
18 #include <memory>
19 #include <algorithm>
20 #include "ops/array_ops.h"
21 #include "frontend/operator/composite/do_signature.h"
22 
23 namespace mindspore {
24 namespace pynative {
25 namespace {
26 const char kOpsFunctionModelName[] = "mindspore.ops.functional";
27 
28 template <typename S>
CastScalarToScalar(S in,const TypeId & type_id)29 ValuePtr CastScalarToScalar(S in, const TypeId &type_id) {
30   switch (type_id) {
31     case kNumberTypeInt32:
32       return MakeValue(static_cast<int>(in));
33     case kNumberTypeFloat16:
34       return MakeValue(static_cast<float16>(in).int_value());
35     case kNumberTypeFloat32:
36       return MakeValue(static_cast<float>(in));
37     case kNumberTypeBool:
38       return MakeValue(static_cast<bool>(in));
39     case kNumberTypeInt64:
40       return MakeValue(static_cast<int64_t>(in));
41     case kNumberTypeFloat64:
42       return MakeValue(static_cast<double>(in));
43     case kNumberTypeInt16:
44       return MakeValue(static_cast<int16_t>(in));
45     case kNumberTypeInt8:
46       return MakeValue(static_cast<int8_t>(in));
47     case kNumberTypeUInt64:
48       return MakeValue(static_cast<uint64_t>(in));
49     case kNumberTypeUInt32:
50       return MakeValue(static_cast<uint32_t>(in));
51     case kNumberTypeUInt16:
52       return MakeValue(static_cast<uint16_t>(in));
53     case kNumberTypeUInt8:
54       return MakeValue(static_cast<uint8_t>(in));
55     case kNumberTypeBFloat16:
56       return MakeValue(static_cast<float16>(in).int_value());
57     default:
58       MS_LOG(DEBUG) << "Not support cast to dst type: " << TypeIdToType(type_id)->ToString();
59       return nullptr;
60   }
61 }
62 
63 template <typename S>
CastScalarToTensor(S in,const TypeId & type_id)64 ValuePtr CastScalarToTensor(S in, const TypeId &type_id) {
65   switch (type_id) {
66     case kNumberTypeInt32:
67       return std::make_shared<tensor::Tensor>(static_cast<int>(in), kInt32);
68     case kNumberTypeFloat16:
69       return std::make_shared<tensor::Tensor>(static_cast<float16>(in), kFloat16);
70     case kNumberTypeFloat32:
71       return std::make_shared<tensor::Tensor>(static_cast<float>(in), kFloat32);
72     case kNumberTypeBool:
73       return std::make_shared<tensor::Tensor>(static_cast<bool>(in), kBool);
74     case kNumberTypeInt64:
75       return std::make_shared<tensor::Tensor>(static_cast<int64_t>(in), kInt64);
76     case kNumberTypeFloat64:
77       return std::make_shared<tensor::Tensor>(static_cast<double>(in), kFloat64);
78     case kNumberTypeInt16:
79       return std::make_shared<tensor::Tensor>(static_cast<int16_t>(in), kInt16);
80     case kNumberTypeInt8:
81       return std::make_shared<tensor::Tensor>(static_cast<int8_t>(in), kInt8);
82     case kNumberTypeUInt64:
83       return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(in), kUInt64);
84     case kNumberTypeUInt32:
85       return std::make_shared<tensor::Tensor>(static_cast<uint32_t>(in), kUInt32);
86     case kNumberTypeUInt16:
87       return std::make_shared<tensor::Tensor>(static_cast<uint16_t>(in), kUInt16);
88     case kNumberTypeUInt8:
89       return std::make_shared<tensor::Tensor>(static_cast<uint8_t>(in), kUInt8);
90     case kNumberTypeBFloat16:
91       return std::make_shared<tensor::Tensor>(static_cast<bfloat16>(in), kBFloat16);
92     default:
93       MS_LOG(DEBUG) << "Not support cast to dst type: " << TypeIdToType(type_id)->ToString();
94       return nullptr;
95   }
96 }
97 
98 template <typename S>
Cast(S in,const std::pair<TypeId,bool> & dst_type)99 ValuePtr Cast(S in, const std::pair<TypeId, bool> &dst_type) {
100   bool has_tensor_input = dst_type.second;
101   if (has_tensor_input) {
102     return CastScalarToTensor(in, dst_type.first);
103   }
104   return CastScalarToScalar(in, dst_type.first);
105 }
106 }  // namespace
107 
GetPrimByTypeId(const TypeId & type_id) const108 PrimitivePtr CastBaseOperation::GetPrimByTypeId(const TypeId &type_id) const {
109   const auto &iter = type_prim_cache_.find(type_id);
110   if (iter != type_prim_cache_.end()) {
111     return iter->second;
112   }
113 
114 #ifndef ENABLE_TEST
115   auto cast_prim = std::make_shared<Primitive>(kCastOpName);
116   std::vector<std::string> input_names = {"x", "dst_type"};
117   std::vector<std::string> output_names = {"output"};
118   (void)cast_prim->AddAttr("input_names", MakeValue(input_names));
119   (void)cast_prim->AddAttr("output_names", MakeValue(output_names));
120   type_prim_cache_[type_id] = cast_prim;
121   cast_prim->EnableSharedMutex();
122   return cast_prim;
123 #else
124   py::gil_scoped_acquire gil;
125   const auto &cast_prim = python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
126   auto prim_adapter = cast_prim.cast<PrimitivePyAdapterPtr>();
127   MS_EXCEPTION_IF_NULL(prim_adapter);
128   auto primitive = prim_adapter->attached_primitive();
129   if (primitive == nullptr) {
130     primitive = std::make_shared<PrimitivePy>(cast_prim);
131     prim_adapter->set_attached_primitive(primitive);
132   }
133   if (!primitive->HasPyObj()) {
134     MS_LOG(EXCEPTION) << "Pyobj is empty";
135   }
136   type_prim_cache_[type_id] = primitive;
137   primitive->EnableSharedMutex();
138   return primitive;
139 #endif
140 }
141 
GetSignatureType(const std::vector<Signature> & signatures,std::vector<SignatureEnumDType> * dtypes) const142 bool CastBaseOperation::GetSignatureType(const std::vector<Signature> &signatures,
143                                          std::vector<SignatureEnumDType> *dtypes) const {
144   MS_EXCEPTION_IF_NULL(dtypes);
145   bool has_sig_dtype = false;
146   (void)std::transform(signatures.begin(), signatures.end(), std::back_inserter(*dtypes),
147                        [&has_sig_dtype](const Signature &sig) {
148                          auto dtype = sig.dtype;
149                          if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
150                            has_sig_dtype = true;
151                          }
152                          return dtype;
153                        });
154   return has_sig_dtype;
155 }
156 
TensorToDstDtypeValue(const ValuePtr & src_value,const TypeId & dst_type_id) const157 tensor::BaseTensorPtr CastBaseOperation::TensorToDstDtypeValue(const ValuePtr &src_value,
158                                                                const TypeId &dst_type_id) const {
159   MS_EXCEPTION_IF_NULL(src_value);
160   auto src_tensor = src_value->cast<tensor::BaseTensorPtr>();
161   MS_EXCEPTION_IF_NULL(src_tensor);
162   (void)src_tensor->set_data_type(dst_type_id);
163   return src_tensor;
164 }
165 
166 // This function is used to convert scalar value to another scalar value with destination data type.
167 // The scope of scalar type includes common data types, such as `FP64`, `FP32`, `FP16, `Int64`, `Int32`, ...
168 // The following sort is based on the hot spots of the data type.
ScalarToDstDtypeValue(const ValuePtr & src_value,const std::pair<TypeId,bool> & dst_type) const169 ValuePtr CastBaseOperation::ScalarToDstDtypeValue(const ValuePtr &src_value,
170                                                   const std::pair<TypeId, bool> &dst_type) const {
171   MS_EXCEPTION_IF_NULL(src_value);
172   // Tensor not do scalar cast
173   if (src_value->isa<tensor::BaseTensor>()) {
174     return nullptr;
175   } else if (src_value->isa<Int64Imm>()) {
176     const auto &int64_v = src_value->cast<Int64ImmPtr>();
177     return Cast<int64_t>(int64_v->value(), dst_type);
178   } else if (src_value->isa<FP32Imm>()) {
179     const auto &fp32_v = src_value->cast<FP32ImmPtr>();
180     return Cast<float>(fp32_v->value(), dst_type);
181   } else if (src_value->isa<Int32Imm>()) {
182     const auto &int32_v = src_value->cast<Int32ImmPtr>();
183     return Cast<int32_t>(int32_v->value(), dst_type);
184   } else if (src_value->isa<FP64Imm>()) {
185     const auto &fp64_v = src_value->cast<FP64ImmPtr>();
186     return Cast<double>(fp64_v->value(), dst_type);
187   } else if (src_value->isa<BoolImm>()) {
188     const auto &bool_v = src_value->cast<BoolImmPtr>();
189     return Cast<bool>(bool_v->value(), dst_type);
190   } else if (src_value->isa<Int16Imm>()) {
191     const auto &int16_v = src_value->cast<Int16ImmPtr>();
192     return Cast<int16_t>(int16_v->value(), dst_type);
193   } else {
194     MS_LOG(DEBUG) << "Now, the value [" << src_value->ToString() << "] is not supported to cast directly.";
195     return nullptr;
196   }
197 }
198 }  // namespace pynative
199 }  // namespace mindspore
200