1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/math_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/cast_op.h"
21
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/work_sharder.h"
30
31 #include "tensorflow/core/kernels/cast_op_impl.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 #ifdef TENSORFLOW_USE_SYCL
38 typedef Eigen::SyclDevice SYCLDevice;
39 #endif // TENSORFLOW_USE_SYCL
40
41 #define CURRY_TYPES2(FN, arg0) \
42 FN(arg0, bool); \
43 FN(arg0, uint8); \
44 FN(arg0, uint16); \
45 FN(arg0, uint32); \
46 FN(arg0, uint64); \
47 FN(arg0, int8); \
48 FN(arg0, int16); \
49 FN(arg0, int32); \
50 FN(arg0, int64); \
51 FN(arg0, Eigen::half); \
52 FN(arg0, float); \
53 FN(arg0, double); \
54 FN(arg0, std::complex<float>); \
55 FN(arg0, std::complex<double>)
56
CastOpBase(OpKernelConstruction * ctx)57 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
58 OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
59
60 OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
61
62 OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
63
64 // Quantized data types use the same underlying format as their non quantized
65 // version so we use the non quantized implementation for casting.
66 if (external_dst_dtype_ == DT_QUINT8) {
67 dst_dtype_ = DT_UINT8;
68 } else if (external_dst_dtype_ == DT_QINT8) {
69 dst_dtype_ = DT_INT8;
70 } else if (external_dst_dtype_ == DT_QINT32) {
71 dst_dtype_ = DT_INT32;
72 } else if (external_dst_dtype_ == DT_QINT16) {
73 dst_dtype_ = DT_INT16;
74 } else if (external_dst_dtype_ == DT_QUINT16) {
75 dst_dtype_ = DT_UINT16;
76 } else {
77 dst_dtype_ = external_dst_dtype_;
78 }
79
80 if (external_src_dtype_ == DT_QUINT8) {
81 src_dtype_ = DT_UINT8;
82 } else if (external_src_dtype_ == DT_QINT8) {
83 src_dtype_ = DT_INT8;
84 } else if (external_src_dtype_ == DT_QINT32) {
85 src_dtype_ = DT_INT32;
86 } else if (external_src_dtype_ == DT_QINT16) {
87 src_dtype_ = DT_INT16;
88 } else if (external_src_dtype_ == DT_QUINT16) {
89 src_dtype_ = DT_UINT16;
90 } else {
91 src_dtype_ = external_src_dtype_;
92 }
93 }
94
Compute(OpKernelContext * ctx)95 void CastOpBase::Compute(OpKernelContext* ctx) {
96 const Tensor& inp = ctx->input(0);
97 if (work_ == nullptr) {
98 ctx->set_output(0, inp);
99 } else {
100 Tensor in;
101 if (external_src_dtype_ != src_dtype_) {
102 // If the type is a quantized type we need to do a bitcast since the
103 // src_dtype_ is different from external_src_type_.
104 OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
105 } else {
106 in = inp;
107 }
108 Tensor* out = nullptr;
109 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
110 out->set_dtype(dst_dtype_);
111 work_(ctx, in, out, use_truncation_);
112 out->set_dtype(external_dst_dtype_);
113 }
114 }
115
Unimplemented()116 Status CastOpBase::Unimplemented() {
117 return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
118 " to ", DataTypeString(external_dst_dtype_),
119 " is not supported");
120 }
121
CpuCastOp(OpKernelConstruction * ctx)122 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
123 OP_REQUIRES_OK(ctx, Prepare());
124 }
125
Prepare()126 Status CpuCastOp::Prepare() {
127 if (external_src_dtype_ == external_dst_dtype_) {
128 work_ = nullptr; // Identity
129 return Status::OK();
130 }
131 if (src_dtype_ == DT_BOOL) {
132 work_ = GetCpuCastFromBool(dst_dtype_);
133 } else if (src_dtype_ == DT_UINT8) {
134 work_ = GetCpuCastFromUint8(dst_dtype_);
135 } else if (src_dtype_ == DT_UINT16) {
136 work_ = GetCpuCastFromUint16(dst_dtype_);
137 } else if (src_dtype_ == DT_UINT32) {
138 work_ = GetCpuCastFromUint32(dst_dtype_);
139 } else if (src_dtype_ == DT_UINT64) {
140 work_ = GetCpuCastFromUint64(dst_dtype_);
141 } else if (src_dtype_ == DT_INT8) {
142 work_ = GetCpuCastFromInt8(dst_dtype_);
143 } else if (src_dtype_ == DT_INT16) {
144 work_ = GetCpuCastFromInt16(dst_dtype_);
145 } else if (src_dtype_ == DT_INT32) {
146 work_ = GetCpuCastFromInt32(dst_dtype_);
147 } else if (src_dtype_ == DT_INT64) {
148 work_ = GetCpuCastFromInt64(dst_dtype_);
149 } else if (src_dtype_ == DT_HALF) {
150 work_ = GetCpuCastFromHalf(dst_dtype_);
151 } else if (src_dtype_ == DT_FLOAT) {
152 work_ = GetCpuCastFromFloat(dst_dtype_);
153 } else if (src_dtype_ == DT_DOUBLE) {
154 work_ = GetCpuCastFromDouble(dst_dtype_);
155 } else if (src_dtype_ == DT_COMPLEX64) {
156 work_ = GetCpuCastFromComplex64(dst_dtype_);
157 } else if (src_dtype_ == DT_COMPLEX128) {
158 work_ = GetCpuCastFromComplex128(dst_dtype_);
159 } else if (src_dtype_ == DT_BFLOAT16) {
160 work_ = GetCpuCastFromBfloat(dst_dtype_);
161 }
162
163 // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
164 // bottleneck, we could probably implement specialized support for
165 // vectorized versions (not the least based on F16C for Haswell
166 // or newer).
167
168 return work_ == nullptr ? Unimplemented() : Status::OK();
169 }
170
171 #if GOOGLE_CUDA
172 class GpuCastOp : public CastOpBase {
173 public:
GpuCastOp(OpKernelConstruction * ctx)174 explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
175 OP_REQUIRES_OK(ctx, Prepare());
176 }
177
178 private:
Prepare()179 Status Prepare() {
180 if (external_src_dtype_ == external_dst_dtype_) {
181 work_ = nullptr; // Identity
182 return Status::OK();
183 }
184 if (src_dtype_ == DT_BOOL) {
185 work_ = GetGpuCastFromBool(dst_dtype_);
186 } else if (src_dtype_ == DT_UINT8) {
187 work_ = GetGpuCastFromUint8(dst_dtype_);
188 } else if (src_dtype_ == DT_UINT16) {
189 work_ = GetGpuCastFromUint16(dst_dtype_);
190 } else if (src_dtype_ == DT_UINT32) {
191 work_ = GetGpuCastFromUint32(dst_dtype_);
192 } else if (src_dtype_ == DT_UINT64) {
193 work_ = GetGpuCastFromUint64(dst_dtype_);
194 } else if (src_dtype_ == DT_INT8) {
195 work_ = GetGpuCastFromInt8(dst_dtype_);
196 } else if (src_dtype_ == DT_INT16) {
197 work_ = GetGpuCastFromInt16(dst_dtype_);
198 } else if (src_dtype_ == DT_INT32) {
199 work_ = GetGpuCastFromInt32(dst_dtype_);
200 } else if (src_dtype_ == DT_INT64) {
201 work_ = GetGpuCastFromInt64(dst_dtype_);
202 } else if (src_dtype_ == DT_HALF) {
203 work_ = GetGpuCastFromHalf(dst_dtype_);
204 } else if (src_dtype_ == DT_FLOAT) {
205 work_ = GetGpuCastFromFloat(dst_dtype_);
206 } else if (src_dtype_ == DT_DOUBLE) {
207 work_ = GetGpuCastFromDouble(dst_dtype_);
208 } else if (src_dtype_ == DT_COMPLEX64) {
209 work_ = GetGpuCastFromComplex64(dst_dtype_);
210 } else if (src_dtype_ == DT_COMPLEX128) {
211 work_ = GetGpuCastFromComplex128(dst_dtype_);
212 } else if (src_dtype_ == DT_BFLOAT16) {
213 work_ = GetGpuCastFromBfloat(dst_dtype_);
214 }
215
216 return work_ == nullptr ? Unimplemented() : Status::OK();
217 }
218 };
219 #endif // GOOGLE_CUDA
220
221 #undef CAST_CASE
222
223 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
224
225 #if GOOGLE_CUDA
226 #define REGISTER_CAST_GPU(srctype, dsttype) \
227 REGISTER_KERNEL_BUILDER(Name("Cast") \
228 .TypeConstraint<srctype>("SrcT") \
229 .TypeConstraint<dsttype>("DstT") \
230 .Device(DEVICE_GPU), \
231 GpuCastOp)
232
233 CURRY_TYPES2(REGISTER_CAST_GPU, bool);
234 CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
235 CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
236 CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
237 CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
238 CURRY_TYPES2(REGISTER_CAST_GPU, int8);
239 CURRY_TYPES2(REGISTER_CAST_GPU, int16);
240 CURRY_TYPES2(REGISTER_CAST_GPU, int32);
241 CURRY_TYPES2(REGISTER_CAST_GPU, int64);
242 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
243 CURRY_TYPES2(REGISTER_CAST_GPU, float);
244 CURRY_TYPES2(REGISTER_CAST_GPU, double);
245 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
246 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
247 REGISTER_CAST_GPU(float, bfloat16);
248 REGISTER_CAST_GPU(bfloat16, float);
249
250 #undef REGISTER_CAST_GPU
251 #endif // GOOGLE_CUDA
252
253 #ifdef TENSORFLOW_USE_SYCL
254 class SyclCastOp : public CastOpBase {
255 public:
SyclCastOp(OpKernelConstruction * ctx)256 explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
257 OP_REQUIRES_OK(ctx, Prepare());
258 }
259
260 private:
Prepare()261 Status Prepare() {
262 if (external_src_dtype_ == external_dst_dtype_) {
263 work_ = nullptr; // Identity
264 return Status::OK();
265 }
266 if (src_dtype_ == DT_BOOL) {
267 work_ = GetSyclCastFromBool(dst_dtype_);
268 } else if (src_dtype_ == DT_INT32) {
269 work_ = GetSyclCastFromInt32(dst_dtype_);
270 } else if (src_dtype_ == DT_INT64) {
271 work_ = GetSyclCastFromInt64(dst_dtype_);
272 } else if (src_dtype_ == DT_FLOAT) {
273 work_ = GetSyclCastFromFloat(dst_dtype_);
274 } else if (src_dtype_ == DT_DOUBLE) {
275 work_ = GetSyclCastFromDouble(dst_dtype_);
276 }
277
278 return work_ == nullptr ? Unimplemented() : Status::OK();
279 }
280 };
281
282 #define REGISTER_CAST_SYCL(srctype, dsttype) \
283 REGISTER_KERNEL_BUILDER(Name("Cast") \
284 .TypeConstraint<srctype>("SrcT") \
285 .TypeConstraint<dsttype>("DstT") \
286 .Device(DEVICE_SYCL), \
287 SyclCastOp)
288 CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
289 CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
290 CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
291 CURRY_TYPES2(REGISTER_CAST_SYCL, float);
292 CURRY_TYPES2(REGISTER_CAST_SYCL, double);
293
294 #undef REGISTER_CAST_SYCL
295
296 #endif // TENSORFLOW_USE_SYCL
297
298 #undef CURRY_TYPES2
299
300 // HostCast differs from Cast in that its input and output are in host memory.
301 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
302 REGISTER_KERNEL_BUILDER(
303 Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
304 CpuCastOp);
305 #ifdef TENSORFLOW_USE_SYCL
306 REGISTER_KERNEL_BUILDER(
307 Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"),
308 CpuCastOp);
309 #endif // TENSORFLOW_USE_SYCL
310 } // end namespace tensorflow
311