1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/xent_op.h" 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/util/bcast.h" 28 #include "tensorflow/core/util/determinism.h" 29 #include "tensorflow/core/util/env_var.h" 30 31 namespace tensorflow { 32 33 typedef Eigen::ThreadPoolDevice CPUDevice; 34 typedef Eigen::GpuDevice GPUDevice; 35 36 template <typename Device, typename T> 37 class SoftmaxXentWithLogitsOp : public OpKernel { 38 public: SoftmaxXentWithLogitsOp(OpKernelConstruction * context)39 explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context) 40 : OpKernel(context) {} 41 Compute(OpKernelContext * context)42 void Compute(OpKernelContext* context) override { 43 const Tensor& logits_in = context->input(0); 44 const Tensor& labels_in = context->input(1); 45 46 TensorShape shape_in = logits_in.shape(); 47 48 BCast bcast(BCast::FromShape(logits_in.shape()), 49 BCast::FromShape(labels_in.shape()), 50 /*fewer_dims_optimization=*/false); 51 if (!logits_in.IsSameSize(labels_in)) { 52 OP_REQUIRES(context, bcast.IsValid(), 53 errors::InvalidArgument( 54 "logits and labels must be broadcastable: logits_size=", 55 logits_in.shape().DebugString(), 56 " labels_size=", labels_in.shape().DebugString())); 57 shape_in = BCast::ToShape(bcast.output_shape()); 58 } 59 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in), 60 errors::InvalidArgument("logits and labels must be either " 61 "2-dimensional, or broadcasted to be " 62 "2-dimensional")); 63 64 if (std::is_same<Device, GPUDevice>::value) { 65 OP_REQUIRES(context, !OpDeterminismRequired(), 66 errors::Unimplemented( 67 "The GPU implementation of SoftmaxCrossEntropyWithLogits" 68 " that would have been executed is not deterministic." 69 " Note that the Python API uses an alternative," 70 " deterministic, GPU-accelerated path when determinism is" 71 " enabled.")); 72 } 73 74 // loss is 1-D (one per example), and size is batch_size. 75 76 Tensor scratch; 77 OP_REQUIRES_OK( 78 context, context->allocate_temp(DataTypeToEnum<T>::value, 79 TensorShape({shape_in.dim_size(0), 1}), 80 &scratch)); 81 82 Tensor* loss_out = nullptr; 83 OP_REQUIRES_OK(context, 84 context->allocate_output( 85 0, TensorShape({shape_in.dim_size(0)}), &loss_out)); 86 Tensor* back_out = nullptr; 87 // Try to reuse the logits_in buffer for the backprop output. 88 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 89 {0}, 1, shape_in, &back_out)); 90 if (shape_in.dim_size(0) > 0) { 91 functor::XentFunctor<Device, T> functor; 92 functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(), 93 BCast::ToIndexArray<2>(bcast.x_bcast()), 94 BCast::ToIndexArray<2>(bcast.y_bcast()), 95 logits_in.template shaped<T, 2>(bcast.x_reshape()), 96 labels_in.template shaped<T, 2>(bcast.y_reshape()), 97 scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>()); 98 } 99 } 100 }; 101 102 // Partial specialization for a CPUDevice, that uses the Eigen implementation 103 // from XentEigenImpl. 104 namespace functor { 105 template <typename Device, typename T> 106 struct XentFunctorBase { operator ()tensorflow::functor::XentFunctorBase107 void operator()(const Device& d, 108 const Eigen::DSizes<Eigen::DenseIndex, 2>& shape, 109 const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast, 110 const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast, 111 typename TTypes<T>::ConstMatrix logits, 112 typename TTypes<T>::ConstMatrix labels, 113 typename TTypes<T>::Matrix scratch, 114 typename TTypes<T>::Vec loss, 115 typename TTypes<T>::Matrix backprop) { 116 XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast, 117 logits, labels, scratch, loss, backprop); 118 } 119 }; 120 121 template <typename T> 122 struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {}; 123 124 } // namespace functor 125 126 #define REGISTER_CPU(T) \ 127 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \ 128 .Device(DEVICE_CPU) \ 129 .TypeConstraint<T>("T"), \ 130 SoftmaxXentWithLogitsOp<CPUDevice, T>); 131 TF_CALL_half(REGISTER_CPU); 132 TF_CALL_float(REGISTER_CPU); 133 TF_CALL_double(REGISTER_CPU); 134 135 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ 136 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) 137 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 138 .Device(DEVICE_GPU) 139 .TypeConstraint<Eigen::half>("T"), 140 SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>); 141 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 142 .Device(DEVICE_GPU) 143 .TypeConstraint<float>("T"), 144 SoftmaxXentWithLogitsOp<GPUDevice, float>); 145 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 146 .Device(DEVICE_GPU) 147 .TypeConstraint<double>("T"), 148 SoftmaxXentWithLogitsOp<GPUDevice, double>); 149 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 150 151 152 } // namespace tensorflow 153