• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/xent_op.h"
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/util/bcast.h"
28 #include "tensorflow/core/util/determinism.h"
29 #include "tensorflow/core/util/env_var.h"
30 
31 namespace tensorflow {
32 
33 typedef Eigen::ThreadPoolDevice CPUDevice;
34 typedef Eigen::GpuDevice GPUDevice;
35 
36 template <typename Device, typename T>
37 class SoftmaxXentWithLogitsOp : public OpKernel {
38  public:
SoftmaxXentWithLogitsOp(OpKernelConstruction * context)39   explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context)
40       : OpKernel(context) {}
41 
Compute(OpKernelContext * context)42   void Compute(OpKernelContext* context) override {
43     const Tensor& logits_in = context->input(0);
44     const Tensor& labels_in = context->input(1);
45 
46     TensorShape shape_in = logits_in.shape();
47 
48     BCast bcast(BCast::FromShape(logits_in.shape()),
49                 BCast::FromShape(labels_in.shape()),
50                 /*fewer_dims_optimization=*/false);
51     if (!logits_in.IsSameSize(labels_in)) {
52       OP_REQUIRES(context, bcast.IsValid(),
53                   errors::InvalidArgument(
54                       "logits and labels must be broadcastable: logits_size=",
55                       logits_in.shape().DebugString(),
56                       " labels_size=", labels_in.shape().DebugString()));
57       shape_in = BCast::ToShape(bcast.output_shape());
58     }
59     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
60                 errors::InvalidArgument("logits and labels must be either "
61                                         "2-dimensional, or broadcasted to be "
62                                         "2-dimensional"));
63 
64     if (std::is_same<Device, GPUDevice>::value) {
65       OP_REQUIRES(context, !OpDeterminismRequired(),
66                   errors::Unimplemented(
67                       "The GPU implementation of SoftmaxCrossEntropyWithLogits"
68                       " that would have been executed is not deterministic."
69                       " Note that the Python API uses an alternative,"
70                       " deterministic, GPU-accelerated path when determinism is"
71                       " enabled."));
72     }
73 
74     // loss is 1-D (one per example), and size is batch_size.
75 
76     Tensor scratch;
77     OP_REQUIRES_OK(
78         context, context->allocate_temp(DataTypeToEnum<T>::value,
79                                         TensorShape({shape_in.dim_size(0), 1}),
80                                         &scratch));
81 
82     Tensor* loss_out = nullptr;
83     OP_REQUIRES_OK(context,
84                    context->allocate_output(
85                        0, TensorShape({shape_in.dim_size(0)}), &loss_out));
86     Tensor* back_out = nullptr;
87     // Try to reuse the logits_in buffer for the backprop output.
88     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
89                                 {0}, 1, shape_in, &back_out));
90     if (shape_in.dim_size(0) > 0) {
91       functor::XentFunctor<Device, T> functor;
92       functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
93               BCast::ToIndexArray<2>(bcast.x_bcast()),
94               BCast::ToIndexArray<2>(bcast.y_bcast()),
95               logits_in.template shaped<T, 2>(bcast.x_reshape()),
96               labels_in.template shaped<T, 2>(bcast.y_reshape()),
97               scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
98     }
99   }
100 };
101 
102 // Partial specialization for a CPUDevice, that uses the Eigen implementation
103 // from XentEigenImpl.
104 namespace functor {
105 template <typename Device, typename T>
106 struct XentFunctorBase {
operator ()tensorflow::functor::XentFunctorBase107   void operator()(const Device& d,
108                   const Eigen::DSizes<Eigen::DenseIndex, 2>& shape,
109                   const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast,
110                   const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast,
111                   typename TTypes<T>::ConstMatrix logits,
112                   typename TTypes<T>::ConstMatrix labels,
113                   typename TTypes<T>::Matrix scratch,
114                   typename TTypes<T>::Vec loss,
115                   typename TTypes<T>::Matrix backprop) {
116     XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast,
117                                       logits, labels, scratch, loss, backprop);
118   }
119 };
120 
121 template <typename T>
122 struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {};
123 
124 }  // namespace functor
125 
126 #define REGISTER_CPU(T)                                         \
127   REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \
128                               .Device(DEVICE_CPU)               \
129                               .TypeConstraint<T>("T"),          \
130                           SoftmaxXentWithLogitsOp<CPUDevice, T>);
131 TF_CALL_half(REGISTER_CPU);
132 TF_CALL_float(REGISTER_CPU);
133 TF_CALL_double(REGISTER_CPU);
134 
135 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
136     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
137 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
138                             .Device(DEVICE_GPU)
139                             .TypeConstraint<Eigen::half>("T"),
140                         SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>);
141 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
142                             .Device(DEVICE_GPU)
143                             .TypeConstraint<float>("T"),
144                         SoftmaxXentWithLogitsOp<GPUDevice, float>);
145 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
146                             .Device(DEVICE_GPU)
147                             .TypeConstraint<double>("T"),
148                         SoftmaxXentWithLogitsOp<GPUDevice, double>);
149 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
150 
151 
152 }  // namespace tensorflow
153