1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/kernels/xent_op.h" 27 #include "tensorflow/core/util/bcast.h" 28 29 namespace tensorflow { 30 31 typedef Eigen::ThreadPoolDevice CPUDevice; 32 typedef Eigen::GpuDevice GPUDevice; 33 #ifdef TENSORFLOW_USE_SYCL 34 typedef Eigen::SyclDevice SYCLDevice; 35 #endif // TENSORFLOW_USE_SYCL 36 37 template <typename Device, typename T> 38 class SoftmaxXentWithLogitsOp : public OpKernel { 39 public: SoftmaxXentWithLogitsOp(OpKernelConstruction * context)40 explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context) 41 : OpKernel(context) {} 42 Compute(OpKernelContext * context)43 void Compute(OpKernelContext* context) override { 44 const Tensor& logits_in = context->input(0); 45 const Tensor& labels_in = context->input(1); 46 47 TensorShape shape_in = logits_in.shape(); 48 49 BCast bcast(BCast::FromShape(logits_in.shape()), 50 BCast::FromShape(labels_in.shape())); 51 if (!logits_in.IsSameSize(labels_in)) { 52 OP_REQUIRES(context, bcast.IsValid(), 53 errors::InvalidArgument( 54 "logits and labels must be broadcastable: logits_size=", 55 logits_in.shape().DebugString(), 56 " labels_size=", labels_in.shape().DebugString())); 57 shape_in = BCast::ToShape(bcast.output_shape()); 58 } 59 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in), 60 errors::InvalidArgument("logits and labels must be beither " 61 "2-dimensional, or roadcasted to " 62 "2-dimensional")); 63 64 // loss is 1-D (one per example), and size is batch_size. 65 66 Tensor scratch; 67 OP_REQUIRES_OK( 68 context, context->allocate_temp(DataTypeToEnum<T>::value, 69 TensorShape({shape_in.dim_size(0), 1}), 70 &scratch)); 71 72 Tensor* loss_out = nullptr; 73 OP_REQUIRES_OK(context, 74 context->allocate_output( 75 0, TensorShape({shape_in.dim_size(0)}), &loss_out)); 76 Tensor* back_out = nullptr; 77 // Try to reuse the logits_in buffer for the backprop output. 78 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 79 {0}, 1, shape_in, &back_out)); 80 if (shape_in.dim_size(0) > 0) { 81 functor::XentFunctor<Device, T> functor; 82 if (logits_in.IsSameSize(labels_in)) { 83 functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(), 84 Eigen::array<Eigen::DenseIndex, 2>{1, 1}, 85 Eigen::array<Eigen::DenseIndex, 2>{1, 1}, logits_in.matrix<T>(), 86 labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), 87 back_out->matrix<T>()); 88 } else { 89 functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(), 90 BCast::ToIndexArray<2>(bcast.x_bcast()), 91 BCast::ToIndexArray<2>(bcast.y_bcast()), 92 logits_in.template shaped<T, 2>(bcast.x_reshape()), 93 labels_in.template shaped<T, 2>(bcast.y_reshape()), 94 scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>()); 95 } 96 } 97 } 98 }; 99 100 // Partial specialization for a CPUDevice, that uses the Eigen implementation 101 // from XentEigenImpl. 102 namespace functor { 103 template <typename Device, typename T> 104 struct XentFunctorBase { operator ()tensorflow::functor::XentFunctorBase105 void operator()(const Device& d, 106 const Eigen::DSizes<Eigen::DenseIndex, 2>& shape, 107 const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast, 108 const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast, 109 typename TTypes<T>::ConstMatrix logits, 110 typename TTypes<T>::ConstMatrix labels, 111 typename TTypes<T>::Matrix scratch, 112 typename TTypes<T>::Vec loss, 113 typename TTypes<T>::Matrix backprop) { 114 XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast, 115 logits, labels, scratch, loss, backprop); 116 } 117 }; 118 119 template <typename T> 120 struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {}; 121 122 #ifdef TENSORFLOW_USE_SYCL 123 template <typename T> 124 struct XentFunctor<SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {}; 125 #endif // TENSORFLOW_USE_SYCL 126 } // namespace functor 127 128 #define REGISTER_CPU(T) \ 129 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \ 130 .Device(DEVICE_CPU) \ 131 .TypeConstraint<T>("T"), \ 132 SoftmaxXentWithLogitsOp<CPUDevice, T>); 133 TF_CALL_half(REGISTER_CPU); 134 TF_CALL_float(REGISTER_CPU); 135 TF_CALL_double(REGISTER_CPU); 136 137 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 138 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 139 .Device(DEVICE_GPU) 140 .TypeConstraint<Eigen::half>("T"), 141 SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>); 142 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 143 .Device(DEVICE_GPU) 144 .TypeConstraint<float>("T"), 145 SoftmaxXentWithLogitsOp<GPUDevice, float>); 146 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 147 .Device(DEVICE_GPU) 148 .TypeConstraint<double>("T"), 149 SoftmaxXentWithLogitsOp<GPUDevice, double>); 150 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 151 152 #ifdef TENSORFLOW_USE_SYCL 153 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 154 .Device(DEVICE_SYCL) 155 .TypeConstraint<float>("T"), 156 SoftmaxXentWithLogitsOp<SYCLDevice, float>); 157 #endif // TENSORFLOW_USE_SYCL 158 159 } // namespace tensorflow 160