1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "runtime/device/convert_tensor_utils.h"
17 #include <vector>
18 namespace mindspore {
19 namespace device {
HalfToFloat(void * dst,const void * src,size_t elem_num)20 void HalfToFloat(void *dst, const void *src, size_t elem_num) {
21 if (dst == nullptr || src == nullptr) {
22 return;
23 }
24 auto half_data = static_cast<const float16 *>(src);
25 auto float_data = static_cast<float *>(dst);
26 for (size_t i = 0; i < elem_num; ++i) {
27 float tmp = half_to_float(half_data[i]);
28 float_data[i] = tmp;
29 }
30 }
31
FloatToHalf(void * dst,const void * src,size_t elem_num)32 void FloatToHalf(void *dst, const void *src, size_t elem_num) {
33 if (dst == nullptr || src == nullptr) {
34 return;
35 }
36 auto float_data = static_cast<const float *>(src);
37 auto half_data = static_cast<float16 *>(dst);
38 for (size_t i = 0; i < elem_num; ++i) {
39 half_data[i] = float16(float_data[i]);
40 }
41 }
42
DoubleToFloat(void * dst,const void * src,size_t elem_num)43 void DoubleToFloat(void *dst, const void *src, size_t elem_num) {
44 if (dst == nullptr || src == nullptr) {
45 return;
46 }
47 auto double_data = static_cast<const double *>(src);
48 auto float_data = static_cast<float *>(dst);
49 for (size_t i = 0; i < elem_num; ++i) {
50 float_data[i] = static_cast<float>(double_data[i]);
51 }
52 }
53
FloatToDouble(void * dst,const void * src,size_t elem_num)54 void FloatToDouble(void *dst, const void *src, size_t elem_num) {
55 if (dst == nullptr || src == nullptr) {
56 return;
57 }
58 auto float_data = static_cast<const float *>(src);
59 auto double_data = static_cast<double *>(dst);
60 for (size_t i = 0; i < elem_num; ++i) {
61 double_data[i] = static_cast<double>(float_data[i]);
62 }
63 }
64
ShortToInt(void * dst,const void * src,size_t elem_num)65 void ShortToInt(void *dst, const void *src, size_t elem_num) {
66 if (dst == nullptr || src == nullptr) {
67 return;
68 }
69 auto half_data = static_cast<const int16_t *>(src);
70 auto int_data = static_cast<int *>(dst);
71 for (size_t i = 0; i < elem_num; ++i) {
72 int_data[i] = static_cast<int>(half_data[i]);
73 }
74 }
75
IntToShort(void * dst,const void * src,size_t elem_num)76 void IntToShort(void *dst, const void *src, size_t elem_num) {
77 if (dst == nullptr || src == nullptr) {
78 return;
79 }
80 auto int_data = static_cast<const int *>(src);
81 auto half_data = static_cast<int16_t *>(dst);
82 for (size_t i = 0; i < elem_num; ++i) {
83 half_data[i] = static_cast<int16_t>(int_data[i]);
84 }
85 }
86
LongToInt(void * dst,const void * src,size_t elem_num)87 void LongToInt(void *dst, const void *src, size_t elem_num) {
88 if (dst == nullptr || src == nullptr) {
89 return;
90 }
91 auto long_data = static_cast<const int64_t *>(src);
92 auto int_data = static_cast<int *>(dst);
93 for (size_t i = 0; i < elem_num; ++i) {
94 int_data[i] = static_cast<int>(long_data[i]);
95 }
96 }
97
IntToLong(void * dst,const void * src,size_t elem_num)98 void IntToLong(void *dst, const void *src, size_t elem_num) {
99 if (dst == nullptr || src == nullptr) {
100 return;
101 }
102 auto int_data = static_cast<const int *>(src);
103 auto long_data = static_cast<int64_t *>(dst);
104 for (size_t i = 0; i < elem_num; ++i) {
105 long_data[i] = static_cast<int64_t>(int_data[i]);
106 }
107 }
108
ConvertSameType(void * const dst,const void * src,size_t size,TypeId type)109 void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type) {
110 if (dst == nullptr || src == nullptr) {
111 return;
112 }
113 if (type == kNumberTypeFloat16) {
114 auto dst_data = static_cast<float16 *>(dst);
115 auto src_data = static_cast<const float16 *>(src);
116 ConvertSameType(dst_data, src_data, size >> 1);
117 } else if (type == kNumberTypeFloat32) {
118 auto dst_data = static_cast<float *>(dst);
119 auto src_data = static_cast<const float *>(src);
120 ConvertSameType(dst_data, src_data, size / sizeof(float));
121 } else if (type == kNumberTypeFloat64) {
122 auto dst_data = static_cast<double *>(dst);
123 auto src_data = static_cast<const double *>(src);
124 ConvertSameType(dst_data, src_data, size / sizeof(double));
125 } else if (type == kNumberTypeInt16) {
126 auto dst_data = static_cast<int16_t *>(dst);
127 auto src_data = static_cast<const int16_t *>(src);
128 ConvertSameType(dst_data, src_data, size >> 1);
129 } else if (type == kNumberTypeInt32) {
130 auto dst_data = static_cast<int *>(dst);
131 auto src_data = static_cast<const int *>(src);
132 ConvertSameType(dst_data, src_data, size / sizeof(int));
133 } else if (type == kNumberTypeInt64) {
134 auto dst_data = static_cast<int64_t *>(dst);
135 auto src_data = static_cast<const int64_t *>(src);
136 ConvertSameType(dst_data, src_data, size / sizeof(int64_t));
137 } else {
138 MS_LOG(EXCEPTION) << "Invalid Type: " << TypeIdLabel(type);
139 }
140 }
141 } // namespace device
142 } // namespace mindspore
143