• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
25 #include "tensorflow/core/common_runtime/device/device_id.h"
26 #include "tensorflow/core/common_runtime/device/device_id_manager.h"
27 #include "tensorflow/core/common_runtime/device/device_id_utils.h"
28 #include "tensorflow/core/common_runtime/local_device.h"
29 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h"
30 #include "tensorflow/core/common_runtime/shared_counter.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/device_base.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/gtl/inlined_vector.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/public/session_options.h"
41 
42 namespace tensorflow {
43 
44 class PluggableDevice : public LocalDevice {
45  public:
46   PluggableDevice(const SessionOptions& options, const std::string& name,
47                   const string& device_type, const string& platform_name,
48                   Bytes memory_limit, const DeviceLocality& locality,
49                   TfDeviceId tf_device_id,
50                   const std::string& physical_device_desc,
51                   Allocator* device_allocator, Allocator* cpu_allocator,
52                   bool sync_every_op);
53 
54   ~PluggableDevice() override;
55 
56   // Initialize the device and return the status of initialization.
57   Status Init(const SessionOptions& options);
58 
59   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
60                     AsyncOpKernel::DoneCallback done) override;
61 
62   void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
63 
64   Status Sync() override;
65 
66   Allocator* GetAllocator(AllocatorAttributes attr) override;
67 
68   Status MakeTensorFromProto(const TensorProto& tensor_proto,
69                              const AllocatorAttributes alloc_attrs,
70                              Tensor* tensor) override;
71 
72   void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
73                               const DeviceContext* device_context,
74                               StatusCallback done) override;
75 
76   // The executor that provides control for the pluggable device;
executor()77   se::StreamExecutor* executor() const { return executor_; }
78 
79  private:
80   Allocator* device_allocator_;
81   Allocator* cpu_allocator_;
82 
83   se::StreamExecutor* executor_ = nullptr;
84   struct StreamGroup {
85     se::Stream* compute = nullptr;
86     se::Stream* host_to_device = nullptr;
87     se::Stream* device_to_host = nullptr;
88     gtl::InlinedVector<se::Stream*, 4> device_to_device;
89   };
90 
91   class StreamGroupFactory;
92 
93   StreamGroup* stream_;
94   PluggableDeviceContext* device_context_;
95   // TODO(penpornk): Investigate renaming `GpuDeviceInfo` to `DeviceInfo`.
96   DeviceBase::AcceleratorDeviceInfo* pluggable_device_info_ = nullptr;
97   TfDeviceId tf_device_id_;
98   const string platform_name_;
99   const bool sync_every_op_ = false;
100   EventMgr* em_ = nullptr;
101   std::unique_ptr<thread::ThreadPool> thread_pool_;
102   bool force_gpu_compatible_ = false;
103   std::string ComputeOpKernelDebugString(const OpKernel& op_kernel,
104                                          const int stream_id);
105 
106   // This method returns an initialization status, in addition to
107   // calling the "done" StatusCallback, if there is a failure to
108   // allocate memory or if the tensor "from" is not DMA-copyable.
109   // If there is no error prior to enqueueing the copy, an OK status
110   // is returned.
111   Status MaybeCopyTensorToPluggableDevice(
112       const AllocatorAttributes& alloc_attrs, const Tensor& from, Tensor* to,
113       StatusCallback done);
114 };
115 
116 }  // namespace tensorflow
117 
118 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_
119