1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA-specific Ops for softmax.
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29
30 namespace tensorflow {
31 namespace {
32
33 class SoftmaxOp : public XlaOpKernel {
34 public:
SoftmaxOp(OpKernelConstruction * ctx)35 explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
36 log_ = absl::StartsWith(type_string(), "Log");
37 }
38
Compile(XlaOpKernelContext * ctx)39 void Compile(XlaOpKernelContext* ctx) override {
40 const TensorShape logits_shape = ctx->InputShape(0);
41 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape),
42 errors::InvalidArgument("logits must have >= 1 dimension, got ",
43 logits_shape.DebugString()));
44
45 // Major dimensions are batch dimensions, minor dimension is the class
46 // dimension.
47 std::vector<int64> batch_dims(logits_shape.dims() - 1);
48 std::iota(batch_dims.begin(), batch_dims.end(), 0);
49 const int kClassDim = logits_shape.dims() - 1;
50
51 const DataType type = input_type(0);
52 const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
53 auto logits = ctx->Input(0);
54
55 xla::XlaBuilder* const b = ctx->builder();
56 const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
57
58 // Find the max in each batch, resulting in a tensor of shape [batch]
59 auto logits_max =
60 xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
61 // Subtract the max in batch b from every element in batch b. Broadcasts
62 // along the batch dimension.
63 auto shifted_logits = xla::Sub(logits, logits_max, batch_dims);
64 auto exp_shifted = xla::Exp(shifted_logits);
65 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
66 xla::PrimitiveType xla_accumulation_type;
67 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type,
68 &xla_accumulation_type));
69 auto converted =
70 xla::ConvertElementType(exp_shifted, xla_accumulation_type);
71 auto reduce =
72 xla::Reduce(converted, xla::Zero(b, xla_accumulation_type),
73 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
74 auto sum = XlaHelpers::ConvertElementType(reduce, type);
75 auto softmax =
76 log_
77 // softmax = shifted_logits - log(sum(exp(shifted_logits)))
78 ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims)
79 // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
80 : xla::Div(exp_shifted, sum, batch_dims);
81 ctx->SetOutput(0, softmax);
82 }
83
84 private:
85 bool log_;
86 };
87
88 REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
89 REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
90
CrossEntropyWithLogits(XlaOpKernelContext * ctx,DataType type,xla::PrimitiveType xla_type,xla::XlaOp logits,xla::XlaOp labels)91 std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
92 XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type,
93 xla::XlaOp logits, xla::XlaOp labels) {
94 const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
95
96 const int kBatchDim = 0;
97 const int kClassDim = 1;
98
99 xla::XlaBuilder* b = ctx->builder();
100 // Find the max in each batch, resulting in a tensor of shape [batch]
101 auto logits_max =
102 xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
103
104 // Subtract the max in batch b from every element in batch b.
105 // Broadcasts along the batch dimension.
106 auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
107
108 // exp(logits - max_logits)
109 auto exp_shifted_logits = xla::Exp(shifted_logits);
110
111 // sum_{class} (exp(logits - max_logits))
112 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
113 auto converted =
114 XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type);
115 auto reduce =
116 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
117 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
118 auto sum_exp = XlaHelpers::ConvertElementType(reduce, type);
119
120 // log(sum(exp(logits - max_logits)))
121 auto log_sum_exp = xla::Log(sum_exp);
122
123 // sum(-labels *
124 // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
125 // along classes
126 // (The subtraction broadcasts along the batch dimension.)
127 auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim});
128 auto mul = xla::Mul(xla::Neg(labels), sub);
129 auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type),
130 XlaHelpers::Zero(b, accumulation_type),
131 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
132 auto loss = XlaHelpers::ConvertElementType(sum, type);
133
134 // backprop: prob - labels, where
135 // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
136 // (where the division broadcasts along the batch dimension)
137 xla::XlaOp backprop =
138 xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
139 return {loss, backprop};
140 }
141
142 class SoftmaxXentWithLogitsOp : public XlaOpKernel {
143 public:
SoftmaxXentWithLogitsOp(OpKernelConstruction * ctx)144 explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
145 : XlaOpKernel(ctx) {}
146
Compile(XlaOpKernelContext * ctx)147 void Compile(XlaOpKernelContext* ctx) override {
148 const TensorShape logits_shape = ctx->InputShape(0);
149 const TensorShape labels_shape = ctx->InputShape(1);
150 OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
151 errors::InvalidArgument(
152 "logits and labels must be same size: logits_size=",
153 logits_shape.DebugString(),
154 " labels_size=", labels_shape.DebugString()));
155 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
156 errors::InvalidArgument("logits must be 2-dimensional"));
157 // As we already tested that both inputs have the same shape no need to
158 // check that "labels" is a matrix too.
159
160 const DataType type = input_type(0);
161 const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
162 auto logits = ctx->Input(0);
163 auto labels = ctx->Input(1);
164
165 xla::XlaOp loss, backprop;
166 std::tie(loss, backprop) =
167 CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
168 ctx->SetOutput(0, loss);
169 ctx->SetOutput(1, backprop);
170 }
171 };
172
173 REGISTER_XLA_OP(Name("SoftmaxCrossEntropyWithLogits"), SoftmaxXentWithLogitsOp);
174
175 class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
176 public:
SparseSoftmaxXentWithLogitsOp(OpKernelConstruction * ctx)177 explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
178 : XlaOpKernel(ctx) {}
179
Compile(XlaOpKernelContext * ctx)180 void Compile(XlaOpKernelContext* ctx) override {
181 const TensorShape logits_shape = ctx->InputShape(0);
182 const TensorShape labels_shape = ctx->InputShape(1);
183 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
184 errors::InvalidArgument("logits must be 2-D, but got shape ",
185 logits_shape.DebugString()));
186 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape),
187 errors::InvalidArgument("labels must be 1-D, but got shape ",
188 labels_shape.DebugString()));
189 OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0),
190 errors::InvalidArgument(
191 "logits and labels must have the same first dimension, "
192 "got logits shape ",
193 logits_shape.DebugString(), " and labels shape ",
194 labels_shape.DebugString()));
195 OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0,
196 errors::InvalidArgument(
197 "Must have at least one class, but got logits shape ",
198 logits_shape.DebugString()));
199
200 int64 batch_size = logits_shape.dim_size(0);
201 int64 depth = logits_shape.dim_size(1);
202
203 const DataType logits_type = input_type(0);
204 const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0);
205 const DataType indices_type = input_type(1);
206
207 xla::XlaOp indices = ctx->Input(1);
208
209 xla::XlaBuilder* builder = ctx->builder();
210 xla::XlaOp labels;
211 OP_REQUIRES_OK(ctx,
212 XlaHelpers::OneHot(
213 builder, depth, /*axis=*/1, input_type(1), labels_shape,
214 indices, XlaHelpers::One(builder, logits_type),
215 XlaHelpers::Zero(builder, logits_type), &labels));
216
217 // If any of the indices are out of range, we must populate the labels with
218 // NaNs to obey the interface contract of
219 // tf.nn.sparse_softmax_cross_entropy_with_logits.
220 // Builds a vector of {batch_size} that is 0 if the index is in range, or
221 // NaN otherwise; then add that vector to the labels to force out-of-range
222 // values to NaNs.
223 xla::XlaOp nan_or_zero = xla::Select(
224 xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices),
225 xla::Lt(indices, XlaHelpers::IntegerLiteral(
226 builder, indices_type, depth))),
227 xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}),
228 xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
229 {batch_size}));
230 labels = xla::Add(labels, nan_or_zero, {0});
231
232 xla::XlaOp loss, backprop;
233 std::tie(loss, backprop) = CrossEntropyWithLogits(
234 ctx, logits_type, xla_logits_type, ctx->Input(0), labels);
235 ctx->SetOutput(0, loss);
236 ctx->SetOutput(1, backprop);
237 }
238 };
239
240 REGISTER_XLA_OP(Name("SparseSoftmaxCrossEntropyWithLogits"),
241 SparseSoftmaxXentWithLogitsOp);
242
243 } // namespace
244 } // namespace tensorflow
245