• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/bfloat16.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 // Note that the GPU cast functor templates need to be instantiated unlike the
28 // CPU ones, and hence their specializations are different than that for CPUs.
29 #ifdef SPECIALIZE_FOR_GPUS
30 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT)                   \
31   template <typename Device>                                        \
32   struct CastFunctor<Device, OUT_TYPE, IN_OUT> {                    \
33     void operator()(const Device& d,                                \
34                     typename TTypes<OUT_TYPE>::Flat out_tensor,     \
35                     typename TTypes<IN_OUT>::ConstFlat in_tensor,   \
36                     bool truncate = false) {                        \
37       if (truncate) {                                               \
38         out_tensor.device(d) =                                      \
39             in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>())  \
40                 .template cast<OUT_TYPE>();                         \
41       } else {                                                      \
42         out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
43       }                                                             \
44     }                                                               \
45   };                                                                \
46   template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
47 #else
48 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT)                   \
49   template <>                                                       \
50   struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> {                    \
51     void operator()(const DEVICE& d,                                \
52                     typename TTypes<OUT_TYPE>::Flat out_tensor,     \
53                     typename TTypes<IN_OUT>::ConstFlat in_tensor,   \
54                     bool truncate = false) {                        \
55       if (truncate) {                                               \
56         out_tensor.device(d) =                                      \
57             in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>())  \
58                 .template cast<OUT_TYPE>();                         \
59       } else {                                                      \
60         out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
61       }                                                             \
62     }                                                               \
63   };
64 #endif
65 
66 #define CAST_FUNCTORS(devname)                                        \
67   SPECIALIZE_CAST(devname, float, double)                             \
68   SPECIALIZE_CAST(devname, float, std::complex<double>)               \
69   SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
70   SPECIALIZE_CAST(devname, std::complex<float>, double)               \
71   SPECIALIZE_CAST(devname, Eigen::half, double)                       \
72   SPECIALIZE_CAST(devname, Eigen::half, float)                        \
73   SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>)         \
74   SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>)          \
75   SPECIALIZE_CAST(devname, bfloat16, float)                           \
76   template <typename OUT_TYPE, typename IN_OUT>                       \
77   struct CastFunctor<devname, OUT_TYPE, IN_OUT> {                     \
78     void operator()(const devname& d,                                 \
79                     typename TTypes<OUT_TYPE>::Flat out_tensor,       \
80                     typename TTypes<IN_OUT>::ConstFlat in_tensor,     \
81                     bool truncate = false) {                          \
82       out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>();     \
83     }                                                                 \
84   };
85 
86 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
87 // If MLIR kernels are enabled, we don't need the specialized cast from float to
88 // double or from Eigen::half to double. We still need the specialized cast from
89 // Eigen::half to float, because it is used in depthwise_conv_grad_op.cc. We
90 // still need the specialized cast from float to double because it is used in
91 // resize_bilinear_op.cc.
92 #define CAST_FUNCTORS_SUBSET(devname)                                 \
93   SPECIALIZE_CAST(devname, float, double)                             \
94   SPECIALIZE_CAST(devname, float, std::complex<double>)               \
95   SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
96   SPECIALIZE_CAST(devname, std::complex<float>, double)               \
97   SPECIALIZE_CAST(devname, Eigen::half, float)                        \
98   SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>)         \
99   SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>)          \
100   SPECIALIZE_CAST(devname, bfloat16, float)                           \
101   template <typename OUT_TYPE, typename IN_OUT>                       \
102   struct CastFunctor<devname, OUT_TYPE, IN_OUT> {                     \
103     void operator()(const devname& d,                                 \
104                     typename TTypes<OUT_TYPE>::Flat out_tensor,       \
105                     typename TTypes<IN_OUT>::ConstFlat in_tensor,     \
106                     bool truncate = false) {                          \
107       out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>();     \
108     }                                                                 \
109   };
110 #endif
111 
112 namespace tensorflow {
113 
114 typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
115                            bool trunc)>
116     CastFunctorType;
117 
118 // Common base class of Cast kernels
119 class CastOpBase : public OpKernel {
120  public:
121   explicit CastOpBase(OpKernelConstruction* ctx);
122 
123   void Compute(OpKernelContext* ctx) override;
124 
125  protected:
126   DataType src_dtype_;
127   DataType dst_dtype_;
128   DataType external_src_dtype_;
129   DataType external_dst_dtype_;
130   bool use_truncation_;
131   CastFunctorType work_ = nullptr;
132   Status Unimplemented();
133 
134   TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
135 };
136 
137 // CPU implementation of Cast
138 class CpuCastOp : public CastOpBase {
139  public:
140   explicit CpuCastOp(OpKernelConstruction* ctx);
141 
142  private:
143   Status Prepare();
144 };
145 
146 namespace functor {
147 
148 template <typename I>
MantissaWidth()149 constexpr int MantissaWidth() {
150   return std::numeric_limits<I>::digits;
151 }
152 
153 template <>
154 constexpr int MantissaWidth<Eigen::half>() {
155   // Remember, there's 1 hidden bit
156   return 10 + 1;
157 }
158 
159 template <>
160 constexpr int MantissaWidth<bfloat16>() {
161   // Remember, there's 1 hidden bit
162   return 7 + 1;
163 }
164 
165 template <typename Device, typename Tout, typename Tin>
Cast(const Device & d,typename TTypes<Tout>::Flat o,typename TTypes<Tin>::ConstFlat i)166 void Cast(const Device& d, typename TTypes<Tout>::Flat o,
167           typename TTypes<Tin>::ConstFlat i) {
168   o.device(d) = i.template cast<Tout>();
169 }
170 
171 template <typename Device, typename Tout, typename Tin>
172 struct CastFunctor {
173   void operator()(const Device& d, typename TTypes<Tout>::Flat o,
174                   typename TTypes<Tin>::ConstFlat i, bool truncate = false);
175 };
176 
177 // Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
178 // Specialize for others if needed in future.
179 template <typename I>
180 typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)181     EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
182   // Only zero the bits for non-NaNs.
183   // For NaNs, let the non-truncation version handle it.
184   if (!std::isnan(t)) {
185     uint64_t* p = reinterpret_cast<uint64_t*>(&t);
186     *p &= (0xFFFFFFFFFFFFFFFF << n);
187   }
188 }
189 
190 template <typename I>
191 typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)192     EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
193   // Only zero the bits for non-NaNs.
194   // For NaNs, let the non-truncation version handle it.
195   if (!std::isnan(t)) {
196     uint32_t* p = reinterpret_cast<uint32_t*>(&t);
197     *p &= (0xFFFFFFFF << n);
198   }
199 }
200 
201 // Set n least significant bits to 0
202 template <typename I, typename O>
203 struct LSBZeroSetter {
EIGEN_EMPTY_STRUCT_CTORLSBZeroSetter204   EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
205 
206   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
207     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
208     static_assert(
209         bits > 0,
210         "The output type must have fewer mantissa bits than the input type\n");
211     I t = a;
212     LSBZeroSetterHelper(t, bits);
213     return t;
214   }
215 };
216 
217 template <typename I, typename O>
218 struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
219   EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
220 
221   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
222       const std::complex<I>& a) const {
223     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
224     static_assert(
225         bits > 0,
226         "The output type must have fewer mantissa bits than the input type\n");
227     I re = std::real(a);
228     I img = std::imag(a);
229     LSBZeroSetterHelper(re, bits);
230     LSBZeroSetterHelper(img, bits);
231     std::complex<I> toReturn(re, img);
232     return toReturn;
233   }
234 };
235 
236 template <typename I, typename O>
237 struct LSBZeroSetter<std::complex<I>, O> {
238   EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
239   // Sets the 16 LSBits of the float to 0
240   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
241       const std::complex<I>& a) const {
242     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
243     static_assert(
244         bits > 0,
245         "The output type must have fewer mantissa bits than the input type\n");
246     I re = std::real(a);
247     I img = std::imag(a);
248     LSBZeroSetterHelper(re, bits);
249     LSBZeroSetterHelper(img, bits);
250     std::complex<I> toReturn(re, img);
251     return toReturn;
252   }
253 };
254 
255 }  // end namespace functor
256 }  // end namespace tensorflow
257 
258 namespace Eigen {
259 namespace internal {
260 
261 // Eigen can't convert to/from complex numbers, because it is limited to cases
262 // that can be static_casted. But numpy is able to cast to/from complex, which
263 // we want to replicate. So we add specializations for complex here.
264 template <typename From, typename To>
265 struct scalar_cast_op<std::complex<From>, To> {
266   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
267   operator()(const std::complex<From>& a) const {
268     // Replicate numpy behavior of returning just the real part
269     return static_cast<To>(a.real());
270   }
271 };
272 
273 template <typename From, typename To>
274 struct scalar_cast_op<From, std::complex<To>> {
275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
276       const From& a) const {
277     // Replicate numpy behavior of setting the imaginary part to 0
278     return std::complex<To>(static_cast<To>(a), To(0));
279   }
280 };
281 
282 template <typename From, typename To>
283 struct scalar_cast_op<std::complex<From>, std::complex<To>> {
284   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
285       const std::complex<From>& a) const {
286     return std::complex<To>(static_cast<To>(a.real()),
287                             static_cast<To>(a.imag()));
288   }
289 };
290 
291 template <typename From, typename To>
292 struct functor_traits_complex_impl {
293   enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
294 };
295 
296 template <typename From, typename To>
297 struct functor_traits<scalar_cast_op<std::complex<From>, To>>
298     : functor_traits_complex_impl<std::complex<From>, To> {};
299 template <typename From, typename To>
300 struct functor_traits<scalar_cast_op<From, std::complex<To>>>
301     : functor_traits_complex_impl<From, std::complex<To>> {};
302 // Needed to avoid ambiguous partial specialization
303 template <typename From, typename To>
304 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
305     : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
306 
307 }  // namespace internal
308 }  // namespace Eigen
309 
310 #endif  // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
311