1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 #define EIGEN_USE_GPU 20 #endif 21 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/kernels/assign_op.h" 26 #include "tensorflow/core/kernels/dense_update_functor.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 template <typename Device, typename T> 34 class AssignOpT : public AssignOp { 35 public: 36 using AssignOp::AssignOp; 37 Copy(OpKernelContext * context,Tensor * lhs,const Tensor & rhs)38 void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override { 39 functor::DenseUpdate<Device, T, ASSIGN> copy; 40 copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>()); 41 } 42 }; 43 44 // TODO(jeff): Get rid of use_exclusive_lock_ option 45 template <typename Device, typename T, DenseUpdateType OP> 46 class DenseUpdateOp : public OpKernel { 47 public: DenseUpdateOp(OpKernelConstruction * context)48 explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) { 49 OP_REQUIRES_OK(context, 50 context->GetAttr("use_locking", &use_exclusive_lock_)); 51 const DataType dt = DataTypeToEnum<T>::v(); 52 OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt}, 53 {MakeRefType(dt)})); 54 } 55 Compute(OpKernelContext * context)56 void Compute(OpKernelContext* context) override { 57 // We always return the input ref. 58 context->forward_ref_input_to_ref_output(0, 0); 59 60 if (use_exclusive_lock_) { 61 mutex_lock l(*context->input_ref_mutex(0)); 62 DoUpdate(context); 63 } else { 64 DoUpdate(context); 65 } 66 } 67 68 private: DoUpdate(OpKernelContext * context)69 void DoUpdate(OpKernelContext* context) { 70 Tensor Tparams = context->mutable_input(0, use_exclusive_lock_); 71 const Tensor& Tupdate = context->input(1); 72 OP_REQUIRES(context, Tparams.IsInitialized(), 73 errors::FailedPrecondition("Attempting to use uninitialized " 74 "parameters: ", 75 requested_input(0))); 76 OP_REQUIRES( 77 context, Tparams.IsSameSize(Tupdate), 78 errors::InvalidArgument("Parameters and update must be the same size")); 79 80 functor::DenseUpdate<Device, T, OP> update_functor; 81 update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(), 82 Tupdate.flat<T>()); 83 } 84 85 bool use_exclusive_lock_; 86 }; 87 88 typedef Eigen::ThreadPoolDevice CPUDevice; 89 typedef Eigen::GpuDevice GPUDevice; 90 91 #define REGISTER_KERNELS(type) \ 92 REGISTER_KERNEL_BUILDER( \ 93 Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 94 AssignOpT<CPUDevice, type>); 95 96 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 97 // uint32 not included in ALL_TYPES 98 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); 99 // quint16 not included in QUANTIZIED_TYPES 100 TF_CALL_quint16(REGISTER_KERNELS); 101 #undef REGISTER_KERNELS 102 103 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 104 // Only register 'Assign' on GPU for the subset of types also supported by 105 // 'Variable' (see variable_ops.cc.) 106 #define REGISTER_GPU_KERNELS(type) \ 107 REGISTER_KERNEL_BUILDER( \ 108 Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 109 AssignOpT<GPUDevice, type>); 110 111 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); 112 TF_CALL_int64(REGISTER_GPU_KERNELS); 113 TF_CALL_uint32(REGISTER_GPU_KERNELS); 114 #undef REGISTER_GPU_KERNELS 115 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 116 117 118 #define REGISTER_KERNELS(type) \ 119 REGISTER_KERNEL_BUILDER( \ 120 Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 121 DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>); \ 122 REGISTER_KERNEL_BUILDER( \ 123 Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 124 DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>); 125 126 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 127 #undef REGISTER_KERNELS 128 129 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 130 #define REGISTER_GPU_KERNELS(type) \ 131 REGISTER_KERNEL_BUILDER( \ 132 Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 133 DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>); \ 134 REGISTER_KERNEL_BUILDER( \ 135 Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 136 DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>); 137 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); 138 TF_CALL_int64(REGISTER_GPU_KERNELS); 139 #undef REGISTER_GPU_KERNELS 140 #endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM 141 142 } // namespace tensorflow 143