1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/forward/cast_base.h"
18 #include <memory>
19 #include <algorithm>
20 #include "ops/array_ops.h"
21 #include "frontend/operator/composite/do_signature.h"
22
23 namespace mindspore {
24 namespace pynative {
25 namespace {
26 const char kOpsFunctionModelName[] = "mindspore.ops.functional";
27
28 template <typename S>
CastScalarToScalar(S in,const TypeId & type_id)29 ValuePtr CastScalarToScalar(S in, const TypeId &type_id) {
30 switch (type_id) {
31 case kNumberTypeInt32:
32 return MakeValue(static_cast<int>(in));
33 case kNumberTypeFloat16:
34 return MakeValue(static_cast<float16>(in).int_value());
35 case kNumberTypeFloat32:
36 return MakeValue(static_cast<float>(in));
37 case kNumberTypeBool:
38 return MakeValue(static_cast<bool>(in));
39 case kNumberTypeInt64:
40 return MakeValue(static_cast<int64_t>(in));
41 case kNumberTypeFloat64:
42 return MakeValue(static_cast<double>(in));
43 case kNumberTypeInt16:
44 return MakeValue(static_cast<int16_t>(in));
45 case kNumberTypeInt8:
46 return MakeValue(static_cast<int8_t>(in));
47 case kNumberTypeUInt64:
48 return MakeValue(static_cast<uint64_t>(in));
49 case kNumberTypeUInt32:
50 return MakeValue(static_cast<uint32_t>(in));
51 case kNumberTypeUInt16:
52 return MakeValue(static_cast<uint16_t>(in));
53 case kNumberTypeUInt8:
54 return MakeValue(static_cast<uint8_t>(in));
55 case kNumberTypeBFloat16:
56 return MakeValue(static_cast<float16>(in).int_value());
57 default:
58 MS_LOG(DEBUG) << "Not support cast to dst type: " << TypeIdToType(type_id)->ToString();
59 return nullptr;
60 }
61 }
62
63 template <typename S>
CastScalarToTensor(S in,const TypeId & type_id)64 ValuePtr CastScalarToTensor(S in, const TypeId &type_id) {
65 switch (type_id) {
66 case kNumberTypeInt32:
67 return std::make_shared<tensor::Tensor>(static_cast<int>(in), kInt32);
68 case kNumberTypeFloat16:
69 return std::make_shared<tensor::Tensor>(static_cast<float16>(in), kFloat16);
70 case kNumberTypeFloat32:
71 return std::make_shared<tensor::Tensor>(static_cast<float>(in), kFloat32);
72 case kNumberTypeBool:
73 return std::make_shared<tensor::Tensor>(static_cast<bool>(in), kBool);
74 case kNumberTypeInt64:
75 return std::make_shared<tensor::Tensor>(static_cast<int64_t>(in), kInt64);
76 case kNumberTypeFloat64:
77 return std::make_shared<tensor::Tensor>(static_cast<double>(in), kFloat64);
78 case kNumberTypeInt16:
79 return std::make_shared<tensor::Tensor>(static_cast<int16_t>(in), kInt16);
80 case kNumberTypeInt8:
81 return std::make_shared<tensor::Tensor>(static_cast<int8_t>(in), kInt8);
82 case kNumberTypeUInt64:
83 return std::make_shared<tensor::Tensor>(static_cast<uint64_t>(in), kUInt64);
84 case kNumberTypeUInt32:
85 return std::make_shared<tensor::Tensor>(static_cast<uint32_t>(in), kUInt32);
86 case kNumberTypeUInt16:
87 return std::make_shared<tensor::Tensor>(static_cast<uint16_t>(in), kUInt16);
88 case kNumberTypeUInt8:
89 return std::make_shared<tensor::Tensor>(static_cast<uint8_t>(in), kUInt8);
90 case kNumberTypeBFloat16:
91 return std::make_shared<tensor::Tensor>(static_cast<bfloat16>(in), kBFloat16);
92 default:
93 MS_LOG(DEBUG) << "Not support cast to dst type: " << TypeIdToType(type_id)->ToString();
94 return nullptr;
95 }
96 }
97
98 template <typename S>
Cast(S in,const std::pair<TypeId,bool> & dst_type)99 ValuePtr Cast(S in, const std::pair<TypeId, bool> &dst_type) {
100 bool has_tensor_input = dst_type.second;
101 if (has_tensor_input) {
102 return CastScalarToTensor(in, dst_type.first);
103 }
104 return CastScalarToScalar(in, dst_type.first);
105 }
106 } // namespace
107
GetPrimByTypeId(const TypeId & type_id) const108 PrimitivePtr CastBaseOperation::GetPrimByTypeId(const TypeId &type_id) const {
109 const auto &iter = type_prim_cache_.find(type_id);
110 if (iter != type_prim_cache_.end()) {
111 return iter->second;
112 }
113
114 #ifndef ENABLE_TEST
115 auto cast_prim = std::make_shared<Primitive>(kCastOpName);
116 std::vector<std::string> input_names = {"x", "dst_type"};
117 std::vector<std::string> output_names = {"output"};
118 (void)cast_prim->AddAttr("input_names", MakeValue(input_names));
119 (void)cast_prim->AddAttr("output_names", MakeValue(output_names));
120 type_prim_cache_[type_id] = cast_prim;
121 cast_prim->EnableSharedMutex();
122 return cast_prim;
123 #else
124 py::gil_scoped_acquire gil;
125 const auto &cast_prim = python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
126 auto prim_adapter = cast_prim.cast<PrimitivePyAdapterPtr>();
127 MS_EXCEPTION_IF_NULL(prim_adapter);
128 auto primitive = prim_adapter->attached_primitive();
129 if (primitive == nullptr) {
130 primitive = std::make_shared<PrimitivePy>(cast_prim);
131 prim_adapter->set_attached_primitive(primitive);
132 }
133 if (!primitive->HasPyObj()) {
134 MS_LOG(EXCEPTION) << "Pyobj is empty";
135 }
136 type_prim_cache_[type_id] = primitive;
137 primitive->EnableSharedMutex();
138 return primitive;
139 #endif
140 }
141
GetSignatureType(const std::vector<Signature> & signatures,std::vector<SignatureEnumDType> * dtypes) const142 bool CastBaseOperation::GetSignatureType(const std::vector<Signature> &signatures,
143 std::vector<SignatureEnumDType> *dtypes) const {
144 MS_EXCEPTION_IF_NULL(dtypes);
145 bool has_sig_dtype = false;
146 (void)std::transform(signatures.begin(), signatures.end(), std::back_inserter(*dtypes),
147 [&has_sig_dtype](const Signature &sig) {
148 auto dtype = sig.dtype;
149 if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) {
150 has_sig_dtype = true;
151 }
152 return dtype;
153 });
154 return has_sig_dtype;
155 }
156
TensorToDstDtypeValue(const ValuePtr & src_value,const TypeId & dst_type_id) const157 tensor::BaseTensorPtr CastBaseOperation::TensorToDstDtypeValue(const ValuePtr &src_value,
158 const TypeId &dst_type_id) const {
159 MS_EXCEPTION_IF_NULL(src_value);
160 auto src_tensor = src_value->cast<tensor::BaseTensorPtr>();
161 MS_EXCEPTION_IF_NULL(src_tensor);
162 (void)src_tensor->set_data_type(dst_type_id);
163 return src_tensor;
164 }
165
166 // This function is used to convert scalar value to another scalar value with destination data type.
167 // The scope of scalar type includes common data types, such as `FP64`, `FP32`, `FP16, `Int64`, `Int32`, ...
168 // The following sort is based on the hot spots of the data type.
ScalarToDstDtypeValue(const ValuePtr & src_value,const std::pair<TypeId,bool> & dst_type) const169 ValuePtr CastBaseOperation::ScalarToDstDtypeValue(const ValuePtr &src_value,
170 const std::pair<TypeId, bool> &dst_type) const {
171 MS_EXCEPTION_IF_NULL(src_value);
172 // Tensor not do scalar cast
173 if (src_value->isa<tensor::BaseTensor>()) {
174 return nullptr;
175 } else if (src_value->isa<Int64Imm>()) {
176 const auto &int64_v = src_value->cast<Int64ImmPtr>();
177 return Cast<int64_t>(int64_v->value(), dst_type);
178 } else if (src_value->isa<FP32Imm>()) {
179 const auto &fp32_v = src_value->cast<FP32ImmPtr>();
180 return Cast<float>(fp32_v->value(), dst_type);
181 } else if (src_value->isa<Int32Imm>()) {
182 const auto &int32_v = src_value->cast<Int32ImmPtr>();
183 return Cast<int32_t>(int32_v->value(), dst_type);
184 } else if (src_value->isa<FP64Imm>()) {
185 const auto &fp64_v = src_value->cast<FP64ImmPtr>();
186 return Cast<double>(fp64_v->value(), dst_type);
187 } else if (src_value->isa<BoolImm>()) {
188 const auto &bool_v = src_value->cast<BoolImmPtr>();
189 return Cast<bool>(bool_v->value(), dst_type);
190 } else if (src_value->isa<Int16Imm>()) {
191 const auto &int16_v = src_value->cast<Int16ImmPtr>();
192 return Cast<int16_t>(int16_v->value(), dst_type);
193 } else {
194 MS_LOG(DEBUG) << "Now, the value [" << src_value->ToString() << "] is not supported to cast directly.";
195 return nullptr;
196 }
197 }
198 } // namespace pynative
199 } // namespace mindspore
200