1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/kernels/assign_op.h" 26 #include "tensorflow/core/kernels/dense_update_functor.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 template <typename Device, typename T> 34 class AssignOpT : public AssignOp { 35 public: 36 using AssignOp::AssignOp; 37 Copy(OpKernelContext * context,Tensor * lhs,const Tensor & rhs)38 void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override { 39 functor::DenseUpdate<Device, T, ASSIGN> copy; 40 copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>()); 41 } 42 }; 43 44 // TODO(jeff): Get rid of use_exclusive_lock_ option 45 template <typename Device, typename T, DenseUpdateType OP> 46 class DenseUpdateOp : public OpKernel { 47 public: DenseUpdateOp(OpKernelConstruction * context)48 explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) { 49 OP_REQUIRES_OK(context, 50 context->GetAttr("use_locking", &use_exclusive_lock_)); 51 const DataType dt = DataTypeToEnum<T>::v(); 52 OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt}, 53 {MakeRefType(dt)})); 54 } 55 Compute(OpKernelContext * context)56 void Compute(OpKernelContext* context) override { 57 // We always return the input ref. 58 context->forward_ref_input_to_ref_output(0, 0); 59 60 if (use_exclusive_lock_) { 61 mutex_lock l(*context->input_ref_mutex(0)); 62 DoUpdate(context); 63 } else { 64 DoUpdate(context); 65 } 66 } 67 68 private: DoUpdate(OpKernelContext * context)69 void DoUpdate(OpKernelContext* context) { 70 Tensor Tparams = context->mutable_input(0, use_exclusive_lock_); 71 const Tensor& Tupdate = context->input(1); 72 OP_REQUIRES(context, Tparams.IsInitialized(), 73 errors::FailedPrecondition("Attempting to use uninitialized " 74 "parameters: ", 75 requested_input(0))); 76 OP_REQUIRES( 77 context, Tparams.IsSameSize(Tupdate), 78 errors::InvalidArgument("Parameters and update must be the same size")); 79 80 functor::DenseUpdate<Device, T, OP> update_functor; 81 update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(), 82 Tupdate.flat<T>()); 83 } 84 85 bool use_exclusive_lock_; 86 }; 87 88 typedef Eigen::ThreadPoolDevice CPUDevice; 89 typedef Eigen::GpuDevice GPUDevice; 90 #ifdef TENSORFLOW_USE_SYCL 91 typedef Eigen::SyclDevice SYCLDevice; 92 #endif // TENSORFLOW_USE_SYCL 93 94 #define REGISTER_KERNELS(type) \ 95 REGISTER_KERNEL_BUILDER( \ 96 Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 97 AssignOpT<CPUDevice, type>); 98 99 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 100 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); 101 // quint16 not included in QUANTZIED_TYPES 102 TF_CALL_quint16(REGISTER_KERNELS); 103 #undef REGISTER_KERNELS 104 105 #if GOOGLE_CUDA 106 // Only register 'Assign' on GPU for the subset of types also supported by 107 // 'Variable' (see variable_ops.cc.) 108 #define REGISTER_GPU_KERNELS(type) \ 109 REGISTER_KERNEL_BUILDER( \ 110 Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 111 AssignOpT<GPUDevice, type>); 112 113 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); 114 TF_CALL_int64(REGISTER_GPU_KERNELS); 115 #undef REGISTER_GPU_KERNELS 116 #endif // GOOGLE_CUDA 117 118 #ifdef TENSORFLOW_USE_SYCL 119 #define REGISTER_SYCL_KERNELS(type) \ 120 REGISTER_KERNEL_BUILDER( \ 121 Name("Assign").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 122 AssignOpT<SYCLDevice, type>); 123 124 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); 125 #undef REGISTER_SYCL_KERNELS 126 #endif // TENSORFLOW_USE_SYCL 127 128 #define REGISTER_KERNELS(type) \ 129 REGISTER_KERNEL_BUILDER( \ 130 Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 131 DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>); \ 132 REGISTER_KERNEL_BUILDER( \ 133 Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 134 DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>); 135 136 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 137 #undef REGISTER_KERNELS 138 139 #if GOOGLE_CUDA 140 #define REGISTER_GPU_KERNELS(type) \ 141 REGISTER_KERNEL_BUILDER( \ 142 Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 143 DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>); \ 144 REGISTER_KERNEL_BUILDER( \ 145 Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 146 DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>); 147 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); 148 TF_CALL_int64(REGISTER_GPU_KERNELS); 149 #undef REGISTER_GPU_KERNELS 150 #endif // end GOOGLE_CUDA 151 152 #ifdef TENSORFLOW_USE_SYCL 153 #define REGISTER_SYCL_KERNELS(type) \ 154 REGISTER_KERNEL_BUILDER( \ 155 Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 156 DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \ 157 REGISTER_KERNEL_BUILDER( \ 158 Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 159 DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>); 160 161 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS); 162 #undef REGISTER_SYCL_KERNELS 163 #endif // TENSORFLOW_USE_SYCL 164 } // namespace tensorflow 165