1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/nn_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/sparse_xent_op.h"
21
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/util/determinism.h"
28 #include "tensorflow/core/util/env_var.h"
29
30 namespace tensorflow {
31
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 typedef Eigen::GpuDevice GPUDevice;
34
35 template <typename Index>
CheckInvalidLabelIndex(const Tensor & labels,int64_t max_index)36 Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) {
37 if (labels.NumElements() == 0) return Status::OK();
38 const auto label_values = labels.vec<Index>();
39 int64_t bad_index;
40 auto min_max_dim_value = std::minmax_element(
41 label_values.data(), label_values.data() + label_values.size());
42 if (*min_max_dim_value.first < 0 || *min_max_dim_value.second >= max_index) {
43 bad_index = (*min_max_dim_value.first < 0) ? *min_max_dim_value.first
44 : *min_max_dim_value.second;
45 return errors::InvalidArgument(
46 "Received a label value of ", bad_index,
47 " which is outside the valid range of [0, ", max_index,
48 "). Label values: ", labels.SummarizeValue(labels.NumElements()));
49 }
50 return Status::OK();
51 }
52
DisableSparseSoftmaxXentWithLogitsOpDeterminismExceptions()53 bool DisableSparseSoftmaxXentWithLogitsOpDeterminismExceptions() {
54 static bool cached_disable = [] {
55 bool disable = false;
56 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
57 "TF_DISABLE_SPARSE_SOFTMAX_XENT_WITH_LOGITS_OP_DETERMINISM_EXCEPTIONS",
58 /*default_val=*/false, &disable));
59 return disable;
60 }();
61 return cached_disable;
62 }
63
64 template <typename Device, typename T, typename Index>
65 class SparseSoftmaxXentWithLogitsOp : public OpKernel {
66 public:
SparseSoftmaxXentWithLogitsOp(OpKernelConstruction * context)67 explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* context)
68 : OpKernel(context) {}
69
Compute(OpKernelContext * context)70 void Compute(OpKernelContext* context) override {
71 const Tensor& logits = context->input(0);
72 const Tensor& labels = context->input(1);
73 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
74 errors::InvalidArgument("logits must be 2-D, but got shape ",
75 logits.shape().DebugString()));
76 OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
77 errors::InvalidArgument("labels must be 1-D, but got shape ",
78 labels.shape().DebugString()));
79 OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
80 errors::InvalidArgument(
81 "logits and labels must have the same first dimension, "
82 "got logits shape ",
83 logits.shape().DebugString(), " and labels shape ",
84 labels.shape().DebugString()));
85 OP_REQUIRES(context, logits.dim_size(1) > 0,
86 errors::InvalidArgument(
87 "Must have at least one class, but got logits shape ",
88 logits.shape().DebugString()));
89
90 if (std::is_same<Device, GPUDevice>::value) {
91 OP_REQUIRES(
92 context,
93 !OpDeterminismRequired() ||
94 DisableSparseSoftmaxXentWithLogitsOpDeterminismExceptions(),
95 errors::Unimplemented(
96 "Deterministic GPU implementation of"
97 " SparseSoftmaxXentWithLogitsOp not available."));
98 }
99
100 Tensor scratch;
101 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
102 labels.shape(), &scratch));
103
104 Tensor* loss_out = nullptr;
105 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
106 {1}, 0, labels.shape(), &loss_out));
107 Tensor* back_out = nullptr;
108 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
109 {0}, 1, logits.shape(), &back_out));
110
111 if (logits.dim_size(0) > 0) {
112 if (std::is_same<Device, CPUDevice>::value) {
113 OP_REQUIRES_OK(
114 context, CheckInvalidLabelIndex<Index>(labels, logits.dim_size(1)));
115 }
116 functor::SparseXentFunctor<Device, T, Index> functor;
117 functor(context, logits.matrix<T>(), labels.vec<Index>(),
118 scratch.vec<T>(), loss_out->vec<T>(), back_out->matrix<T>());
119 }
120 }
121 };
122
123 // Partial specialization for a CPUDevice, that uses the Eigen implementation
124 // from XentEigenImpl.
125 namespace functor {
126 template <typename T, typename Index>
127 struct SparseXentFunctor<CPUDevice, T, Index> {
operator ()tensorflow::functor::SparseXentFunctor128 void operator()(OpKernelContext* ctx, typename TTypes<T>::ConstMatrix logits,
129 typename TTypes<Index>::ConstVec labels,
130 typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss,
131 typename TTypes<T>::Matrix backprop) {
132 SparseXentEigenImpl<CPUDevice, T, Index>::Compute(ctx, logits, labels,
133 scratch, loss, backprop);
134 }
135 };
136 } // namespace functor
137
138 #define REGISTER(Dev, T, Index) \
139 REGISTER_KERNEL_BUILDER( \
140 Name("SparseSoftmaxCrossEntropyWithLogits") \
141 .Device(DEVICE_##Dev) \
142 .TypeConstraint<T>("T") \
143 .TypeConstraint<Index>("Tlabels"), \
144 SparseSoftmaxXentWithLogitsOp<Dev##Device, T, Index>);
145 REGISTER(CPU, float, int32)
146 REGISTER(CPU, float, int64)
147 REGISTER(CPU, double, int32)
148 REGISTER(CPU, double, int64)
149 REGISTER(CPU, Eigen::half, int32)
150 REGISTER(CPU, Eigen::half, int64)
151
152 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
153 REGISTER(GPU, float, int32)
154 REGISTER(GPU, float, int64)
155 REGISTER(GPU, Eigen::half, int32)
156 REGISTER(GPU, Eigen::half, int64)
157 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
158
159 #undef REGISTER
160
161 } // namespace tensorflow
162