• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/xent_op.h"
27 #include "tensorflow/core/util/bcast.h"
28 
29 namespace tensorflow {
30 
31 typedef Eigen::ThreadPoolDevice CPUDevice;
32 typedef Eigen::GpuDevice GPUDevice;
33 #ifdef TENSORFLOW_USE_SYCL
34 typedef Eigen::SyclDevice SYCLDevice;
35 #endif  // TENSORFLOW_USE_SYCL
36 
37 template <typename Device, typename T>
38 class SoftmaxXentWithLogitsOp : public OpKernel {
39  public:
SoftmaxXentWithLogitsOp(OpKernelConstruction * context)40   explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context)
41       : OpKernel(context) {}
42 
Compute(OpKernelContext * context)43   void Compute(OpKernelContext* context) override {
44     const Tensor& logits_in = context->input(0);
45     const Tensor& labels_in = context->input(1);
46 
47     TensorShape shape_in = logits_in.shape();
48 
49     BCast bcast(BCast::FromShape(logits_in.shape()),
50                 BCast::FromShape(labels_in.shape()));
51     if (!logits_in.IsSameSize(labels_in)) {
52       OP_REQUIRES(context, bcast.IsValid(),
53                   errors::InvalidArgument(
54                       "logits and labels must be broadcastable: logits_size=",
55                       logits_in.shape().DebugString(),
56                       " labels_size=", labels_in.shape().DebugString()));
57       shape_in = BCast::ToShape(bcast.output_shape());
58     }
59     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
60                 errors::InvalidArgument("logits and labels must be beither "
61                                         "2-dimensional, or roadcasted to "
62                                         "2-dimensional"));
63 
64     // loss is 1-D (one per example), and size is batch_size.
65 
66     Tensor scratch;
67     OP_REQUIRES_OK(
68         context, context->allocate_temp(DataTypeToEnum<T>::value,
69                                         TensorShape({shape_in.dim_size(0), 1}),
70                                         &scratch));
71 
72     Tensor* loss_out = nullptr;
73     OP_REQUIRES_OK(context,
74                    context->allocate_output(
75                        0, TensorShape({shape_in.dim_size(0)}), &loss_out));
76     Tensor* back_out = nullptr;
77     // Try to reuse the logits_in buffer for the backprop output.
78     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
79                                 {0}, 1, shape_in, &back_out));
80     if (shape_in.dim_size(0) > 0) {
81       functor::XentFunctor<Device, T> functor;
82       if (logits_in.IsSameSize(labels_in)) {
83         functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
84                 Eigen::array<Eigen::DenseIndex, 2>{1, 1},
85                 Eigen::array<Eigen::DenseIndex, 2>{1, 1}, logits_in.matrix<T>(),
86                 labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
87                 back_out->matrix<T>());
88       } else {
89         functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
90                 BCast::ToIndexArray<2>(bcast.x_bcast()),
91                 BCast::ToIndexArray<2>(bcast.y_bcast()),
92                 logits_in.template shaped<T, 2>(bcast.x_reshape()),
93                 labels_in.template shaped<T, 2>(bcast.y_reshape()),
94                 scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
95       }
96     }
97   }
98 };
99 
100 // Partial specialization for a CPUDevice, that uses the Eigen implementation
101 // from XentEigenImpl.
102 namespace functor {
103 template <typename Device, typename T>
104 struct XentFunctorBase {
operator ()tensorflow::functor::XentFunctorBase105   void operator()(const Device& d,
106                   const Eigen::DSizes<Eigen::DenseIndex, 2>& shape,
107                   const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast,
108                   const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast,
109                   typename TTypes<T>::ConstMatrix logits,
110                   typename TTypes<T>::ConstMatrix labels,
111                   typename TTypes<T>::Matrix scratch,
112                   typename TTypes<T>::Vec loss,
113                   typename TTypes<T>::Matrix backprop) {
114     XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast,
115                                       logits, labels, scratch, loss, backprop);
116   }
117 };
118 
119 template <typename T>
120 struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {};
121 
122 #ifdef TENSORFLOW_USE_SYCL
123 template <typename T>
124 struct XentFunctor<SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {};
125 #endif  // TENSORFLOW_USE_SYCL
126 }  // namespace functor
127 
128 #define REGISTER_CPU(T)                                         \
129   REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \
130                               .Device(DEVICE_CPU)               \
131                               .TypeConstraint<T>("T"),          \
132                           SoftmaxXentWithLogitsOp<CPUDevice, T>);
133 TF_CALL_half(REGISTER_CPU);
134 TF_CALL_float(REGISTER_CPU);
135 TF_CALL_double(REGISTER_CPU);
136 
137 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
138 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
139                             .Device(DEVICE_GPU)
140                             .TypeConstraint<Eigen::half>("T"),
141                         SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>);
142 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
143                             .Device(DEVICE_GPU)
144                             .TypeConstraint<float>("T"),
145                         SoftmaxXentWithLogitsOp<GPUDevice, float>);
146 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
147                             .Device(DEVICE_GPU)
148                             .TypeConstraint<double>("T"),
149                         SoftmaxXentWithLogitsOp<GPUDevice, double>);
150 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
151 
152 #ifdef TENSORFLOW_USE_SYCL
153 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
154                             .Device(DEVICE_SYCL)
155                             .TypeConstraint<float>("T"),
156                         SoftmaxXentWithLogitsOp<SYCLDevice, float>);
157 #endif  // TENSORFLOW_USE_SYCL
158 
159 }  // namespace tensorflow
160