• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/kernels/assign_op.h"
26 #include "tensorflow/core/kernels/dense_update_functor.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 template <typename Device, typename T>
34 class AssignOpT : public AssignOp {
35  public:
36   using AssignOp::AssignOp;
37 
Copy(OpKernelContext * context,Tensor * lhs,const Tensor & rhs)38   void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override {
39     functor::DenseUpdate<Device, T, ASSIGN> copy;
40     copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>());
41   }
42 };
43 
44 // TODO(jeff): Get rid of use_exclusive_lock_ option
45 template <typename Device, typename T, DenseUpdateType OP>
46 class DenseUpdateOp : public OpKernel {
47  public:
DenseUpdateOp(OpKernelConstruction * context)48   explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) {
49     OP_REQUIRES_OK(context,
50                    context->GetAttr("use_locking", &use_exclusive_lock_));
51     const DataType dt = DataTypeToEnum<T>::v();
52     OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt},
53                                                     {MakeRefType(dt)}));
54   }
55 
Compute(OpKernelContext * context)56   void Compute(OpKernelContext* context) override {
57     // We always return the input ref.
58     context->forward_ref_input_to_ref_output(0, 0);
59 
60     if (use_exclusive_lock_) {
61       mutex_lock l(*context->input_ref_mutex(0));
62       DoUpdate(context);
63     } else {
64       DoUpdate(context);
65     }
66   }
67 
68  private:
DoUpdate(OpKernelContext * context)69   void DoUpdate(OpKernelContext* context) {
70     Tensor Tparams = context->mutable_input(0, use_exclusive_lock_);
71     const Tensor& Tupdate = context->input(1);
72     OP_REQUIRES(context, Tparams.IsInitialized(),
73                 errors::FailedPrecondition("Attempting to use uninitialized "
74                                            "parameters: ",
75                                            requested_input(0)));
76     OP_REQUIRES(
77         context, Tparams.IsSameSize(Tupdate),
78         errors::InvalidArgument("Parameters and update must be the same size"));
79 
80     functor::DenseUpdate<Device, T, OP> update_functor;
81     update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(),
82                    Tupdate.flat<T>());
83   }
84 
85   bool use_exclusive_lock_;
86 };
87 
88 typedef Eigen::ThreadPoolDevice CPUDevice;
89 typedef Eigen::GpuDevice GPUDevice;
90 
91 #define REGISTER_KERNELS(type)                                     \
92   REGISTER_KERNEL_BUILDER(                                         \
93       Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
94       AssignOpT<CPUDevice, type>);
95 
96 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
97 // uint32 not included in ALL_TYPES
98 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
99 // quint16 not included in QUANTIZIED_TYPES
100 TF_CALL_quint16(REGISTER_KERNELS);
101 #undef REGISTER_KERNELS
102 
103 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
104 // Only register 'Assign' on GPU for the subset of types also supported by
105 // 'Variable' (see variable_ops.cc.)
106 #define REGISTER_GPU_KERNELS(type)                                 \
107   REGISTER_KERNEL_BUILDER(                                         \
108       Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
109       AssignOpT<GPUDevice, type>);
110 
111 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
112 TF_CALL_int64(REGISTER_GPU_KERNELS);
113 TF_CALL_uint32(REGISTER_GPU_KERNELS);
114 TF_CALL_uint8(REGISTER_GPU_KERNELS);
115 #undef REGISTER_GPU_KERNELS
116 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
117 
118 
119 #define REGISTER_KERNELS(type)                                        \
120   REGISTER_KERNEL_BUILDER(                                            \
121       Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
122       DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>);          \
123   REGISTER_KERNEL_BUILDER(                                            \
124       Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
125       DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>);
126 
127 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
128 #undef REGISTER_KERNELS
129 
130 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
131 #define REGISTER_GPU_KERNELS(type)                                    \
132   REGISTER_KERNEL_BUILDER(                                            \
133       Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
134       DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>);          \
135   REGISTER_KERNEL_BUILDER(                                            \
136       Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
137       DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>);
138 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
139 TF_CALL_int64(REGISTER_GPU_KERNELS);
140 TF_CALL_uint8(REGISTER_GPU_KERNELS);
141 #undef REGISTER_GPU_KERNELS
142 #endif  // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
143 
144 }  // namespace tensorflow
145