1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/kernels/assign_op.h" 26 #include "tensorflow/core/kernels/dense_update_functor.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 template <typename Device, typename T> 34 class AssignOpT : public AssignOp { 35 public: 36 using AssignOp::AssignOp; 37 Copy(OpKernelContext * context,Tensor * lhs,const Tensor & rhs)38 void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override { 39 functor::DenseUpdate<Device, T, ASSIGN> copy; 40 copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>()); 41 } 42 }; 43 44 // TODO(jeff): Get rid of use_exclusive_lock_ option 45 template <typename Device, typename T, DenseUpdateType OP> 46 class DenseUpdateOp : public OpKernel { 47 public: DenseUpdateOp(OpKernelConstruction * context)48 explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) { 49 OP_REQUIRES_OK(context, 50 context->GetAttr("use_locking", &use_exclusive_lock_)); 51 const DataType dt = DataTypeToEnum<T>::v(); 52 OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt}, 53 {MakeRefType(dt)})); 54 } 55 Compute(OpKernelContext * context)56 void Compute(OpKernelContext* context) override { 57 // We always return the input ref. 58 context->forward_ref_input_to_ref_output(0, 0); 59 60 if (use_exclusive_lock_) { 61 mutex_lock l(*context->input_ref_mutex(0)); 62 DoUpdate(context); 63 } else { 64 DoUpdate(context); 65 } 66 } 67 68 private: DoUpdate(OpKernelContext * context)69 void DoUpdate(OpKernelContext* context) { 70 Tensor Tparams = context->mutable_input(0, use_exclusive_lock_); 71 const Tensor& Tupdate = context->input(1); 72 OP_REQUIRES(context, Tparams.IsInitialized(), 73 errors::FailedPrecondition("Attempting to use uninitialized " 74 "parameters: ", 75 requested_input(0))); 76 OP_REQUIRES( 77 context, Tparams.IsSameSize(Tupdate), 78 errors::InvalidArgument("Parameters and update must be the same size")); 79 80 functor::DenseUpdate<Device, T, OP> update_functor; 81 update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(), 82 Tupdate.flat<T>()); 83 } 84 85 bool use_exclusive_lock_; 86 }; 87 88 typedef Eigen::ThreadPoolDevice CPUDevice; 89 typedef Eigen::GpuDevice GPUDevice; 90 91 #define REGISTER_KERNELS(type) \ 92 REGISTER_KERNEL_BUILDER( \ 93 Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 94 AssignOpT<CPUDevice, type>); 95 96 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 97 // uint32 not included in ALL_TYPES 98 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); 99 // quint16 not included in QUANTIZIED_TYPES 100 TF_CALL_quint16(REGISTER_KERNELS); 101 #undef REGISTER_KERNELS 102 103 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 104 // Only register 'Assign' on GPU for the subset of types also supported by 105 // 'Variable' (see variable_ops.cc.) 106 #define REGISTER_GPU_KERNELS(type) \ 107 REGISTER_KERNEL_BUILDER( \ 108 Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 109 AssignOpT<GPUDevice, type>); 110 111 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); 112 TF_CALL_int64(REGISTER_GPU_KERNELS); 113 TF_CALL_uint32(REGISTER_GPU_KERNELS); 114 TF_CALL_uint8(REGISTER_GPU_KERNELS); 115 #undef REGISTER_GPU_KERNELS 116 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 117 118 119 #define REGISTER_KERNELS(type) \ 120 REGISTER_KERNEL_BUILDER( \ 121 Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 122 DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>); \ 123 REGISTER_KERNEL_BUILDER( \ 124 Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 125 DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>); 126 127 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 128 #undef REGISTER_KERNELS 129 130 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 131 #define REGISTER_GPU_KERNELS(type) \ 132 REGISTER_KERNEL_BUILDER( \ 133 Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 134 DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>); \ 135 REGISTER_KERNEL_BUILDER( \ 136 Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 137 DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>); 138 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); 139 TF_CALL_int64(REGISTER_GPU_KERNELS); 140 TF_CALL_uint8(REGISTER_GPU_KERNELS); 141 #undef REGISTER_GPU_KERNELS 142 #endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM 143 144 } // namespace tensorflow 145