• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/type_util.h"
17 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/framework/kernel_def_builder.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 
25 namespace tensorflow {
26 namespace {
27 
28 template <typename DstT, typename SrcT>
CastTo(SrcT src)29 DstT CastTo(SrcT src) {
30   return static_cast<DstT>(src);
31 }
32 
33 template <typename DstT,
34           typename std::enable_if<std::is_same<DstT, Eigen::half>::value ||
35                                   std::is_same<DstT, bfloat16>::value>::type* =
36               nullptr>
CastTo(int32_t src)37 DstT CastTo(int32_t src) {
38   return absl::bit_cast<DstT>(static_cast<uint16>(src));
39 }
40 
41 // Returns scalar constant with the value in the tensor, if the given proto has
42 // exactly one value but more than one elements. This encoding is used to
43 // efficiently serialize tensors that have one value repeated for all the
44 // indices.
GetScalarConst(const TensorProto & proto,xla::XlaBuilder * b)45 xla::XlaOp GetScalarConst(const TensorProto& proto, xla::XlaBuilder* b) {
46   if (!proto.tensor_content().empty()) return xla::XlaOp();
47   TensorShape shape(proto.tensor_shape());
48   if (shape.num_elements() > 1) {
49     switch (proto.dtype()) {
50 #define HANDLE_SPLAT(DTYPE, field_name, xla_type)                             \
51   case DTYPE:                                                                 \
52     if (proto.field_name##_val_size() == 0) {                                 \
53       return xla::ConstantR0(b, CastTo<xla_type>(0));                         \
54     } else if (proto.field_name##_val_size() == 1) {                          \
55       return xla::ConstantR0(b, CastTo<xla_type>(proto.field_name##_val(0))); \
56     }                                                                         \
57     break;
58 
59       HANDLE_SPLAT(DT_BOOL, bool, bool);
60 
61       HANDLE_SPLAT(DT_INT8, int, xla::int8);
62       HANDLE_SPLAT(DT_INT16, int, xla::int16);
63       HANDLE_SPLAT(DT_INT32, int, xla::int32);
64       HANDLE_SPLAT(DT_INT64, int64, xla::int64);
65 
66       HANDLE_SPLAT(DT_UINT8, int, xla::uint8);
67       HANDLE_SPLAT(DT_UINT16, int, xla::uint16);
68       HANDLE_SPLAT(DT_UINT32, uint32, xla::uint32);
69       HANDLE_SPLAT(DT_UINT64, uint64, xla::uint64);
70 
71       HANDLE_SPLAT(DT_FLOAT, float, float);
72       HANDLE_SPLAT(DT_DOUBLE, double, double);
73 
74       HANDLE_SPLAT(DT_BFLOAT16, half, bfloat16);
75       HANDLE_SPLAT(DT_HALF, half, Eigen::half);
76 
77 #undef HANDLE_SPLAT
78 
79 #define HANDLE_COMPLEX_SPLAT(DTYPE, field_name, xla_type)                     \
80   case DTYPE:                                                                 \
81     if (proto.field_name##_val_size() == 2) {                                 \
82       return xla::ConstantR0<xla_type>(                                       \
83           b, xla_type(proto.field_name##_val(0), proto.field_name##_val(1))); \
84     }                                                                         \
85     break;
86 
87       HANDLE_COMPLEX_SPLAT(DT_COMPLEX64, scomplex, xla::complex64);
88       HANDLE_COMPLEX_SPLAT(DT_COMPLEX128, dcomplex, xla::complex128);
89 
90 #undef HANDLE_COMPLEXSPLAT
91 
92       default:
93         break;
94     }
95   }
96 
97   return xla::XlaOp();
98 }
99 
100 class ConstOp : public XlaOpKernel {
101  public:
ConstOp(OpKernelConstruction * ctx)102   explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
103     const TensorProto* proto = nullptr;
104     OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
105     proto_ = *proto;
106     OP_REQUIRES(
107         ctx, ctx->output_type(0) == proto_.dtype(),
108         errors::InvalidArgument("Type mismatch between value (",
109                                 DataTypeString(proto_.dtype()), ") and dtype (",
110                                 DataTypeString(ctx->output_type(0)), ")"));
111     OP_REQUIRES_OK(ctx, TensorShape::IsValidShape(proto_.tensor_shape()));
112   }
113 
Compile(XlaOpKernelContext * ctx)114   void Compile(XlaOpKernelContext* ctx) override {
115     xla::XlaBuilder* b = ctx->builder();
116 
117     // To avoid blowups for large constants filled with the same value,
118     // recognize that case and emit a scalar broadcast instead.
119     TensorShape shape(proto_.tensor_shape());
120     if (shape.num_elements() > 1) {
121       xla::XlaOp value = GetScalarConst(proto_, b);
122       if (value.valid()) {
123         ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes()));
124         return;
125       }
126     }
127 
128     Tensor tensor(proto_.dtype());
129     OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_),
130                 errors::InvalidArgument("Cannot parse tensor from proto: ",
131                                         proto_.DebugString()));
132     ctx->SetConstantOutput(0, tensor);
133   }
134 
135  private:
136   TensorProto proto_;
137   TF_DISALLOW_COPY_AND_ASSIGN(ConstOp);
138 };
139 
140 // XLA_* devices also register a "real" Const operator so we suppress the
141 // dummy operator using CompilationOnly().
142 REGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstOp);
143 
144 }  // namespace
145 }  // namespace tensorflow
146