1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "runtime/device/convert_tensor_utils.h"
17 #include <complex>
18 #include <vector>
19 namespace mindspore {
20 namespace device {
HalfToFloat(void * dst,const void * src,size_t elem_num)21 void HalfToFloat(void *dst, const void *src, size_t elem_num) {
22 if (dst == nullptr || src == nullptr) {
23 return;
24 }
25 auto half_data = static_cast<const float16 *>(src);
26 auto float_data = static_cast<float *>(dst);
27 for (size_t i = 0; i < elem_num; ++i) {
28 float tmp = half_to_float(half_data[i]);
29 float_data[i] = tmp;
30 }
31 }
32
FloatToHalf(void * dst,const void * src,size_t elem_num)33 void FloatToHalf(void *dst, const void *src, size_t elem_num) {
34 if (dst == nullptr || src == nullptr) {
35 return;
36 }
37 auto float_data = static_cast<const float *>(src);
38 auto half_data = static_cast<float16 *>(dst);
39 for (size_t i = 0; i < elem_num; ++i) {
40 half_data[i] = float16(float_data[i]);
41 }
42 }
43
DoubleToFloat(void * dst,const void * src,size_t elem_num)44 void DoubleToFloat(void *dst, const void *src, size_t elem_num) {
45 if (dst == nullptr || src == nullptr) {
46 return;
47 }
48 auto double_data = static_cast<const double *>(src);
49 auto float_data = static_cast<float *>(dst);
50 for (size_t i = 0; i < elem_num; ++i) {
51 float_data[i] = static_cast<float>(double_data[i]);
52 }
53 }
54
FloatToDouble(void * dst,const void * src,size_t elem_num)55 void FloatToDouble(void *dst, const void *src, size_t elem_num) {
56 if (dst == nullptr || src == nullptr) {
57 return;
58 }
59 auto float_data = static_cast<const float *>(src);
60 auto double_data = static_cast<double *>(dst);
61 for (size_t i = 0; i < elem_num; ++i) {
62 double_data[i] = static_cast<double>(float_data[i]);
63 }
64 }
65
ShortToInt(void * dst,const void * src,size_t elem_num)66 void ShortToInt(void *dst, const void *src, size_t elem_num) {
67 if (dst == nullptr || src == nullptr) {
68 return;
69 }
70 auto half_data = static_cast<const int16_t *>(src);
71 auto int_data = static_cast<int *>(dst);
72 for (size_t i = 0; i < elem_num; ++i) {
73 int_data[i] = static_cast<int>(half_data[i]);
74 }
75 }
76
IntToShort(void * dst,const void * src,size_t elem_num)77 void IntToShort(void *dst, const void *src, size_t elem_num) {
78 if (dst == nullptr || src == nullptr) {
79 return;
80 }
81 auto int_data = static_cast<const int *>(src);
82 auto half_data = static_cast<int16_t *>(dst);
83 for (size_t i = 0; i < elem_num; ++i) {
84 half_data[i] = static_cast<int16_t>(int_data[i]);
85 }
86 }
87
LongToInt(void * dst,const void * src,size_t elem_num)88 void LongToInt(void *dst, const void *src, size_t elem_num) {
89 if (dst == nullptr || src == nullptr) {
90 return;
91 }
92 auto long_data = static_cast<const int64_t *>(src);
93 auto int_data = static_cast<int *>(dst);
94 for (size_t i = 0; i < elem_num; ++i) {
95 int_data[i] = static_cast<int>(long_data[i]);
96 }
97 }
98
IntToLong(void * dst,const void * src,size_t elem_num)99 void IntToLong(void *dst, const void *src, size_t elem_num) {
100 if (dst == nullptr || src == nullptr) {
101 return;
102 }
103 auto int_data = static_cast<const int *>(src);
104 auto long_data = static_cast<int64_t *>(dst);
105 for (size_t i = 0; i < elem_num; ++i) {
106 long_data[i] = static_cast<int64_t>(int_data[i]);
107 }
108 }
109
ConvertSameType(void * const dst,const void * src,size_t size,TypeId type)110 void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type) {
111 if (dst == nullptr || src == nullptr) {
112 return;
113 }
114 if (type == kNumberTypeFloat16) {
115 auto dst_data = static_cast<float16 *>(dst);
116 auto src_data = static_cast<const float16 *>(src);
117 ConvertSameType(dst_data, src_data, size >> 1);
118 } else if (type == kNumberTypeFloat32) {
119 auto dst_data = static_cast<float *>(dst);
120 auto src_data = static_cast<const float *>(src);
121 ConvertSameType(dst_data, src_data, size / sizeof(float));
122 } else if (type == kNumberTypeFloat64) {
123 auto dst_data = static_cast<double *>(dst);
124 auto src_data = static_cast<const double *>(src);
125 ConvertSameType(dst_data, src_data, size / sizeof(double));
126 } else if (type == kNumberTypeBFloat16) {
127 auto dst_data = static_cast<bfloat16 *>(dst);
128 auto src_data = static_cast<const bfloat16 *>(src);
129 ConvertSameType(dst_data, src_data, size >> 1);
130 } else if (type == kNumberTypeInt8) {
131 auto dst_data = static_cast<int8_t *>(dst);
132 auto src_data = static_cast<const int8_t *>(src);
133 ConvertSameType(dst_data, src_data, size / sizeof(int8_t));
134 } else if (type == kNumberTypeInt16) {
135 auto dst_data = static_cast<int16_t *>(dst);
136 auto src_data = static_cast<const int16_t *>(src);
137 ConvertSameType(dst_data, src_data, size >> 1);
138 } else if (type == kNumberTypeInt32) {
139 auto dst_data = static_cast<int *>(dst);
140 auto src_data = static_cast<const int *>(src);
141 ConvertSameType(dst_data, src_data, size / sizeof(int));
142 } else if (type == kNumberTypeInt64) {
143 auto dst_data = static_cast<int64_t *>(dst);
144 auto src_data = static_cast<const int64_t *>(src);
145 ConvertSameType(dst_data, src_data, size / sizeof(int64_t));
146 } else if (type == kNumberTypeBool) {
147 auto dst_data = static_cast<bool *>(dst);
148 auto src_data = static_cast<const bool *>(src);
149 ConvertSameType(dst_data, src_data, size / sizeof(bool));
150 } else if (type == kNumberTypeUInt8) {
151 auto dst_data = static_cast<uint8_t *>(dst);
152 auto src_data = static_cast<const uint8_t *>(src);
153 ConvertSameType(dst_data, src_data, size / sizeof(uint8_t));
154 } else if (type == kNumberTypeUInt16) {
155 auto dst_data = static_cast<uint16_t *>(dst);
156 auto src_data = static_cast<const uint16_t *>(src);
157 ConvertSameType(dst_data, src_data, size / sizeof(uint16_t));
158 } else if (type == kNumberTypeUInt32) {
159 auto dst_data = static_cast<uint32_t *>(dst);
160 auto src_data = static_cast<const uint32_t *>(src);
161 ConvertSameType(dst_data, src_data, size / sizeof(uint32_t));
162 } else if (type == kNumberTypeUInt64) {
163 auto dst_data = static_cast<uint64_t *>(dst);
164 auto src_data = static_cast<const uint64_t *>(src);
165 ConvertSameType(dst_data, src_data, size / sizeof(uint64_t));
166 } else if (type == kNumberTypeComplex64) {
167 auto dst_data = static_cast<std::complex<float> *>(dst);
168 auto src_data = static_cast<const std::complex<float> *>(src);
169 ConvertSameType(dst_data, src_data, size / sizeof(std::complex<float>));
170 } else if (type == kNumberTypeComplex128) {
171 auto dst_data = static_cast<std::complex<double> *>(dst);
172 auto src_data = static_cast<const std::complex<double> *>(src);
173 ConvertSameType(dst_data, src_data, size / sizeof(std::complex<double>));
174 } else {
175 MS_LOG(EXCEPTION) << "Invalid Type: " << TypeIdLabel(type);
176 }
177 }
178 } // namespace device
179 } // namespace mindspore
180