1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #include "tensorflow/core/lib/bfloat16/bfloat16.h" 19 20 #include <math.h> 21 #include <algorithm> 22 #include <numeric> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/types.h" 28 29 #if GOOGLE_CUDA 30 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 31 #include "tensorflow/core/platform/cuda.h" 32 #endif // GOOGLE_CUDA 33 namespace tensorflow { 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 typedef Eigen::GpuDevice GPUDevice; 37 38 #if GOOGLE_CUDA 39 template <typename T> 40 struct CheckNumericsLaunch { 41 void Run(const GPUDevice& d, const T* data, int size, 42 int abnormal_detected[2]); 43 }; 44 45 extern template struct CheckNumericsLaunch<Eigen::half>; 46 extern template struct CheckNumericsLaunch<float>; 47 extern template struct CheckNumericsLaunch<double>; 48 #endif 49 50 namespace { 51 52 template <typename Device, typename T> 53 class CheckNumericsOp; 54 55 // Partial specialization for CPU 56 // TODO(jeff,rmlarsen): We should make this variant be an AsyncOpKernel, as 57 // was done for the GPU case below. 58 template <typename T> 59 class CheckNumericsOp<CPUDevice, T> : public OpKernel { 60 public: CheckNumericsOp(OpKernelConstruction * context)61 explicit CheckNumericsOp(OpKernelConstruction* context) : OpKernel(context) { 62 // message_ is used as the prefix for the assertion error message. For 63 // instance, this can be the name of the input op that produced the tensor. 64 OP_REQUIRES_OK(context, context->GetAttr("message", &message_)); 65 } 66 Compute(OpKernelContext * context)67 void Compute(OpKernelContext* context) override { 68 // pass along the input to the output 69 context->set_output(0, context->input(0)); 70 71 auto in = context->input(0).flat<T>(); 72 const T* data = in.data(); 73 const int64 size = in.size(); 74 // Check to see if any element of the tensor is NaN or Inf. 75 int fp_props = 76 std::accumulate(data, data + size, 0, [](const int& x, const T& y) { 77 int result = x; 78 if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { 79 // Do nothing: common case 80 } else if (Eigen::numext::isinf(y)) { 81 result |= kInfBit; 82 } else if (Eigen::numext::isnan(y)) { 83 result |= kNaNBit; 84 } 85 return result; 86 }); 87 if (fp_props != 0) { 88 string status; 89 if ((fp_props & kInfBit) && (fp_props & kNaNBit)) { 90 status = "Inf and NaN"; 91 } else { 92 if (fp_props & kInfBit) { 93 status = "Inf"; 94 } 95 if (fp_props & kNaNBit) { 96 status = "NaN"; 97 } 98 } 99 if (!status.empty()) { 100 context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", 101 status, " values")); 102 } 103 } 104 } 105 106 private: 107 string message_; 108 static const int kInfBit = 0x01; 109 static const int kNaNBit = 0x02; 110 }; 111 112 #if GOOGLE_CUDA 113 // Partial specialization for GPU 114 template <typename T> 115 class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel { 116 public: 117 typedef GPUDevice Device; 118 CheckNumericsOp(OpKernelConstruction * context)119 explicit CheckNumericsOp(OpKernelConstruction* context) 120 : AsyncOpKernel(context) { 121 // message_ is used as the prefix for the assertion error message. For 122 // instance, this can be the name of the input op that produced the tensor. 123 OP_REQUIRES_OK(context, context->GetAttr("message", &message_)); 124 } 125 ComputeAsync(OpKernelContext * context,DoneCallback done)126 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 127 // pass along the input to the output 128 context->set_output(0, context->input(0)); 129 if (context->input(0).NumElements() == 0) { 130 done(); 131 return; 132 } 133 auto input = context->input(0).flat<T>(); 134 135 // Allocate and initialize the elements to hold the check results 136 const int abnormal_detected_size = 2; 137 Tensor abnormal_detected; 138 OP_REQUIRES_OK(context, context->allocate_temp( 139 DT_INT32, TensorShape({abnormal_detected_size}), 140 &abnormal_detected)); 141 142 auto* stream = context->op_device_context()->stream(); 143 OP_REQUIRES_ASYNC(context, stream != nullptr, 144 errors::Internal("No GPU stream available."), done); 145 146 se::DeviceMemoryBase abnormal_detected_ptr( 147 abnormal_detected.flat<int>().data(), 148 abnormal_detected.flat<int>().size()); 149 stream->ThenMemset32(&abnormal_detected_ptr, 0, 150 abnormal_detected.flat<int>().size() * sizeof(int)); 151 152 // Call the Cuda kernels for the numerical checks 153 const Device& d = context->eigen_device<Device>(); 154 CheckNumericsLaunch<T>().Run(d, input.data(), input.size(), 155 abnormal_detected.flat<int>().data()); 156 157 // Copy the results from device to host 158 AllocatorAttributes attr; 159 attr.set_on_host(true); 160 attr.set_gpu_compatible(true); 161 Tensor abnormal_detected_host; 162 OP_REQUIRES_OK_ASYNC( 163 context, 164 context->allocate_temp(DT_INT32, TensorShape({abnormal_detected_size}), 165 &abnormal_detected_host, attr), 166 done); 167 OP_REQUIRES_ASYNC( 168 context, 169 stream 170 ->ThenMemcpy(abnormal_detected_host.flat<int>().data(), 171 abnormal_detected_ptr, 172 abnormal_detected_size * sizeof(int)) 173 .ok(), 174 errors::Internal("cudaMemcpy from device to host failed"), done); 175 176 // We have observed crashes on some network stacks when not holding 177 // this tensor reference. 178 TensorReference abnormal_detected_ref(abnormal_detected); 179 auto check_cb = [this, stream, abnormal_detected_ref, 180 abnormal_detected_host, context, done]() { 181 se::cuda::ScopedActivateExecutorContext scoped_activation{ 182 stream->parent()}; 183 auto abnormal_detected_host_flat = abnormal_detected_host.flat<int>(); 184 int is_nan = abnormal_detected_host_flat(0); 185 int is_inf = abnormal_detected_host_flat(1); 186 abnormal_detected_ref.Unref(); 187 if (is_nan || is_inf) { 188 string status; 189 LOG(ERROR) << "abnormal_detected_host @" 190 << abnormal_detected_host_flat.data() << " = {" << is_nan 191 << ", " << is_inf << "} " << message_; 192 193 // Results should always be 1 or 0. If we see anything else then 194 // there has been some GPU memory corruption. 195 CHECK_GE(is_nan, 0); 196 CHECK_GE(is_inf, 0); 197 CHECK_LE(is_nan, 1); 198 CHECK_LE(is_inf, 1); 199 200 if (is_nan && is_inf) { 201 status = "Inf and NaN"; 202 } else if (is_nan) { 203 status = "NaN"; 204 } else if (is_inf) { 205 status = "Inf"; 206 } 207 context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", 208 status, " values")); 209 } 210 done(); 211 }; 212 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 213 stream, std::move(check_cb)); 214 } 215 216 private: 217 string message_; 218 }; 219 #endif // GOOGLE_CUDA 220 221 } // namespace 222 223 #define REGISTER_CPU_KERNEL(T) \ 224 REGISTER_KERNEL_BUILDER( \ 225 Name("CheckNumerics").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 226 CheckNumericsOp<CPUDevice, T>); 227 TF_CALL_half(REGISTER_CPU_KERNEL); 228 TF_CALL_bfloat16(REGISTER_CPU_KERNEL); 229 TF_CALL_float(REGISTER_CPU_KERNEL); 230 TF_CALL_double(REGISTER_CPU_KERNEL); 231 232 #if GOOGLE_CUDA 233 REGISTER_KERNEL_BUILDER( 234 Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), 235 CheckNumericsOp<GPUDevice, Eigen::half>); 236 REGISTER_KERNEL_BUILDER( 237 Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<float>("T"), 238 CheckNumericsOp<GPUDevice, float>); 239 REGISTER_KERNEL_BUILDER( 240 Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<double>("T"), 241 CheckNumericsOp<GPUDevice, double>); 242 #endif // GOOGLE_CUDA 243 244 } // namespace tensorflow 245