• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/kernels/assign_op.h"
26 #include "tensorflow/core/kernels/dense_update_functor.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 template <typename Device, typename T>
34 class AssignOpT : public AssignOp {
35  public:
36   using AssignOp::AssignOp;
37 
Copy(OpKernelContext * context,Tensor * lhs,const Tensor & rhs)38   void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override {
39     functor::DenseUpdate<Device, T, ASSIGN> copy;
40     copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>());
41   }
42 };
43 
44 // TODO(jeff): Get rid of use_exclusive_lock_ option
45 template <typename Device, typename T, DenseUpdateType OP>
46 class DenseUpdateOp : public OpKernel {
47  public:
DenseUpdateOp(OpKernelConstruction * context)48   explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) {
49     OP_REQUIRES_OK(context,
50                    context->GetAttr("use_locking", &use_exclusive_lock_));
51     const DataType dt = DataTypeToEnum<T>::v();
52     OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt},
53                                                     {MakeRefType(dt)}));
54   }
55 
Compute(OpKernelContext * context)56   void Compute(OpKernelContext* context) override {
57     // We always return the input ref.
58     context->forward_ref_input_to_ref_output(0, 0);
59 
60     if (use_exclusive_lock_) {
61       mutex_lock l(*context->input_ref_mutex(0));
62       DoUpdate(context);
63     } else {
64       DoUpdate(context);
65     }
66   }
67 
68  private:
DoUpdate(OpKernelContext * context)69   void DoUpdate(OpKernelContext* context) {
70     Tensor Tparams = context->mutable_input(0, use_exclusive_lock_);
71     const Tensor& Tupdate = context->input(1);
72     OP_REQUIRES(context, Tparams.IsInitialized(),
73                 errors::FailedPrecondition("Attempting to use uninitialized "
74                                            "parameters: ",
75                                            requested_input(0)));
76     OP_REQUIRES(
77         context, Tparams.IsSameSize(Tupdate),
78         errors::InvalidArgument("Parameters and update must be the same size"));
79 
80     functor::DenseUpdate<Device, T, OP> update_functor;
81     update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(),
82                    Tupdate.flat<T>());
83   }
84 
85   bool use_exclusive_lock_;
86 };
87 
88 typedef Eigen::ThreadPoolDevice CPUDevice;
89 typedef Eigen::GpuDevice GPUDevice;
90 #ifdef TENSORFLOW_USE_SYCL
91 typedef Eigen::SyclDevice SYCLDevice;
92 #endif  // TENSORFLOW_USE_SYCL
93 
94 #define REGISTER_KERNELS(type)                                     \
95   REGISTER_KERNEL_BUILDER(                                         \
96       Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
97       AssignOpT<CPUDevice, type>);
98 
99 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
100 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
101 // quint16 not included in QUANTZIED_TYPES
102 TF_CALL_quint16(REGISTER_KERNELS);
103 #undef REGISTER_KERNELS
104 
105 #if GOOGLE_CUDA
106 // Only register 'Assign' on GPU for the subset of types also supported by
107 // 'Variable' (see variable_ops.cc.)
108 #define REGISTER_GPU_KERNELS(type)                                 \
109   REGISTER_KERNEL_BUILDER(                                         \
110       Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
111       AssignOpT<GPUDevice, type>);
112 
113 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
114 TF_CALL_int64(REGISTER_GPU_KERNELS);
115 #undef REGISTER_GPU_KERNELS
116 #endif  // GOOGLE_CUDA
117 
118 #ifdef TENSORFLOW_USE_SYCL
119 #define REGISTER_SYCL_KERNELS(type)                                 \
120   REGISTER_KERNEL_BUILDER(                                          \
121       Name("Assign").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
122       AssignOpT<SYCLDevice, type>);
123 
124 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
125 #undef REGISTER_SYCL_KERNELS
126 #endif  // TENSORFLOW_USE_SYCL
127 
128 #define REGISTER_KERNELS(type)                                        \
129   REGISTER_KERNEL_BUILDER(                                            \
130       Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
131       DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>);          \
132   REGISTER_KERNEL_BUILDER(                                            \
133       Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
134       DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>);
135 
136 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
137 #undef REGISTER_KERNELS
138 
139 #if GOOGLE_CUDA
140 #define REGISTER_GPU_KERNELS(type)                                    \
141   REGISTER_KERNEL_BUILDER(                                            \
142       Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
143       DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>);          \
144   REGISTER_KERNEL_BUILDER(                                            \
145       Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
146       DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>);
147 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
148 TF_CALL_int64(REGISTER_GPU_KERNELS);
149 #undef REGISTER_GPU_KERNELS
150 #endif  // end GOOGLE_CUDA
151 
152 #ifdef TENSORFLOW_USE_SYCL
153 #define REGISTER_SYCL_KERNELS(type)                                    \
154   REGISTER_KERNEL_BUILDER(                                             \
155       Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
156       DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>);          \
157   REGISTER_KERNEL_BUILDER(                                             \
158       Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
159       DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
160 
161 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
162 #undef REGISTER_SYCL_KERNELS
163 #endif  // TENSORFLOW_USE_SYCL
164 }  // namespace tensorflow
165