• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Ops for softmax.
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 class SoftmaxOp : public XlaOpKernel {
34  public:
SoftmaxOp(OpKernelConstruction * ctx)35   explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
36     log_ = absl::StartsWith(type_string(), "Log");
37   }
38 
Compile(XlaOpKernelContext * ctx)39   void Compile(XlaOpKernelContext* ctx) override {
40     const TensorShape logits_shape = ctx->InputShape(0);
41     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape),
42                 errors::InvalidArgument("logits must have >= 1 dimension, got ",
43                                         logits_shape.DebugString()));
44 
45     // Major dimensions are batch dimensions, minor dimension is the class
46     // dimension.
47     std::vector<int64> batch_dims(logits_shape.dims() - 1);
48     std::iota(batch_dims.begin(), batch_dims.end(), 0);
49     const int kClassDim = logits_shape.dims() - 1;
50 
51     const DataType type = input_type(0);
52     const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
53     auto logits = ctx->Input(0);
54 
55     xla::XlaBuilder* const b = ctx->builder();
56     const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
57 
58     // Find the max in each batch, resulting in a tensor of shape [batch]
59     auto logits_max =
60         xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
61     // Subtract the max in batch b from every element in batch b. Broadcasts
62     // along the batch dimension.
63     auto shifted_logits = xla::Sub(logits, logits_max, batch_dims);
64     auto exp_shifted = xla::Exp(shifted_logits);
65     const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
66     xla::PrimitiveType xla_accumulation_type;
67     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type,
68                                                 &xla_accumulation_type));
69     auto converted =
70         xla::ConvertElementType(exp_shifted, xla_accumulation_type);
71     auto reduce =
72         xla::Reduce(converted, xla::Zero(b, xla_accumulation_type),
73                     *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
74     auto sum = XlaHelpers::ConvertElementType(reduce, type);
75     auto softmax =
76         log_
77             // softmax = shifted_logits - log(sum(exp(shifted_logits)))
78             ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims)
79             // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
80             : xla::Div(exp_shifted, sum, batch_dims);
81     ctx->SetOutput(0, softmax);
82   }
83 
84  private:
85   bool log_;
86 };
87 
88 REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
89 REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
90 
CrossEntropyWithLogits(XlaOpKernelContext * ctx,DataType type,xla::PrimitiveType xla_type,xla::XlaOp logits,xla::XlaOp labels)91 std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
92     XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type,
93     xla::XlaOp logits, xla::XlaOp labels) {
94   const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
95 
96   const int kBatchDim = 0;
97   const int kClassDim = 1;
98 
99   xla::XlaBuilder* b = ctx->builder();
100   // Find the max in each batch, resulting in a tensor of shape [batch]
101   auto logits_max =
102       xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
103 
104   // Subtract the max in batch b from every element in batch b.
105   // Broadcasts along the batch dimension.
106   auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
107 
108   // exp(logits - max_logits)
109   auto exp_shifted_logits = xla::Exp(shifted_logits);
110 
111   // sum_{class} (exp(logits - max_logits))
112   const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
113   auto converted =
114       XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type);
115   auto reduce =
116       xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
117                   *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
118   auto sum_exp = XlaHelpers::ConvertElementType(reduce, type);
119 
120   // log(sum(exp(logits - max_logits)))
121   auto log_sum_exp = xla::Log(sum_exp);
122 
123   // sum(-labels *
124   //    ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
125   // along classes
126   // (The subtraction broadcasts along the batch dimension.)
127   auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim});
128   auto mul = xla::Mul(xla::Neg(labels), sub);
129   auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type),
130                          XlaHelpers::Zero(b, accumulation_type),
131                          *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
132   auto loss = XlaHelpers::ConvertElementType(sum, type);
133 
134   // backprop: prob - labels, where
135   //   prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
136   //     (where the division broadcasts along the batch dimension)
137   xla::XlaOp backprop =
138       xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
139   return {loss, backprop};
140 }
141 
142 class SoftmaxXentWithLogitsOp : public XlaOpKernel {
143  public:
SoftmaxXentWithLogitsOp(OpKernelConstruction * ctx)144   explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
145       : XlaOpKernel(ctx) {}
146 
Compile(XlaOpKernelContext * ctx)147   void Compile(XlaOpKernelContext* ctx) override {
148     const TensorShape logits_shape = ctx->InputShape(0);
149     const TensorShape labels_shape = ctx->InputShape(1);
150     OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
151                 errors::InvalidArgument(
152                     "logits and labels must be same size: logits_size=",
153                     logits_shape.DebugString(),
154                     " labels_size=", labels_shape.DebugString()));
155     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
156                 errors::InvalidArgument("logits must be 2-dimensional"));
157     // As we already tested that both inputs have the same shape no need to
158     // check that "labels" is a matrix too.
159 
160     const DataType type = input_type(0);
161     const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
162     auto logits = ctx->Input(0);
163     auto labels = ctx->Input(1);
164 
165     xla::XlaOp loss, backprop;
166     std::tie(loss, backprop) =
167         CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
168     ctx->SetOutput(0, loss);
169     ctx->SetOutput(1, backprop);
170   }
171 };
172 
173 REGISTER_XLA_OP(Name("SoftmaxCrossEntropyWithLogits"), SoftmaxXentWithLogitsOp);
174 
175 class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
176  public:
SparseSoftmaxXentWithLogitsOp(OpKernelConstruction * ctx)177   explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
178       : XlaOpKernel(ctx) {}
179 
Compile(XlaOpKernelContext * ctx)180   void Compile(XlaOpKernelContext* ctx) override {
181     const TensorShape logits_shape = ctx->InputShape(0);
182     const TensorShape labels_shape = ctx->InputShape(1);
183     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
184                 errors::InvalidArgument("logits must be 2-D, but got shape ",
185                                         logits_shape.DebugString()));
186     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape),
187                 errors::InvalidArgument("labels must be 1-D, but got shape ",
188                                         labels_shape.DebugString()));
189     OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0),
190                 errors::InvalidArgument(
191                     "logits and labels must have the same first dimension, "
192                     "got logits shape ",
193                     logits_shape.DebugString(), " and labels shape ",
194                     labels_shape.DebugString()));
195     OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0,
196                 errors::InvalidArgument(
197                     "Must have at least one class, but got logits shape ",
198                     logits_shape.DebugString()));
199 
200     int64 batch_size = logits_shape.dim_size(0);
201     int64 depth = logits_shape.dim_size(1);
202 
203     const DataType logits_type = input_type(0);
204     const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0);
205     const DataType indices_type = input_type(1);
206 
207     xla::XlaOp indices = ctx->Input(1);
208 
209     xla::XlaBuilder* builder = ctx->builder();
210     xla::XlaOp labels;
211     OP_REQUIRES_OK(ctx,
212                    XlaHelpers::OneHot(
213                        builder, depth, /*axis=*/1, input_type(1), labels_shape,
214                        indices, XlaHelpers::One(builder, logits_type),
215                        XlaHelpers::Zero(builder, logits_type), &labels));
216 
217     // If any of the indices are out of range, we must populate the labels with
218     // NaNs to obey the interface contract of
219     // tf.nn.sparse_softmax_cross_entropy_with_logits.
220     // Builds a vector of {batch_size} that is 0 if the index is in range, or
221     // NaN otherwise; then add that vector to the labels to force out-of-range
222     // values to NaNs.
223     xla::XlaOp nan_or_zero = xla::Select(
224         xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices),
225                  xla::Lt(indices, XlaHelpers::IntegerLiteral(
226                                       builder, indices_type, depth))),
227         xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}),
228         xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
229                        {batch_size}));
230     labels = xla::Add(labels, nan_or_zero, {0});
231 
232     xla::XlaOp loss, backprop;
233     std::tie(loss, backprop) = CrossEntropyWithLogits(
234         ctx, logits_type, xla_logits_type, ctx->Input(0), labels);
235     ctx->SetOutput(0, loss);
236     ctx->SetOutput(1, backprop);
237   }
238 };
239 
240 REGISTER_XLA_OP(Name("SparseSoftmaxCrossEntropyWithLogits"),
241                 SparseSoftmaxXentWithLogitsOp);
242 
243 }  // namespace
244 }  // namespace tensorflow
245