1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
18
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/bfloat16.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/types.h"
26
27 // Note that the GPU cast functor templates need to be instantiated unlike the
28 // CPU ones, and hence their specializations are different than that for CPUs.
29 #ifdef SPECIALIZE_FOR_GPUS
30 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
31 template <typename Device> \
32 struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \
33 void operator()(const Device& d, \
34 typename TTypes<OUT_TYPE>::Flat out_tensor, \
35 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
36 bool truncate = false) { \
37 if (truncate) { \
38 out_tensor.device(d) = \
39 in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
40 .template cast<OUT_TYPE>(); \
41 } else { \
42 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
43 } \
44 } \
45 }; \
46 template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
47 #else
48 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
49 template <> \
50 struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \
51 void operator()(const DEVICE& d, \
52 typename TTypes<OUT_TYPE>::Flat out_tensor, \
53 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
54 bool truncate = false) { \
55 if (truncate) { \
56 out_tensor.device(d) = \
57 in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
58 .template cast<OUT_TYPE>(); \
59 } else { \
60 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
61 } \
62 } \
63 };
64 #endif
65
66 #define CAST_FUNCTORS(devname) \
67 SPECIALIZE_CAST(devname, float, double) \
68 SPECIALIZE_CAST(devname, float, std::complex<double>) \
69 SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
70 SPECIALIZE_CAST(devname, std::complex<float>, double) \
71 SPECIALIZE_CAST(devname, Eigen::half, double) \
72 SPECIALIZE_CAST(devname, Eigen::half, float) \
73 SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
74 SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
75 SPECIALIZE_CAST(devname, bfloat16, float) \
76 template <typename OUT_TYPE, typename IN_OUT> \
77 struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
78 void operator()(const devname& d, \
79 typename TTypes<OUT_TYPE>::Flat out_tensor, \
80 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
81 bool truncate = false) { \
82 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
83 } \
84 };
85
86 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
87 // If MLIR kernels are enabled, we don't need the specialized cast from float to
88 // double or from Eigen::half to double. We still need the specialized cast from
89 // Eigen::half to float, because it is used in depthwise_conv_grad_op.cc. We
90 // still need the specialized cast from float to double because it is used in
91 // resize_bilinear_op.cc.
92 #define CAST_FUNCTORS_SUBSET(devname) \
93 SPECIALIZE_CAST(devname, float, double) \
94 SPECIALIZE_CAST(devname, float, std::complex<double>) \
95 SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
96 SPECIALIZE_CAST(devname, std::complex<float>, double) \
97 SPECIALIZE_CAST(devname, Eigen::half, float) \
98 SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
99 SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
100 SPECIALIZE_CAST(devname, bfloat16, float) \
101 template <typename OUT_TYPE, typename IN_OUT> \
102 struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
103 void operator()(const devname& d, \
104 typename TTypes<OUT_TYPE>::Flat out_tensor, \
105 typename TTypes<IN_OUT>::ConstFlat in_tensor, \
106 bool truncate = false) { \
107 out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
108 } \
109 };
110 #endif
111
112 namespace tensorflow {
113
114 typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
115 bool trunc)>
116 CastFunctorType;
117
118 // Common base class of Cast kernels
119 class CastOpBase : public OpKernel {
120 public:
121 explicit CastOpBase(OpKernelConstruction* ctx);
122
123 void Compute(OpKernelContext* ctx) override;
124
125 protected:
126 DataType src_dtype_;
127 DataType dst_dtype_;
128 DataType external_src_dtype_;
129 DataType external_dst_dtype_;
130 bool use_truncation_;
131 CastFunctorType work_ = nullptr;
132 Status Unimplemented();
133
134 TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
135 };
136
137 // CPU implementation of Cast
138 class CpuCastOp : public CastOpBase {
139 public:
140 explicit CpuCastOp(OpKernelConstruction* ctx);
141
142 private:
143 Status Prepare();
144 };
145
146 namespace functor {
147
148 template <typename I>
MantissaWidth()149 constexpr int MantissaWidth() {
150 return std::numeric_limits<I>::digits;
151 }
152
153 template <>
154 constexpr int MantissaWidth<Eigen::half>() {
155 // Remember, there's 1 hidden bit
156 return 10 + 1;
157 }
158
159 template <>
160 constexpr int MantissaWidth<bfloat16>() {
161 // Remember, there's 1 hidden bit
162 return 7 + 1;
163 }
164
165 template <typename Device, typename Tout, typename Tin>
Cast(const Device & d,typename TTypes<Tout>::Flat o,typename TTypes<Tin>::ConstFlat i)166 void Cast(const Device& d, typename TTypes<Tout>::Flat o,
167 typename TTypes<Tin>::ConstFlat i) {
168 o.device(d) = i.template cast<Tout>();
169 }
170
171 template <typename Device, typename Tout, typename Tin>
172 struct CastFunctor {
173 void operator()(const Device& d, typename TTypes<Tout>::Flat o,
174 typename TTypes<Tin>::ConstFlat i, bool truncate = false);
175 };
176
177 // Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
178 // Specialize for others if needed in future.
179 template <typename I>
180 typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)181 EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
182 // Only zero the bits for non-NaNs.
183 // For NaNs, let the non-truncation version handle it.
184 if (!std::isnan(t)) {
185 uint64_t* p = reinterpret_cast<uint64_t*>(&t);
186 *p &= (0xFFFFFFFFFFFFFFFF << n);
187 }
188 }
189
190 template <typename I>
191 typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)192 EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
193 // Only zero the bits for non-NaNs.
194 // For NaNs, let the non-truncation version handle it.
195 if (!std::isnan(t)) {
196 uint32_t* p = reinterpret_cast<uint32_t*>(&t);
197 *p &= (0xFFFFFFFF << n);
198 }
199 }
200
201 // Set n least significant bits to 0
202 template <typename I, typename O>
203 struct LSBZeroSetter {
EIGEN_EMPTY_STRUCT_CTORLSBZeroSetter204 EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
205
206 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
207 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
208 static_assert(
209 bits > 0,
210 "The output type must have fewer mantissa bits than the input type\n");
211 I t = a;
212 LSBZeroSetterHelper(t, bits);
213 return t;
214 }
215 };
216
217 template <typename I, typename O>
218 struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
219 EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
220
221 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
222 const std::complex<I>& a) const {
223 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
224 static_assert(
225 bits > 0,
226 "The output type must have fewer mantissa bits than the input type\n");
227 I re = std::real(a);
228 I img = std::imag(a);
229 LSBZeroSetterHelper(re, bits);
230 LSBZeroSetterHelper(img, bits);
231 std::complex<I> toReturn(re, img);
232 return toReturn;
233 }
234 };
235
236 template <typename I, typename O>
237 struct LSBZeroSetter<std::complex<I>, O> {
238 EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
239 // Sets the 16 LSBits of the float to 0
240 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
241 const std::complex<I>& a) const {
242 constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
243 static_assert(
244 bits > 0,
245 "The output type must have fewer mantissa bits than the input type\n");
246 I re = std::real(a);
247 I img = std::imag(a);
248 LSBZeroSetterHelper(re, bits);
249 LSBZeroSetterHelper(img, bits);
250 std::complex<I> toReturn(re, img);
251 return toReturn;
252 }
253 };
254
255 } // end namespace functor
256 } // end namespace tensorflow
257
258 namespace Eigen {
259 namespace internal {
260
261 // Eigen can't convert to/from complex numbers, because it is limited to cases
262 // that can be static_casted. But numpy is able to cast to/from complex, which
263 // we want to replicate. So we add specializations for complex here.
264 template <typename From, typename To>
265 struct scalar_cast_op<std::complex<From>, To> {
266 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
267 operator()(const std::complex<From>& a) const {
268 // Replicate numpy behavior of returning just the real part
269 return static_cast<To>(a.real());
270 }
271 };
272
273 template <typename From, typename To>
274 struct scalar_cast_op<From, std::complex<To>> {
275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
276 const From& a) const {
277 // Replicate numpy behavior of setting the imaginary part to 0
278 return std::complex<To>(static_cast<To>(a), To(0));
279 }
280 };
281
282 template <typename From, typename To>
283 struct scalar_cast_op<std::complex<From>, std::complex<To>> {
284 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
285 const std::complex<From>& a) const {
286 return std::complex<To>(static_cast<To>(a.real()),
287 static_cast<To>(a.imag()));
288 }
289 };
290
291 template <typename From, typename To>
292 struct functor_traits_complex_impl {
293 enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
294 };
295
296 template <typename From, typename To>
297 struct functor_traits<scalar_cast_op<std::complex<From>, To>>
298 : functor_traits_complex_impl<std::complex<From>, To> {};
299 template <typename From, typename To>
300 struct functor_traits<scalar_cast_op<From, std::complex<To>>>
301 : functor_traits_complex_impl<From, std::complex<To>> {};
302 // Needed to avoid ambiguous partial specialization
303 template <typename From, typename To>
304 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
305 : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
306
307 } // namespace internal
308 } // namespace Eigen
309
310 #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
311