• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "runtime/device/convert_tensor_utils.h"
17 #include <complex>
18 #include <vector>
19 namespace mindspore {
20 namespace device {
HalfToFloat(void * dst,const void * src,size_t elem_num)21 void HalfToFloat(void *dst, const void *src, size_t elem_num) {
22   if (dst == nullptr || src == nullptr) {
23     return;
24   }
25   auto half_data = static_cast<const float16 *>(src);
26   auto float_data = static_cast<float *>(dst);
27   for (size_t i = 0; i < elem_num; ++i) {
28     float tmp = half_to_float(half_data[i]);
29     float_data[i] = tmp;
30   }
31 }
32 
FloatToHalf(void * dst,const void * src,size_t elem_num)33 void FloatToHalf(void *dst, const void *src, size_t elem_num) {
34   if (dst == nullptr || src == nullptr) {
35     return;
36   }
37   auto float_data = static_cast<const float *>(src);
38   auto half_data = static_cast<float16 *>(dst);
39   for (size_t i = 0; i < elem_num; ++i) {
40     half_data[i] = float16(float_data[i]);
41   }
42 }
43 
DoubleToFloat(void * dst,const void * src,size_t elem_num)44 void DoubleToFloat(void *dst, const void *src, size_t elem_num) {
45   if (dst == nullptr || src == nullptr) {
46     return;
47   }
48   auto double_data = static_cast<const double *>(src);
49   auto float_data = static_cast<float *>(dst);
50   for (size_t i = 0; i < elem_num; ++i) {
51     float_data[i] = static_cast<float>(double_data[i]);
52   }
53 }
54 
FloatToDouble(void * dst,const void * src,size_t elem_num)55 void FloatToDouble(void *dst, const void *src, size_t elem_num) {
56   if (dst == nullptr || src == nullptr) {
57     return;
58   }
59   auto float_data = static_cast<const float *>(src);
60   auto double_data = static_cast<double *>(dst);
61   for (size_t i = 0; i < elem_num; ++i) {
62     double_data[i] = static_cast<double>(float_data[i]);
63   }
64 }
65 
ShortToInt(void * dst,const void * src,size_t elem_num)66 void ShortToInt(void *dst, const void *src, size_t elem_num) {
67   if (dst == nullptr || src == nullptr) {
68     return;
69   }
70   auto half_data = static_cast<const int16_t *>(src);
71   auto int_data = static_cast<int *>(dst);
72   for (size_t i = 0; i < elem_num; ++i) {
73     int_data[i] = static_cast<int>(half_data[i]);
74   }
75 }
76 
IntToShort(void * dst,const void * src,size_t elem_num)77 void IntToShort(void *dst, const void *src, size_t elem_num) {
78   if (dst == nullptr || src == nullptr) {
79     return;
80   }
81   auto int_data = static_cast<const int *>(src);
82   auto half_data = static_cast<int16_t *>(dst);
83   for (size_t i = 0; i < elem_num; ++i) {
84     half_data[i] = static_cast<int16_t>(int_data[i]);
85   }
86 }
87 
LongToInt(void * dst,const void * src,size_t elem_num)88 void LongToInt(void *dst, const void *src, size_t elem_num) {
89   if (dst == nullptr || src == nullptr) {
90     return;
91   }
92   auto long_data = static_cast<const int64_t *>(src);
93   auto int_data = static_cast<int *>(dst);
94   for (size_t i = 0; i < elem_num; ++i) {
95     int_data[i] = static_cast<int>(long_data[i]);
96   }
97 }
98 
IntToLong(void * dst,const void * src,size_t elem_num)99 void IntToLong(void *dst, const void *src, size_t elem_num) {
100   if (dst == nullptr || src == nullptr) {
101     return;
102   }
103   auto int_data = static_cast<const int *>(src);
104   auto long_data = static_cast<int64_t *>(dst);
105   for (size_t i = 0; i < elem_num; ++i) {
106     long_data[i] = static_cast<int64_t>(int_data[i]);
107   }
108 }
109 
ConvertSameType(void * const dst,const void * src,size_t size,TypeId type)110 void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type) {
111   if (dst == nullptr || src == nullptr) {
112     return;
113   }
114   if (type == kNumberTypeFloat16) {
115     auto dst_data = static_cast<float16 *>(dst);
116     auto src_data = static_cast<const float16 *>(src);
117     ConvertSameType(dst_data, src_data, size >> 1);
118   } else if (type == kNumberTypeFloat32) {
119     auto dst_data = static_cast<float *>(dst);
120     auto src_data = static_cast<const float *>(src);
121     ConvertSameType(dst_data, src_data, size / sizeof(float));
122   } else if (type == kNumberTypeFloat64) {
123     auto dst_data = static_cast<double *>(dst);
124     auto src_data = static_cast<const double *>(src);
125     ConvertSameType(dst_data, src_data, size / sizeof(double));
126   } else if (type == kNumberTypeBFloat16) {
127     auto dst_data = static_cast<bfloat16 *>(dst);
128     auto src_data = static_cast<const bfloat16 *>(src);
129     ConvertSameType(dst_data, src_data, size >> 1);
130   } else if (type == kNumberTypeInt8) {
131     auto dst_data = static_cast<int8_t *>(dst);
132     auto src_data = static_cast<const int8_t *>(src);
133     ConvertSameType(dst_data, src_data, size / sizeof(int8_t));
134   } else if (type == kNumberTypeInt16) {
135     auto dst_data = static_cast<int16_t *>(dst);
136     auto src_data = static_cast<const int16_t *>(src);
137     ConvertSameType(dst_data, src_data, size >> 1);
138   } else if (type == kNumberTypeInt32) {
139     auto dst_data = static_cast<int *>(dst);
140     auto src_data = static_cast<const int *>(src);
141     ConvertSameType(dst_data, src_data, size / sizeof(int));
142   } else if (type == kNumberTypeInt64) {
143     auto dst_data = static_cast<int64_t *>(dst);
144     auto src_data = static_cast<const int64_t *>(src);
145     ConvertSameType(dst_data, src_data, size / sizeof(int64_t));
146   } else if (type == kNumberTypeBool) {
147     auto dst_data = static_cast<bool *>(dst);
148     auto src_data = static_cast<const bool *>(src);
149     ConvertSameType(dst_data, src_data, size / sizeof(bool));
150   } else if (type == kNumberTypeUInt8) {
151     auto dst_data = static_cast<uint8_t *>(dst);
152     auto src_data = static_cast<const uint8_t *>(src);
153     ConvertSameType(dst_data, src_data, size / sizeof(uint8_t));
154   } else if (type == kNumberTypeUInt16) {
155     auto dst_data = static_cast<uint16_t *>(dst);
156     auto src_data = static_cast<const uint16_t *>(src);
157     ConvertSameType(dst_data, src_data, size / sizeof(uint16_t));
158   } else if (type == kNumberTypeUInt32) {
159     auto dst_data = static_cast<uint32_t *>(dst);
160     auto src_data = static_cast<const uint32_t *>(src);
161     ConvertSameType(dst_data, src_data, size / sizeof(uint32_t));
162   } else if (type == kNumberTypeUInt64) {
163     auto dst_data = static_cast<uint64_t *>(dst);
164     auto src_data = static_cast<const uint64_t *>(src);
165     ConvertSameType(dst_data, src_data, size / sizeof(uint64_t));
166   } else if (type == kNumberTypeComplex64) {
167     auto dst_data = static_cast<std::complex<float> *>(dst);
168     auto src_data = static_cast<const std::complex<float> *>(src);
169     ConvertSameType(dst_data, src_data, size / sizeof(std::complex<float>));
170   } else if (type == kNumberTypeComplex128) {
171     auto dst_data = static_cast<std::complex<double> *>(dst);
172     auto src_data = static_cast<const std::complex<double> *>(src);
173     ConvertSameType(dst_data, src_data, size / sizeof(std::complex<double>));
174   } else {
175     MS_LOG(EXCEPTION) << "Invalid Type: " << TypeIdLabel(type);
176   }
177 }
178 }  // namespace device
179 }  // namespace mindspore
180