• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/cast_op.h"
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 #include "tensorflow/core/kernels/cast_op_impl.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 #ifdef TENSORFLOW_USE_SYCL
38 typedef Eigen::SyclDevice SYCLDevice;
39 #endif  // TENSORFLOW_USE_SYCL
40 
41 #define CURRY_TYPES2(FN, arg0)   \
42   FN(arg0, bool);                \
43   FN(arg0, uint8);               \
44   FN(arg0, uint16);              \
45   FN(arg0, uint32);              \
46   FN(arg0, uint64);              \
47   FN(arg0, int8);                \
48   FN(arg0, int16);               \
49   FN(arg0, int32);               \
50   FN(arg0, int64);               \
51   FN(arg0, Eigen::half);         \
52   FN(arg0, float);               \
53   FN(arg0, double);              \
54   FN(arg0, std::complex<float>); \
55   FN(arg0, std::complex<double>)
56 
CastOpBase(OpKernelConstruction * ctx)57 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
58   OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
59 
60   OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
61 
62   OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
63 
64   // Quantized data types use the same underlying format as their non quantized
65   // version so we use the non quantized implementation for casting.
66   if (external_dst_dtype_ == DT_QUINT8) {
67     dst_dtype_ = DT_UINT8;
68   } else if (external_dst_dtype_ == DT_QINT8) {
69     dst_dtype_ = DT_INT8;
70   } else if (external_dst_dtype_ == DT_QINT32) {
71     dst_dtype_ = DT_INT32;
72   } else if (external_dst_dtype_ == DT_QINT16) {
73     dst_dtype_ = DT_INT16;
74   } else if (external_dst_dtype_ == DT_QUINT16) {
75     dst_dtype_ = DT_UINT16;
76   } else {
77     dst_dtype_ = external_dst_dtype_;
78   }
79 
80   if (external_src_dtype_ == DT_QUINT8) {
81     src_dtype_ = DT_UINT8;
82   } else if (external_src_dtype_ == DT_QINT8) {
83     src_dtype_ = DT_INT8;
84   } else if (external_src_dtype_ == DT_QINT32) {
85     src_dtype_ = DT_INT32;
86   } else if (external_src_dtype_ == DT_QINT16) {
87     src_dtype_ = DT_INT16;
88   } else if (external_src_dtype_ == DT_QUINT16) {
89     src_dtype_ = DT_UINT16;
90   } else {
91     src_dtype_ = external_src_dtype_;
92   }
93 }
94 
Compute(OpKernelContext * ctx)95 void CastOpBase::Compute(OpKernelContext* ctx) {
96   const Tensor& inp = ctx->input(0);
97   if (work_ == nullptr) {
98     ctx->set_output(0, inp);
99   } else {
100     Tensor in;
101     if (external_src_dtype_ != src_dtype_) {
102       // If the type is a quantized type we need to do a bitcast since the
103       // src_dtype_ is different from external_src_type_.
104       OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
105     } else {
106       in = inp;
107     }
108     Tensor* out = nullptr;
109     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
110     out->set_dtype(dst_dtype_);
111     work_(ctx, in, out, use_truncation_);
112     out->set_dtype(external_dst_dtype_);
113   }
114 }
115 
Unimplemented()116 Status CastOpBase::Unimplemented() {
117   return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
118                                " to ", DataTypeString(external_dst_dtype_),
119                                " is not supported");
120 }
121 
CpuCastOp(OpKernelConstruction * ctx)122 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
123   OP_REQUIRES_OK(ctx, Prepare());
124 }
125 
Prepare()126 Status CpuCastOp::Prepare() {
127   if (external_src_dtype_ == external_dst_dtype_) {
128     work_ = nullptr;  // Identity
129     return Status::OK();
130   }
131   if (src_dtype_ == DT_BOOL) {
132     work_ = GetCpuCastFromBool(dst_dtype_);
133   } else if (src_dtype_ == DT_UINT8) {
134     work_ = GetCpuCastFromUint8(dst_dtype_);
135   } else if (src_dtype_ == DT_UINT16) {
136     work_ = GetCpuCastFromUint16(dst_dtype_);
137   } else if (src_dtype_ == DT_UINT32) {
138     work_ = GetCpuCastFromUint32(dst_dtype_);
139   } else if (src_dtype_ == DT_UINT64) {
140     work_ = GetCpuCastFromUint64(dst_dtype_);
141   } else if (src_dtype_ == DT_INT8) {
142     work_ = GetCpuCastFromInt8(dst_dtype_);
143   } else if (src_dtype_ == DT_INT16) {
144     work_ = GetCpuCastFromInt16(dst_dtype_);
145   } else if (src_dtype_ == DT_INT32) {
146     work_ = GetCpuCastFromInt32(dst_dtype_);
147   } else if (src_dtype_ == DT_INT64) {
148     work_ = GetCpuCastFromInt64(dst_dtype_);
149   } else if (src_dtype_ == DT_HALF) {
150     work_ = GetCpuCastFromHalf(dst_dtype_);
151   } else if (src_dtype_ == DT_FLOAT) {
152     work_ = GetCpuCastFromFloat(dst_dtype_);
153   } else if (src_dtype_ == DT_DOUBLE) {
154     work_ = GetCpuCastFromDouble(dst_dtype_);
155   } else if (src_dtype_ == DT_COMPLEX64) {
156     work_ = GetCpuCastFromComplex64(dst_dtype_);
157   } else if (src_dtype_ == DT_COMPLEX128) {
158     work_ = GetCpuCastFromComplex128(dst_dtype_);
159   } else if (src_dtype_ == DT_BFLOAT16) {
160     work_ = GetCpuCastFromBfloat(dst_dtype_);
161   }
162 
163   // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
164   // bottleneck, we could probably implement specialized support for
165   // vectorized versions (not the least based on F16C for Haswell
166   // or newer).
167 
168   return work_ == nullptr ? Unimplemented() : Status::OK();
169 }
170 
171 #if GOOGLE_CUDA
172 class GpuCastOp : public CastOpBase {
173  public:
GpuCastOp(OpKernelConstruction * ctx)174   explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
175     OP_REQUIRES_OK(ctx, Prepare());
176   }
177 
178  private:
Prepare()179   Status Prepare() {
180     if (external_src_dtype_ == external_dst_dtype_) {
181       work_ = nullptr;  // Identity
182       return Status::OK();
183     }
184     if (src_dtype_ == DT_BOOL) {
185       work_ = GetGpuCastFromBool(dst_dtype_);
186     } else if (src_dtype_ == DT_UINT8) {
187       work_ = GetGpuCastFromUint8(dst_dtype_);
188     } else if (src_dtype_ == DT_UINT16) {
189       work_ = GetGpuCastFromUint16(dst_dtype_);
190     } else if (src_dtype_ == DT_UINT32) {
191       work_ = GetGpuCastFromUint32(dst_dtype_);
192     } else if (src_dtype_ == DT_UINT64) {
193       work_ = GetGpuCastFromUint64(dst_dtype_);
194     } else if (src_dtype_ == DT_INT8) {
195       work_ = GetGpuCastFromInt8(dst_dtype_);
196     } else if (src_dtype_ == DT_INT16) {
197       work_ = GetGpuCastFromInt16(dst_dtype_);
198     } else if (src_dtype_ == DT_INT32) {
199       work_ = GetGpuCastFromInt32(dst_dtype_);
200     } else if (src_dtype_ == DT_INT64) {
201       work_ = GetGpuCastFromInt64(dst_dtype_);
202     } else if (src_dtype_ == DT_HALF) {
203       work_ = GetGpuCastFromHalf(dst_dtype_);
204     } else if (src_dtype_ == DT_FLOAT) {
205       work_ = GetGpuCastFromFloat(dst_dtype_);
206     } else if (src_dtype_ == DT_DOUBLE) {
207       work_ = GetGpuCastFromDouble(dst_dtype_);
208     } else if (src_dtype_ == DT_COMPLEX64) {
209       work_ = GetGpuCastFromComplex64(dst_dtype_);
210     } else if (src_dtype_ == DT_COMPLEX128) {
211       work_ = GetGpuCastFromComplex128(dst_dtype_);
212     } else if (src_dtype_ == DT_BFLOAT16) {
213       work_ = GetGpuCastFromBfloat(dst_dtype_);
214     }
215 
216     return work_ == nullptr ? Unimplemented() : Status::OK();
217   }
218 };
219 #endif  // GOOGLE_CUDA
220 
221 #undef CAST_CASE
222 
223 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
224 
225 #if GOOGLE_CUDA
226 #define REGISTER_CAST_GPU(srctype, dsttype)                    \
227   REGISTER_KERNEL_BUILDER(Name("Cast")                         \
228                               .TypeConstraint<srctype>("SrcT") \
229                               .TypeConstraint<dsttype>("DstT") \
230                               .Device(DEVICE_GPU),             \
231                           GpuCastOp)
232 
233 CURRY_TYPES2(REGISTER_CAST_GPU, bool);
234 CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
235 CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
236 CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
237 CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
238 CURRY_TYPES2(REGISTER_CAST_GPU, int8);
239 CURRY_TYPES2(REGISTER_CAST_GPU, int16);
240 CURRY_TYPES2(REGISTER_CAST_GPU, int32);
241 CURRY_TYPES2(REGISTER_CAST_GPU, int64);
242 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
243 CURRY_TYPES2(REGISTER_CAST_GPU, float);
244 CURRY_TYPES2(REGISTER_CAST_GPU, double);
245 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
246 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
247 REGISTER_CAST_GPU(float, bfloat16);
248 REGISTER_CAST_GPU(bfloat16, float);
249 
250 #undef REGISTER_CAST_GPU
251 #endif  // GOOGLE_CUDA
252 
253 #ifdef TENSORFLOW_USE_SYCL
254 class SyclCastOp : public CastOpBase {
255  public:
SyclCastOp(OpKernelConstruction * ctx)256   explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
257     OP_REQUIRES_OK(ctx, Prepare());
258   }
259 
260  private:
Prepare()261   Status Prepare() {
262     if (external_src_dtype_ == external_dst_dtype_) {
263       work_ = nullptr;  // Identity
264       return Status::OK();
265     }
266     if (src_dtype_ == DT_BOOL) {
267       work_ = GetSyclCastFromBool(dst_dtype_);
268     } else if (src_dtype_ == DT_INT32) {
269       work_ = GetSyclCastFromInt32(dst_dtype_);
270     } else if (src_dtype_ == DT_INT64) {
271       work_ = GetSyclCastFromInt64(dst_dtype_);
272     } else if (src_dtype_ == DT_FLOAT) {
273       work_ = GetSyclCastFromFloat(dst_dtype_);
274     } else if (src_dtype_ == DT_DOUBLE) {
275       work_ = GetSyclCastFromDouble(dst_dtype_);
276     }
277 
278     return work_ == nullptr ? Unimplemented() : Status::OK();
279   }
280 };
281 
282 #define REGISTER_CAST_SYCL(srctype, dsttype)                   \
283   REGISTER_KERNEL_BUILDER(Name("Cast")                         \
284                               .TypeConstraint<srctype>("SrcT") \
285                               .TypeConstraint<dsttype>("DstT") \
286                               .Device(DEVICE_SYCL),            \
287                           SyclCastOp)
288 CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
289 CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
290 CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
291 CURRY_TYPES2(REGISTER_CAST_SYCL, float);
292 CURRY_TYPES2(REGISTER_CAST_SYCL, double);
293 
294 #undef REGISTER_CAST_SYCL
295 
296 #endif  // TENSORFLOW_USE_SYCL
297 
298 #undef CURRY_TYPES2
299 
300 // HostCast differs from Cast in that its input and output are in host memory.
301 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
302 REGISTER_KERNEL_BUILDER(
303     Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
304     CpuCastOp);
305 #ifdef TENSORFLOW_USE_SYCL
306 REGISTER_KERNEL_BUILDER(
307     Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"),
308     CpuCastOp);
309 #endif  // TENSORFLOW_USE_SYCL
310 }  // end namespace tensorflow
311