1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 18 19 #define EIGEN_USE_THREADS 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/kernels/cast_op.h" 23 24 namespace tensorflow { 25 26 namespace functor { 27 28 CAST_FUNCTORS(Eigen::ThreadPoolDevice); 29 30 #ifdef TENSORFLOW_USE_SYCL 31 CAST_FUNCTORS(Eigen::SyclDevice); 32 #endif // TENSORFLOW_USE_SYCL 33 34 } // namespace functor 35 36 #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ 37 FN(arg0, arg1, bool); \ 38 FN(arg0, arg1, uint8); \ 39 FN(arg0, arg1, uint16); \ 40 FN(arg0, arg1, uint32); \ 41 FN(arg0, arg1, uint64); \ 42 FN(arg0, arg1, int8); \ 43 FN(arg0, arg1, int16); \ 44 FN(arg0, arg1, int32); \ 45 FN(arg0, arg1, int64); \ 46 FN(arg0, arg1, float); \ 47 FN(arg0, arg1, double); \ 48 FN(arg0, arg1, std::complex<float>); \ 49 FN(arg0, arg1, std::complex<double>) 50 51 #define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ 52 CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ 53 FN(arg0, arg1, Eigen::half); 54 55 #define CURRY_TYPES3(FN, arg0, arg1) \ 56 CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ 57 FN(arg0, arg1, bfloat16); 58 59 #define CAST_CASE(DEVICE, IN, OUT) \ 60 if (DataTypeToEnum<OUT>::value == dst_dtype) { \ 61 return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ 62 bool truncate) { \ 63 functor::CastFunctor<DEVICE, OUT, IN> func; \ 64 func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \ 65 truncate); \ 66 }; \ 67 } 68 69 // The functions below are implemented in the cast_op_impl_*.cc files. 70 CastFunctorType GetCpuCastFromBool(DataType dst_dtype); 71 72 CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); 73 74 CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); 75 76 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); 77 78 CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); 79 80 CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); 81 82 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); 83 84 CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); 85 86 CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); 87 88 CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); 89 90 CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); 91 92 CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); 93 94 CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); 95 96 CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); 97 98 CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); 99 100 CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); 101 102 #if GOOGLE_CUDA 103 // Same, for GPU. 104 CastFunctorType GetGpuCastFromBool(DataType dst_dtype); 105 106 CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); 107 108 CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); 109 110 CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); 111 112 CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); 113 114 CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); 115 116 CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); 117 118 CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); 119 120 CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); 121 122 CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); 123 124 CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); 125 126 CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); 127 128 CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); 129 130 CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); 131 132 CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); 133 134 #endif // GOOGLE_CUDA 135 136 #ifdef TENSORFLOW_USE_SYCL 137 CastFunctorType GetSyclCastFromBool(DataType dst_dtype); 138 139 CastFunctorType GetSyclCastFromUint8(DataType dst_dtype); 140 141 CastFunctorType GetSyclCastFromUint16(DataType dst_dtype); 142 143 CastFunctorType GetSyclCastFromUint32(DataType dst_dtype); 144 145 CastFunctorType GetSyclCastFromUint64(DataType dst_dtype); 146 147 CastFunctorType GetSyclCastFromInt16(DataType dst_dtype); 148 149 CastFunctorType GetSyclCastFromInt32(DataType dst_dtype); 150 151 CastFunctorType GetSyclCastFromInt64(DataType dst_dtype); 152 153 CastFunctorType GetSyclCastFromFloat(DataType dst_dtype); 154 155 CastFunctorType GetSyclCastFromDouble(DataType dst_dtype); 156 #endif // TENSORFLOW_USE_SYCL 157 158 } // namespace tensorflow 159 160 #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 161