• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/sparse_xent_op.h"
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/util/determinism.h"
28 #include "tensorflow/core/util/env_var.h"
29 
30 namespace tensorflow {
31 
32 typedef Eigen::ThreadPoolDevice CPUDevice;
33 typedef Eigen::GpuDevice GPUDevice;
34 
35 template <typename Index>
CheckInvalidLabelIndex(const Tensor & labels,int64_t max_index)36 Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) {
37   if (labels.NumElements() == 0) return Status::OK();
38   const auto label_values = labels.vec<Index>();
39   int64_t bad_index;
40   auto min_max_dim_value = std::minmax_element(
41       label_values.data(), label_values.data() + label_values.size());
42   if (*min_max_dim_value.first < 0 || *min_max_dim_value.second >= max_index) {
43     bad_index = (*min_max_dim_value.first < 0) ? *min_max_dim_value.first
44                                                : *min_max_dim_value.second;
45     return errors::InvalidArgument(
46         "Received a label value of ", bad_index,
47         " which is outside the valid range of [0, ", max_index,
48         ").  Label values: ", labels.SummarizeValue(labels.NumElements()));
49   }
50   return Status::OK();
51 }
52 
DisableSparseSoftmaxXentWithLogitsOpDeterminismExceptions()53 bool DisableSparseSoftmaxXentWithLogitsOpDeterminismExceptions() {
54   static bool cached_disable = [] {
55     bool disable = false;
56     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
57         "TF_DISABLE_SPARSE_SOFTMAX_XENT_WITH_LOGITS_OP_DETERMINISM_EXCEPTIONS",
58         /*default_val=*/false, &disable));
59     return disable;
60   }();
61   return cached_disable;
62 }
63 
64 template <typename Device, typename T, typename Index>
65 class SparseSoftmaxXentWithLogitsOp : public OpKernel {
66  public:
SparseSoftmaxXentWithLogitsOp(OpKernelConstruction * context)67   explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* context)
68       : OpKernel(context) {}
69 
Compute(OpKernelContext * context)70   void Compute(OpKernelContext* context) override {
71     const Tensor& logits = context->input(0);
72     const Tensor& labels = context->input(1);
73     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
74                 errors::InvalidArgument("logits must be 2-D, but got shape ",
75                                         logits.shape().DebugString()));
76     OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
77                 errors::InvalidArgument("labels must be 1-D, but got shape ",
78                                         labels.shape().DebugString()));
79     OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
80                 errors::InvalidArgument(
81                     "logits and labels must have the same first dimension, "
82                     "got logits shape ",
83                     logits.shape().DebugString(), " and labels shape ",
84                     labels.shape().DebugString()));
85     OP_REQUIRES(context, logits.dim_size(1) > 0,
86                 errors::InvalidArgument(
87                     "Must have at least one class, but got logits shape ",
88                     logits.shape().DebugString()));
89 
90     if (std::is_same<Device, GPUDevice>::value) {
91       OP_REQUIRES(
92           context,
93           !OpDeterminismRequired() ||
94               DisableSparseSoftmaxXentWithLogitsOpDeterminismExceptions(),
95           errors::Unimplemented(
96               "Deterministic GPU implementation of"
97               " SparseSoftmaxXentWithLogitsOp not available."));
98     }
99 
100     Tensor scratch;
101     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
102                                                    labels.shape(), &scratch));
103 
104     Tensor* loss_out = nullptr;
105     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
106                                 {1}, 0, labels.shape(), &loss_out));
107     Tensor* back_out = nullptr;
108     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
109                                 {0}, 1, logits.shape(), &back_out));
110 
111     if (logits.dim_size(0) > 0) {
112       if (std::is_same<Device, CPUDevice>::value) {
113         OP_REQUIRES_OK(
114             context, CheckInvalidLabelIndex<Index>(labels, logits.dim_size(1)));
115       }
116       functor::SparseXentFunctor<Device, T, Index> functor;
117       functor(context, logits.matrix<T>(), labels.vec<Index>(),
118               scratch.vec<T>(), loss_out->vec<T>(), back_out->matrix<T>());
119     }
120   }
121 };
122 
123 // Partial specialization for a CPUDevice, that uses the Eigen implementation
124 // from XentEigenImpl.
125 namespace functor {
126 template <typename T, typename Index>
127 struct SparseXentFunctor<CPUDevice, T, Index> {
operator ()tensorflow::functor::SparseXentFunctor128   void operator()(OpKernelContext* ctx, typename TTypes<T>::ConstMatrix logits,
129                   typename TTypes<Index>::ConstVec labels,
130                   typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss,
131                   typename TTypes<T>::Matrix backprop) {
132     SparseXentEigenImpl<CPUDevice, T, Index>::Compute(ctx, logits, labels,
133                                                       scratch, loss, backprop);
134   }
135 };
136 }  // namespace functor
137 
138 #define REGISTER(Dev, T, Index)                   \
139   REGISTER_KERNEL_BUILDER(                        \
140       Name("SparseSoftmaxCrossEntropyWithLogits") \
141           .Device(DEVICE_##Dev)                   \
142           .TypeConstraint<T>("T")                 \
143           .TypeConstraint<Index>("Tlabels"),      \
144       SparseSoftmaxXentWithLogitsOp<Dev##Device, T, Index>);
145 REGISTER(CPU, float, int32)
146 REGISTER(CPU, float, int64)
147 REGISTER(CPU, double, int32)
148 REGISTER(CPU, double, int64)
149 REGISTER(CPU, Eigen::half, int32)
150 REGISTER(CPU, Eigen::half, int64)
151 
152 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
153 REGISTER(GPU, float, int32)
154 REGISTER(GPU, float, int64)
155 REGISTER(GPU, Eigen::half, int32)
156 REGISTER(GPU, Eigen::half, int64)
157 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
158 
159 #undef REGISTER
160 
161 }  // namespace tensorflow
162