• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
17     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
18 
19 #define EIGEN_USE_GPU
20 
21 #include "tensorflow/core/framework/bfloat16.h"
22 #define SPECIALIZE_FOR_GPUS
23 #include "tensorflow/core/kernels/cast_op.h"
24 #undef SPECIALIZE_FOR_GPUS
25 
26 namespace tensorflow {
27 namespace functor {
28 
29 typedef Eigen::GpuDevice GPUDevice;
30 
31 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
32 CAST_FUNCTORS_SUBSET(GPUDevice);
33 #else
34 CAST_FUNCTORS(GPUDevice);
35 #endif
36 
37 #define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
38 
39 #define DEFINE_ALL_FROM(in_type)        \
40   DEFINE(in_type, bool);                \
41   DEFINE(in_type, uint8);               \
42   DEFINE(in_type, uint16);              \
43   DEFINE(in_type, uint32);              \
44   DEFINE(in_type, uint64);              \
45   DEFINE(in_type, int8);                \
46   DEFINE(in_type, int16);               \
47   DEFINE(in_type, int32);               \
48   DEFINE(in_type, int64);               \
49   DEFINE(in_type, Eigen::half);         \
50   DEFINE(in_type, float);               \
51   DEFINE(in_type, double);              \
52   DEFINE(in_type, std::complex<float>); \
53   DEFINE(in_type, std::complex<double>)
54 
55 DEFINE(float, bfloat16);
56 
57 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
58 
59 // The cast from float to double is still needed for resize_bilinear_op.cc
60 DEFINE(double, float);
61 
62 #else
63 
64 DEFINE_ALL_FROM(bool);
65 DEFINE_ALL_FROM(uint8);
66 DEFINE_ALL_FROM(uint16);
67 DEFINE_ALL_FROM(uint32);
68 DEFINE_ALL_FROM(uint64);
69 DEFINE_ALL_FROM(int8);
70 DEFINE_ALL_FROM(int16);
71 DEFINE_ALL_FROM(int32);
72 DEFINE_ALL_FROM(int64);
73 DEFINE_ALL_FROM(double);
74 DEFINE_ALL_FROM(std::complex<double>);
75 #endif
76 
77 #define DEFINE_ALL_TO_FLOAT(out_type) \
78   DEFINE(out_type, bool);             \
79   DEFINE(out_type, uint8);            \
80   DEFINE(out_type, uint16);           \
81   DEFINE(out_type, uint32);           \
82   DEFINE(out_type, uint64);           \
83   DEFINE(out_type, int8);             \
84   DEFINE(out_type, int16);            \
85   DEFINE(out_type, int32);            \
86   DEFINE(out_type, int64);            \
87   DEFINE(out_type, Eigen::half);      \
88   DEFINE(out_type, float);            \
89   DEFINE(out_type, std::complex<float>)
90 
91 #define DEFINE_ALL_TO_HALF(out_type) \
92   DEFINE(out_type, bool);            \
93   DEFINE(out_type, uint8);           \
94   DEFINE(out_type, uint16);          \
95   DEFINE(out_type, uint32);          \
96   DEFINE(out_type, uint64);          \
97   DEFINE(out_type, int8);            \
98   DEFINE(out_type, int16);           \
99   DEFINE(out_type, int32);           \
100   DEFINE(out_type, int64);           \
101   DEFINE(out_type, Eigen::half)
102 
103 DEFINE_ALL_TO_HALF(bfloat16);
104 
105 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
106 // The cast from Eigen::half is still needed for depthwise_conv_grad_op.cc.
107 DEFINE(float, Eigen::half);
108 // The cast from float to float is still needed for resize_bilinear_op.cc.
109 DEFINE(float, float);
110 // The casts from complex to the complex element type is still needed for
111 // self_adjoint_eig_v2_op_gpu.cc
112 DEFINE(std::complex<float>, float);
113 DEFINE(std::complex<double>, double);
114 #else
115 DEFINE_ALL_TO_HALF(Eigen::half);
116 DEFINE_ALL_TO_FLOAT(float);
117 DEFINE_ALL_TO_FLOAT(std::complex<float>);
118 #endif
119 
120 #undef DEFINE_ALL_TO_FLOAT
121 #undef DEFINE_ALL_TO_HALF
122 #undef DEFINE_ALL_FROM
123 #undef DEFINE
124 
125 }  // end namespace functor
126 }  // end namespace tensorflow
127 
128 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
129