1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/type_util.h"
17 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/framework/kernel_def_builder.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/types.pb.h"
24
25 namespace tensorflow {
26 namespace {
27
28 template <typename DstT, typename SrcT>
CastTo(SrcT src)29 DstT CastTo(SrcT src) {
30 return static_cast<DstT>(src);
31 }
32
33 template <typename DstT,
34 typename std::enable_if<std::is_same<DstT, Eigen::half>::value ||
35 std::is_same<DstT, bfloat16>::value>::type* =
36 nullptr>
CastTo(int32_t src)37 DstT CastTo(int32_t src) {
38 return absl::bit_cast<DstT>(static_cast<uint16>(src));
39 }
40
41 // Returns scalar constant with the value in the tensor, if the given proto has
42 // exactly one value but more than one elements. This encoding is used to
43 // efficiently serialize tensors that have one value repeated for all the
44 // indices.
GetScalarConst(const TensorProto & proto,xla::XlaBuilder * b)45 xla::XlaOp GetScalarConst(const TensorProto& proto, xla::XlaBuilder* b) {
46 if (!proto.tensor_content().empty()) return xla::XlaOp();
47 TensorShape shape(proto.tensor_shape());
48 if (shape.num_elements() > 1) {
49 switch (proto.dtype()) {
50 #define HANDLE_SPLAT(DTYPE, field_name, xla_type) \
51 case DTYPE: \
52 if (proto.field_name##_val_size() == 0) { \
53 return xla::ConstantR0(b, CastTo<xla_type>(0)); \
54 } else if (proto.field_name##_val_size() == 1) { \
55 return xla::ConstantR0(b, CastTo<xla_type>(proto.field_name##_val(0))); \
56 } \
57 break;
58
59 HANDLE_SPLAT(DT_BOOL, bool, bool);
60
61 HANDLE_SPLAT(DT_INT8, int, xla::int8);
62 HANDLE_SPLAT(DT_INT16, int, xla::int16);
63 HANDLE_SPLAT(DT_INT32, int, xla::int32);
64 HANDLE_SPLAT(DT_INT64, int64, xla::int64);
65
66 HANDLE_SPLAT(DT_UINT8, int, xla::uint8);
67 HANDLE_SPLAT(DT_UINT16, int, xla::uint16);
68 HANDLE_SPLAT(DT_UINT32, uint32, xla::uint32);
69 HANDLE_SPLAT(DT_UINT64, uint64, xla::uint64);
70
71 HANDLE_SPLAT(DT_FLOAT, float, float);
72 HANDLE_SPLAT(DT_DOUBLE, double, double);
73
74 HANDLE_SPLAT(DT_BFLOAT16, half, bfloat16);
75 HANDLE_SPLAT(DT_HALF, half, Eigen::half);
76
77 #undef HANDLE_SPLAT
78
79 #define HANDLE_COMPLEX_SPLAT(DTYPE, field_name, xla_type) \
80 case DTYPE: \
81 if (proto.field_name##_val_size() == 2) { \
82 return xla::ConstantR0<xla_type>( \
83 b, xla_type(proto.field_name##_val(0), proto.field_name##_val(1))); \
84 } \
85 break;
86
87 HANDLE_COMPLEX_SPLAT(DT_COMPLEX64, scomplex, xla::complex64);
88 HANDLE_COMPLEX_SPLAT(DT_COMPLEX128, dcomplex, xla::complex128);
89
90 #undef HANDLE_COMPLEXSPLAT
91
92 default:
93 break;
94 }
95 }
96
97 return xla::XlaOp();
98 }
99
100 class ConstOp : public XlaOpKernel {
101 public:
ConstOp(OpKernelConstruction * ctx)102 explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
103 const TensorProto* proto = nullptr;
104 OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
105 proto_ = *proto;
106 OP_REQUIRES(
107 ctx, ctx->output_type(0) == proto_.dtype(),
108 errors::InvalidArgument("Type mismatch between value (",
109 DataTypeString(proto_.dtype()), ") and dtype (",
110 DataTypeString(ctx->output_type(0)), ")"));
111 OP_REQUIRES_OK(ctx, TensorShape::IsValidShape(proto_.tensor_shape()));
112 }
113
Compile(XlaOpKernelContext * ctx)114 void Compile(XlaOpKernelContext* ctx) override {
115 xla::XlaBuilder* b = ctx->builder();
116
117 // To avoid blowups for large constants filled with the same value,
118 // recognize that case and emit a scalar broadcast instead.
119 TensorShape shape(proto_.tensor_shape());
120 if (shape.num_elements() > 1) {
121 xla::XlaOp value = GetScalarConst(proto_, b);
122 if (value.valid()) {
123 ctx->SetOutput(0, xla::Broadcast(value, shape.dim_sizes()));
124 return;
125 }
126 }
127
128 Tensor tensor(proto_.dtype());
129 OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_),
130 errors::InvalidArgument("Cannot parse tensor from proto: ",
131 proto_.DebugString()));
132 ctx->SetConstantOutput(0, tensor);
133 }
134
135 private:
136 TensorProto proto_;
137 TF_DISALLOW_COPY_AND_ASSIGN(ConstOp);
138 };
139
140 // XLA_* devices also register a "real" Const operator so we suppress the
141 // dummy operator using CompilationOnly().
142 REGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstOp);
143
144 } // namespace
145 } // namespace tensorflow
146