1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
18 // Functor definition for SparseXentOp, must be compilable by nvcc.
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/platform/types.h"
26
27 namespace tensorflow {
28
29 namespace sparse_xent_helpers {
30
31 template <typename T>
To32BitConst(typename TTypes<T>::Vec in)32 typename TTypes<const T, 1>::Tensor32Bit To32BitConst(
33 typename TTypes<T>::Vec in) {
34 return To32Bit(typename TTypes<T>::ConstVec(in.data(), in.dimensions()));
35 }
36
37 template <typename T>
To32BitConst(typename TTypes<T>::Matrix in)38 typename TTypes<const T, 2>::Tensor32Bit To32BitConst(
39 typename TTypes<T>::Matrix in) {
40 return To32Bit(typename TTypes<T>::ConstMatrix(in.data(), in.dimensions()));
41 }
42
43 } // namespace sparse_xent_helpers
44
45 namespace generator {
46
47 // Generator for calculation of the sparse Xent loss.
48 // This generator takes the logits, the sum of the exponentiated
49 // logits, and the label indices. For each minibatch entry, ignoring
50 // the batch index b, it calculates:
51 //
52 // loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label }
53 //
54 // for j = 0 .. num_classes. This value must be summed over all j for
55 // the final loss.
56 template <typename T, typename Index>
57 class SparseXentLossGenerator {
58 public:
SparseXentLossGenerator(typename TTypes<const T,2>::Tensor32Bit logits,typename TTypes<const T,1>::Tensor32Bit sum_exp_logits,typename TTypes<const Index,1>::Tensor32Bit labels,const Index max_depth)59 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator(
60 typename TTypes<const T, 2>::Tensor32Bit logits,
61 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
62 typename TTypes<const Index, 1>::Tensor32Bit labels,
63 const Index max_depth)
64 : logits_(logits),
65 sum_exp_logits_(sum_exp_logits),
66 labels_(labels),
67 max_depth_(max_depth) {}
68
69 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()70 operator()(const Eigen::array<int, 2>& coords) const {
71 const int batch = coords[0];
72 const int depth = coords[1];
73 const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch));
74 if (!FastBoundsCheck(label, max_depth_)) {
75 return Eigen::NumTraits<T>::quiet_NaN();
76 }
77 return TF_PREDICT_FALSE(label == depth)
78 ? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords))
79 : T(0.0);
80 };
81
82 private:
83 typename TTypes<const T, 2>::Tensor32Bit logits_;
84 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
85 typename TTypes<const Index, 1>::Tensor32Bit labels_;
86 const Index max_depth_;
87 };
88
89 // Generator for calculation of the sparse Xent gradient.
90 // This generator takes the exponentiated logits, their sums, and the label
91 // indices. For each minibatch entry, ignoring the batch index b, it calculates:
92 //
93 // exp_logits[j] / sum_exp_logits - 1{ j == label }
94 //
95 // for j = 0 .. num_classes.
96 template <typename T, typename Index>
97 class SparseXentGradGenerator {
98 public:
SparseXentGradGenerator(typename TTypes<const T,2>::Tensor32Bit exp_logits,typename TTypes<const T,1>::Tensor32Bit sum_exp_logits,typename TTypes<const Index,1>::Tensor32Bit labels,const Index max_depth)99 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator(
100 typename TTypes<const T, 2>::Tensor32Bit exp_logits,
101 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits,
102 typename TTypes<const Index, 1>::Tensor32Bit labels,
103 const Index max_depth)
104 : exp_logits_(exp_logits),
105 sum_exp_logits_(sum_exp_logits),
106 labels_(labels),
107 max_depth_(max_depth) {}
108
109 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()110 operator()(const Eigen::array<int, 2>& coords) const {
111 const int batch = coords[0];
112 const int depth = coords[1];
113 const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch));
114 if (!FastBoundsCheck(label, max_depth_)) {
115 return Eigen::NumTraits<T>::quiet_NaN();
116 }
117 T subtract = TF_PREDICT_FALSE(depth == label) ? T(1.0) : T(0.0);
118 return exp_logits_(coords) / sum_exp_logits_(batch) - subtract;
119 };
120
121 private:
122 typename TTypes<const T, 2>::Tensor32Bit exp_logits_;
123 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_;
124 typename TTypes<const Index, 1>::Tensor32Bit labels_;
125 const Index max_depth_;
126 };
127
128 } // namespace generator
129
130 namespace functor {
131
132 template <typename Device, typename T>
133 struct RowMaxReduction {
134 // Computes the maximum across the rows of logits
135 //
136 // logits: batch_size, num_classes.
137 // maximum: temporary tensor, dims: batch_size, 1
ComputeRowMaxReduction138 static inline void Compute(OpKernelContext* ctx,
139 typename TTypes<T>::ConstMatrix logits,
140 typename TTypes<T>::Vec maximum) {
141 #if !defined(EIGEN_HAS_INDEX_LIST)
142 Eigen::array<int, 1> along_row;
143 along_row[0] = 1;
144 #else
145 Eigen::IndexList<Eigen::type2index<1> > along_row;
146 #endif
147 Device d = ctx->eigen_device<Device>();
148 To32Bit(maximum).device(d) = To32Bit(logits).maximum(along_row);
149 }
150 };
151
152 // Functor used by SparseXentOp to do the computations.
153 template <typename Device, typename T, typename Index>
154 struct SparseXentFunctor {
155 // Computes Cross Entropy loss and backprop.
156 //
157 // logits: batch_size, num_classes.
158 // labels: num_classes.
159 // scratch: temporary tensor, dims: batch_size, 1
160 // loss: output tensor for the loss, dims: batch_size.
161 // backprop: output tensor for the backprop, dims: batch_size, num_classes.
162 void operator()(OpKernelContext* ctx, typename TTypes<T>::ConstMatrix logits,
163 typename TTypes<Index>::ConstVec labels,
164 typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss,
165 typename TTypes<T>::Matrix backprop);
166 };
167
168 // Eigen code implementing SparseXentFunctor::operator().
169 // This code works for both CPU and GPU and is used by the functor
170 // specializations for both device types.
171 template <typename Device, typename T, typename Index>
172 struct SparseXentEigenImpl {
ComputeSparseXentEigenImpl173 static void Compute(OpKernelContext* ctx,
174 typename TTypes<T>::ConstMatrix logits,
175 typename TTypes<Index>::ConstVec labels,
176 typename TTypes<T>::Vec scratch,
177 typename TTypes<T>::Vec loss,
178 typename TTypes<T>::Matrix backprop) {
179 // NOTE(touts): This duplicates some of the computations in softmax_op
180 // because we need the intermediate (logits -max(logits)) values to
181 // avoid a log(exp()) in the computation of the loss.
182
183 const int kBatchDim = 0;
184 const int kClassDim = 1;
185
186 const int batch_size = logits.dimension(kBatchDim);
187 const int num_classes = logits.dimension(kClassDim);
188
189 // These arrays are used to reduce along the class dimension, and broadcast
190 // the resulting value to all classes.
191 #if !defined(EIGEN_HAS_INDEX_LIST)
192 Eigen::array<int, 1> along_class;
193 along_class[0] = kClassDim;
194 Eigen::array<int, 1> batch_only;
195 batch_only[0] = batch_size;
196 Eigen::array<int, 2> batch_by_one;
197 batch_by_one[0] = batch_size;
198 batch_by_one[1] = 1;
199 Eigen::array<int, 2> one_by_class;
200 one_by_class[0] = 1;
201 one_by_class[1] = num_classes;
202 #else
203 Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
204 Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
205 batch_by_one.set(0, batch_size);
206 Eigen::IndexList<int> batch_only;
207 batch_only.set(0, batch_size);
208 Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
209 one_by_class.set(1, num_classes);
210 #endif
211
212 // scratch = max_logits along classes.
213 RowMaxReduction<Device, T>::Compute(ctx, logits, scratch);
214
215 Device d = ctx->eigen_device<Device>();
216 // backprop = logits - max_logits.
217 To32Bit(backprop).device(d) =
218 To32Bit(logits) -
219 To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class);
220
221 // scratch = sum(exp(logits - max_logits)) along classes.
222 To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class);
223
224 // sum(-labels *
225 // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
226 // along classes
227 generator::SparseXentLossGenerator<T, Index> sparse_xent_loss_gen(
228 sparse_xent_helpers::To32BitConst<T>(backprop),
229 sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels),
230 backprop.dimension(1) /* max_depth */);
231 To32Bit(loss).device(d) =
232 To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class);
233
234 // backprop: prob - labels, where
235 // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
236 To32Bit(backprop).device(d) = To32Bit(backprop).exp();
237 generator::SparseXentGradGenerator<T, Index> sparse_xent_grad_gen(
238 sparse_xent_helpers::To32BitConst<T>(backprop),
239 sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels),
240 backprop.dimension(1) /* max_depth */);
241 To32Bit(backprop).device(d) =
242 To32Bit(backprop).generate(sparse_xent_grad_gen);
243 }
244 };
245
246 } // namespace functor
247
248 } // namespace tensorflow
249
250 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
251