1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/compiler/tf2xla/lib/util.h" 16 17 #include "tensorflow/compiler/tf2xla/type_util.h" 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/primitive_util.h" 23 #include "tensorflow/compiler/xla/xla_data.pb.h" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 26 namespace tensorflow { 27 namespace { 28 29 class CastOp : public XlaOpKernel { 30 public: CastOp(OpKernelConstruction * ctx)31 explicit CastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 32 OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_)); 33 OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); 34 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); 35 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); 36 OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); 37 } 38 Compile(XlaOpKernelContext * ctx)39 void Compile(XlaOpKernelContext* ctx) override { 40 xla::XlaBuilder* builder = ctx->builder(); 41 xla::XlaOp input = ctx->Input(0); 42 xla::XlaOp output; 43 44 if (src_dtype_ == dst_dtype_) { 45 output = input; 46 } else if (dst_dtype_ == DT_BOOL) { 47 output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_)); 48 } else if (xla::primitive_util::IsComplexType(src_type_) && 49 !xla::primitive_util::IsComplexType(dst_type_)) { 50 // As in cast_op.h, we replicate the numpy behavior of truncating the 51 // imaginary part. 52 output = xla::ConvertElementType(xla::Real(input), dst_type_); 53 } else { 54 if (use_truncation_) { 55 OP_REQUIRES( 56 ctx, 57 xla::primitive_util::IsFloatingPointType(src_type_) && 58 xla::primitive_util::IsFloatingPointType(dst_type_), 59 errors::Unimplemented("Truncate attribute is only " 60 "implemented for floating point datatypes.")); 61 int mantissa_difference = 62 xla::primitive_util::SignificandWidth(src_type_) - 63 xla::primitive_util::SignificandWidth(dst_type_); 64 OP_REQUIRES(ctx, mantissa_difference > 0, 65 errors::Unimplemented( 66 "Truncate attribute is only implemented in cases where " 67 "dst datatype " 68 "has fewer mantissa bits than the src datatype")); 69 int src_bitwidth = xla::primitive_util::BitWidth(src_type_); 70 71 // Bitcast to same-width integer, mask off the LSBs, bitcast back to the 72 // source datatype. 73 int64 mask = ~((1L << mantissa_difference) - 1); 74 xla::PrimitiveType same_width_int = 75 xla::primitive_util::UnsignedIntegralTypeForBitWidth(src_bitwidth); 76 OP_REQUIRES(ctx, same_width_int != xla::PRIMITIVE_TYPE_INVALID, 77 errors::Unimplemented("Unexpected type bitwidth")); 78 input = xla::BitcastConvertType( 79 xla::And( 80 xla::BitcastConvertType(input, same_width_int), 81 ::tensorflow::IntegerLiteral(builder, same_width_int, mask)), 82 src_type_); 83 } 84 output = xla::ConvertElementType(input, dst_type_); 85 } 86 87 ctx->SetOutput(0, output); 88 } 89 90 protected: 91 DataType src_dtype_, dst_dtype_; 92 xla::PrimitiveType src_type_, dst_type_; 93 bool use_truncation_; 94 95 TF_DISALLOW_COPY_AND_ASSIGN(CastOp); 96 }; 97 98 REGISTER_XLA_OP(Name("Cast"), CastOp); 99 100 class BitcastOp : public XlaOpKernel { 101 public: BitcastOp(OpKernelConstruction * ctx)102 explicit BitcastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 103 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &src_dtype_)); 104 OP_REQUIRES_OK(ctx, ctx->GetAttr("type", &dst_dtype_)); 105 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); 106 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); 107 } 108 Compile(XlaOpKernelContext * ctx)109 void Compile(XlaOpKernelContext* ctx) override { 110 xla::XlaOp input = ctx->Input(0); 111 xla::XlaOp output; 112 113 if (src_dtype_ == dst_dtype_) { 114 output = input; 115 } else { 116 // Error out if the bitcast has a complex source or destination type and 117 // the bitcast is not trivial. 118 OP_REQUIRES(ctx, 119 !xla::primitive_util::IsComplexType(src_type_) && 120 !xla::primitive_util::IsComplexType(dst_type_), 121 errors::Unimplemented("Complex types not supported.")); 122 // XLA bitcast requires that the bit-width of the source and destination 123 // matches, and currently only the simple lowering is performed. 124 OP_REQUIRES(ctx, 125 xla::primitive_util::BitWidth(src_type_) == 126 xla::primitive_util::BitWidth(dst_type_), 127 errors::Unimplemented( 128 "Only bitcasts between equally sized types supported.")); 129 output = xla::BitcastConvertType(input, dst_type_); 130 } 131 132 ctx->SetOutput(0, output); 133 } 134 135 protected: 136 DataType src_dtype_, dst_dtype_; 137 xla::PrimitiveType src_type_, dst_type_; 138 139 TF_DISALLOW_COPY_AND_ASSIGN(BitcastOp); 140 }; 141 142 REGISTER_XLA_OP(Name("Bitcast"), BitcastOp); 143 144 } // anonymous namespace 145 } // namespace tensorflow 146