• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/lib/util.h"
16 
17 #include "tensorflow/compiler/tf2xla/type_util.h"
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 class CastOp : public XlaOpKernel {
30  public:
CastOp(OpKernelConstruction * ctx)31   explicit CastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
32     OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
33     OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
34     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_));
35     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_));
36     OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
37   }
38 
Compile(XlaOpKernelContext * ctx)39   void Compile(XlaOpKernelContext* ctx) override {
40     xla::XlaBuilder* builder = ctx->builder();
41     xla::XlaOp input = ctx->Input(0);
42     xla::XlaOp output;
43 
44     if (src_dtype_ == dst_dtype_) {
45       output = input;
46     } else if (dst_dtype_ == DT_BOOL) {
47       output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_));
48     } else if (xla::primitive_util::IsComplexType(src_type_) &&
49                !xla::primitive_util::IsComplexType(dst_type_)) {
50       // As in cast_op.h, we replicate the numpy behavior of truncating the
51       // imaginary part.
52       output = xla::ConvertElementType(xla::Real(input), dst_type_);
53     } else {
54       if (use_truncation_) {
55         OP_REQUIRES(
56             ctx,
57             xla::primitive_util::IsFloatingPointType(src_type_) &&
58                 xla::primitive_util::IsFloatingPointType(dst_type_),
59             errors::Unimplemented("Truncate attribute is only "
60                                   "implemented for floating point datatypes."));
61         int mantissa_difference =
62             xla::primitive_util::SignificandWidth(src_type_) -
63             xla::primitive_util::SignificandWidth(dst_type_);
64         OP_REQUIRES(ctx, mantissa_difference > 0,
65                     errors::Unimplemented(
66                         "Truncate attribute is only implemented in cases where "
67                         "dst datatype "
68                         "has fewer mantissa bits than the src datatype"));
69         int src_bitwidth = xla::primitive_util::BitWidth(src_type_);
70 
71         // Bitcast to same-width integer, mask off the LSBs, bitcast back to the
72         // source datatype.
73         int64 mask = ~((1L << mantissa_difference) - 1);
74         xla::PrimitiveType same_width_int =
75             xla::primitive_util::UnsignedIntegralTypeForBitWidth(src_bitwidth);
76         OP_REQUIRES(ctx, same_width_int != xla::PRIMITIVE_TYPE_INVALID,
77                     errors::Unimplemented("Unexpected type bitwidth"));
78         input = xla::BitcastConvertType(
79             xla::And(
80                 xla::BitcastConvertType(input, same_width_int),
81                 ::tensorflow::IntegerLiteral(builder, same_width_int, mask)),
82             src_type_);
83       }
84       output = xla::ConvertElementType(input, dst_type_);
85     }
86 
87     ctx->SetOutput(0, output);
88   }
89 
90  protected:
91   DataType src_dtype_, dst_dtype_;
92   xla::PrimitiveType src_type_, dst_type_;
93   bool use_truncation_;
94 
95   TF_DISALLOW_COPY_AND_ASSIGN(CastOp);
96 };
97 
98 REGISTER_XLA_OP(Name("Cast"), CastOp);
99 
100 class BitcastOp : public XlaOpKernel {
101  public:
BitcastOp(OpKernelConstruction * ctx)102   explicit BitcastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
103     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &src_dtype_));
104     OP_REQUIRES_OK(ctx, ctx->GetAttr("type", &dst_dtype_));
105     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_));
106     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_));
107   }
108 
Compile(XlaOpKernelContext * ctx)109   void Compile(XlaOpKernelContext* ctx) override {
110     xla::XlaOp input = ctx->Input(0);
111     xla::XlaOp output;
112 
113     if (src_dtype_ == dst_dtype_) {
114       output = input;
115     } else {
116       // Error out if the bitcast has a complex source or destination type and
117       // the bitcast is not trivial.
118       OP_REQUIRES(ctx,
119                   !xla::primitive_util::IsComplexType(src_type_) &&
120                       !xla::primitive_util::IsComplexType(dst_type_),
121                   errors::Unimplemented("Complex types not supported."));
122       // XLA bitcast requires that the bit-width of the source and destination
123       // matches, and currently only the simple lowering is performed.
124       OP_REQUIRES(ctx,
125                   xla::primitive_util::BitWidth(src_type_) ==
126                       xla::primitive_util::BitWidth(dst_type_),
127                   errors::Unimplemented(
128                       "Only bitcasts between equally sized types supported."));
129       output = xla::BitcastConvertType(input, dst_type_);
130     }
131 
132     ctx->SetOutput(0, output);
133   }
134 
135  protected:
136   DataType src_dtype_, dst_dtype_;
137   xla::PrimitiveType src_type_, dst_type_;
138 
139   TF_DISALLOW_COPY_AND_ASSIGN(BitcastOp);
140 };
141 
142 REGISTER_XLA_OP(Name("Bitcast"), BitcastOp);
143 
144 }  // anonymous namespace
145 }  // namespace tensorflow
146