• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include <algorithm>
22 #include <list>
23 #include <map>
24 #include <tuple>
25 #include <vector>
26 
27 #include "tensorflow/core/common_runtime/device/device_event_mgr.h"
28 #include "tensorflow/core/common_runtime/device/device_id.h"
29 #include "tensorflow/core/common_runtime/device/device_id_manager.h"
30 #include "tensorflow/core/common_runtime/device/device_id_utils.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/local_device.h"
33 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h"
34 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h"
35 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h"
36 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h"
37 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h"
38 #include "tensorflow/core/framework/allocator.h"
39 #include "tensorflow/core/framework/device_base.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/variant_op_registry.h"
45 #include "tensorflow/core/graph/types.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/strings/numbers.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/stream_executor.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57 #include "tensorflow/core/util/env_var.h"
58 #include "tensorflow/core/util/stream_executor_util.h"
59 
60 namespace tensorflow {
61 
62 // This factory helps ensure that different PluggableDevice objects that
63 // refer to the same physical device and stream group id use the same stream
64 // group object (and therefore the same device streams). This is necessary since
65 // there is a single memory allocator per device (see
66 // ProcessState::GetPluggableDeviceAllocator) and allocators must not be shared
67 // across streams.
68 // TODO(penpornk): Consider refactoring StreamGroupFactory to
69 // common_runtime/device.
70 class PluggableDevice::StreamGroupFactory {
71  public:
72   // Returns the unique stream group for use with the stream defined by
73   // {tf_device_id, stream_group_within_device}, creating it if it does not yet
74   // exist.
75   // This function is thread safe.
GetOrCreate(const std::string & device_type,TfDeviceId tf_device_id,int stream_group_within_device,se::StreamExecutor * executor,const GPUOptions & options)76   PluggableDevice::StreamGroup* GetOrCreate(const std::string& device_type,
77                                             TfDeviceId tf_device_id,
78                                             int stream_group_within_device,
79                                             se::StreamExecutor* executor,
80                                             const GPUOptions& options) {
81     mutex_lock guard(lock_);
82     StreamGroup* group = &streams_[key_type(device_type, tf_device_id.value(),
83                                             stream_group_within_device)];
84     if (!group->compute) {
85       group->compute = new se::Stream(executor);
86       group->compute->Init();
87       VLOG(2) << "Created stream[" << stream_group_within_device
88               << "] = " << group->compute;
89 
90       group->host_to_device = new se::Stream(executor);
91       group->host_to_device->Init();
92       VLOG(2) << "Created host_to_device_stream[" << stream_group_within_device
93               << "] = " << group->host_to_device;
94 
95       group->device_to_host = new se::Stream(executor);
96       group->device_to_host->Init();
97       VLOG(2) << "Created device_to_host_stream[" << stream_group_within_device
98               << "] = " << group->device_to_host;
99 
100       int num_d2d_streams =
101           options.experimental().num_dev_to_dev_copy_streams();
102       if (num_d2d_streams == 0) num_d2d_streams = 1;
103       if (num_d2d_streams < 1 || num_d2d_streams > 4) {
104         LOG(ERROR)
105             << "Illegal GPUOptions.experimental.num_dev_to_dev_copy_streams="
106             << num_d2d_streams << " set to 1 instead.";
107         num_d2d_streams = 1;
108       }
109       for (int i = 0; i < num_d2d_streams; ++i) {
110         se::Stream* stream = new se::Stream(executor);
111         stream->Init();
112         group->device_to_device.push_back(stream);
113         VLOG(2) << "Created device_to_device_stream["
114                 << stream_group_within_device
115                 << "] = " << group->device_to_device.back();
116       }
117     }
118     return group;
119   }
120 
121   // Returns a reference to the StreamGroupFactory singleton. Note that this is
122   // never destroyed, so the objects it owns are never deleted.
Global()123   static StreamGroupFactory& Global() {
124     static StreamGroupFactory* instance = new StreamGroupFactory();
125     return *instance;
126   }
127 
128  private:
129   mutex lock_;
130   using key_type = std::tuple<std::string, int, int>;
131   std::map<key_type, StreamGroup> streams_;
132 
133   // StreamGroupFactory cannot be created directly; Call
134   // StreamGroupFactory::Global to get the global instance.
135   StreamGroupFactory() = default;
136   TF_DISALLOW_COPY_AND_ASSIGN(StreamGroupFactory);
137 };
138 
PluggableDevice(const SessionOptions & options,const std::string & name,const std::string & device_type,const std::string & platform_name,Bytes memory_limit,const DeviceLocality & locality,TfDeviceId tf_device_id,const std::string & physical_device_desc,Allocator * device_allocator,Allocator * cpu_allocator,bool sync_every_op)139 PluggableDevice::PluggableDevice(
140     const SessionOptions& options, const std::string& name,
141     const std::string& device_type, const std::string& platform_name,
142     Bytes memory_limit, const DeviceLocality& locality, TfDeviceId tf_device_id,
143     const std::string& physical_device_desc, Allocator* device_allocator,
144     Allocator* cpu_allocator, bool sync_every_op)
145     : LocalDevice(options, Device::BuildDeviceAttributes(
146                                name, device_type.c_str(), memory_limit,
147                                locality, physical_device_desc)),
148       device_allocator_(device_allocator),
149       cpu_allocator_(cpu_allocator),
150       tf_device_id_(tf_device_id),
151       platform_name_(platform_name),
152       sync_every_op_(sync_every_op) {
153   if (options.config.has_gpu_options()) {
154     force_gpu_compatible_ = options.config.gpu_options().force_gpu_compatible();
155   }
156   PluggableDeviceProcessState::singleton(device_type, platform_name)
157       ->EnablePluggableDevice();
158 }
159 
~PluggableDevice()160 PluggableDevice::~PluggableDevice() {
161   delete pluggable_device_info_;
162   device_context_->Unref();
163 }
164 
Init(const SessionOptions & options)165 Status PluggableDevice::Init(const SessionOptions& options) {
166   se::Platform* platform = PluggableDeviceMachineManager(platform_name_);
167   auto executor_status = DeviceIdUtil::ExecutorForTfDeviceId(
168       DeviceType(device_type()), platform, tf_device_id_);
169   if (!executor_status.status().ok()) {
170     return errors::Internal("Failed to get StreamExecutor for device",
171                             tf_device_id_.value());
172   }
173   executor_ = executor_status.ValueOrDie();
174 
175   em_ = EventMgrFactory::Singleton()->GetEventMgr(executor_,
176                                                   options.config.gpu_options());
177 
178   stream_ = StreamGroupFactory::Global().GetOrCreate(
179       device_type(), tf_device_id_, 0, executor_, options.config.gpu_options());
180   device_context_ = new PluggableDeviceContext(
181       0, stream_->compute, stream_->host_to_device, stream_->device_to_host,
182       stream_->device_to_device);
183   pluggable_device_info_ = new DeviceBase::AcceleratorDeviceInfo;
184   pluggable_device_info_->stream = stream_->compute;
185   pluggable_device_info_->default_context = device_context_;
186   pluggable_device_info_->event_mgr = em_;
187   PlatformDeviceId platform_device_id;
188   TF_RETURN_IF_ERROR(DeviceIdManager::TfToPlatformDeviceId(
189       DeviceType(device_type()), tf_device_id_, &platform_device_id));
190   pluggable_device_info_->gpu_id = platform_device_id.value();
191   set_tensorflow_accelerator_device_info(pluggable_device_info_);
192 
193   // Whether and how the PluggableDevice uses its own threadpool.
194   // This option is experimental. Once we confirm the best setting, we
195   // may change the default behavior and completely remove this flag.
196   // Default values might change in future releases.
197   // Possible values:
198   //   * global: PluggableDevice uses threads shared with CPU in the main
199   //       compute thread-pool. This is currently the default.
200   //   * gpu_private: PluggableDevice uses threads dedicated to this device.
201   //   * gpu_shared: All PluggableDevices share a dedicated thread pool.
202 
203   // TODO(penpornk): Read the following configurations from a PluggableDevice
204   // callback instead of GPU environment variables: TF_GPU_THREAD_MODE,
205   // TF_GPU_THREAD_COUNT, TF_FORCE_GPU_ALLOC_GROWTH,
206   // TF_ENABLE_GPU_GARBAGE_COLLECTION, and TF_GPU_HOST_MEM_LIMIT_IN_MB.
207   string device_thread_mode;
208   TF_RETURN_IF_ERROR(ReadStringFromEnvVar("TF_GPU_THREAD_MODE", "global",
209                                           &device_thread_mode));
210   device_thread_mode = absl::AsciiStrToLower(device_thread_mode);
211   if (device_thread_mode != "global") {
212     int64_t device_thread_count = -1;
213     // Default to two threads. One for device compute and another for memory
214     // copies.
215     TF_RETURN_IF_ERROR(
216         ReadInt64FromEnvVar("TF_GPU_THREAD_COUNT", 2, &device_thread_count));
217     if (device_thread_mode == "gpu_private") {
218       thread_pool_.reset(new thread::ThreadPool(
219           options.env, ThreadOptions(),
220           strings::StrCat("gpu_private_", tf_device_id_.value()),
221           static_cast<int32>(device_thread_count),
222           !options.config.experimental().disable_thread_spinning(),
223           /*allocator=*/nullptr));
224       set_tensorflow_device_thread_pool(thread_pool_.get());
225     } else if (device_thread_mode == "gpu_shared") {
226       static thread::ThreadPool* thread_pool = new thread::ThreadPool(
227           options.env, ThreadOptions(), "gpu_shared",
228           static_cast<int32>(device_thread_count),
229           !options.config.experimental().disable_thread_spinning(),
230           /*allocator=*/nullptr);
231       set_tensorflow_device_thread_pool(thread_pool);
232     } else {
233       string error_message =
234           strings::StrCat("Invalid gpu_thread_mode: ", device_thread_mode);
235       LOG(WARNING) << error_message;
236       return errors::InvalidArgument(error_message);
237     }
238   }
239 
240   return OkStatus();
241 }
242 
GetAllocator(AllocatorAttributes attr)243 Allocator* PluggableDevice::GetAllocator(AllocatorAttributes attr) {
244   DCHECK(cpu_allocator_) << "CPU allocator must be set";
245   if (attr.on_host()) {
246     if (attr.gpu_compatible() || force_gpu_compatible_) {
247       PluggableDeviceProcessState* ps =
248           PluggableDeviceProcessState::singleton(device_type(), platform_name_);
249       return ps->GetPluggableDeviceHostAllocator(0);
250     } else {
251       return cpu_allocator_;
252     }
253   } else {
254     return device_allocator_;
255   }
256 }
257 
ComputeOpKernelDebugString(const OpKernel & op_kernel,const int stream_id)258 string PluggableDevice::ComputeOpKernelDebugString(const OpKernel& op_kernel,
259                                                    const int stream_id) {
260   return strings::StrCat(op_kernel.name(), " op ", op_kernel.type_string(),
261                          " on ", platform_name_, tf_device_id_.value(),
262                          " stream[", stream_id, "]");
263 }
264 
Compute(OpKernel * op_kernel,OpKernelContext * context)265 void PluggableDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
266   PluggableDeviceContext* pluggable_device_context = device_context_;
267   if (context->op_device_context() != nullptr) {
268     pluggable_device_context =
269         static_cast<PluggableDeviceContext*>(context->op_device_context());
270   }
271   const auto stream_id = pluggable_device_context->stream_id();
272 
273   const bool vlog_1 = VLOG_IS_ON(1);
274 
275   if (vlog_1) {
276     VLOG(1) << "PluggableDevice::ComputeHelper "
277             << ComputeOpKernelDebugString(*op_kernel, stream_id);
278   }
279 
280   op_kernel->Compute(context);
281   if (context->status().ok()) {
282     if (sync_every_op_) {
283       context->SetStatus(PluggableDeviceUtil::Sync(this));
284       if (vlog_1) {
285         VLOG(1) << "PluggableDevice::ComputeHelper finished"
286                 << ComputeOpKernelDebugString(*op_kernel, stream_id);
287       }
288     } else if (vlog_1) {
289       VLOG(1) << "PluggableDevice::ComputeHelper scheduled"
290               << ComputeOpKernelDebugString(*op_kernel, stream_id);
291     }
292   } else {
293     if (vlog_1) {
294       VLOG(1) << "PluggableDevice::ComputeHelper failed to schedule"
295               << ComputeOpKernelDebugString(*op_kernel, stream_id);
296     }
297   }
298 }
299 
300 // Based on the semantics of Device::Sync, this call should wait for
301 // all streams not just the current one.
Sync()302 Status PluggableDevice::Sync() { return PluggableDeviceUtil::SyncAll(this); }
303 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)304 void PluggableDevice::ComputeAsync(AsyncOpKernel* op_kernel,
305                                    OpKernelContext* context,
306                                    AsyncOpKernel::DoneCallback done) {
307   PluggableDeviceContext* device_context = device_context_;
308   if (context->op_device_context() != nullptr) {
309     device_context =
310         static_cast<PluggableDeviceContext*>(context->op_device_context());
311   }
312   const auto stream_id = device_context->stream_id();
313 
314   VLOG(1) << "PluggableDevice::ComputeAsync " << op_kernel->name() << " op "
315           << op_kernel->type_string() << " on " << device_type()
316           << tf_device_id_ << " stream[" << stream_id << "]";
317   op_kernel->ComputeAsync(context, std::move(done));
318 }
319 
MaybeCopyTensorToPluggableDevice(const AllocatorAttributes & alloc_attrs,const Tensor & from,Tensor * to,StatusCallback done)320 Status PluggableDevice::MaybeCopyTensorToPluggableDevice(
321     const AllocatorAttributes& alloc_attrs, const Tensor& from, Tensor* to,
322     StatusCallback done) {
323   if (alloc_attrs.on_host()) {
324     *to = from;
325     done(OkStatus());
326     return OkStatus();
327   } else {
328     if (!DMAHelper::CanUseDMA(&from)) {
329       Status err = errors::Internal("PluggableDevice copy from non-DMA ",
330                                     DataTypeString(from.dtype()), " tensor");
331       done(err);
332       return err;
333     }
334     AllocationAttributes allocation_attr;
335     auto* copy = new Tensor(GetAllocator(alloc_attrs), from.dtype(),
336                             from.shape(), allocation_attr);
337 
338     // If the tensor is not initialized, we likely ran out of memory.
339     if (!copy->IsInitialized()) {
340       delete copy;
341       Status err = errors::ResourceExhausted(
342           "OOM when allocating tensor of shape ", from.shape().DebugString(),
343           " and type ", DataTypeString(from.dtype()));
344       done(err);
345       return err;
346     }
347 
348     auto wrapped_done = [to, copy, done = std::move(done)](const Status& s) {
349       if (s.ok()) {
350         *to = std::move(*copy);
351       }
352       delete copy;
353       done(s);
354     };
355 
356     device_context_->CopyCPUTensorToDevice(
357         &from, this, copy, std::move(wrapped_done), false /*sync_dst_compute*/);
358     return OkStatus();
359   }
360 }
361 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)362 Status PluggableDevice::MakeTensorFromProto(
363     const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
364     Tensor* tensor) {
365   AllocatorAttributes attr;
366   attr.set_on_host(true);
367   attr.set_gpu_compatible(true);
368   Allocator* host_alloc = GetAllocator(attr);
369   Tensor parsed(tensor_proto.dtype());
370   if (!parsed.FromProto(host_alloc, tensor_proto)) {
371     return errors::InvalidArgument("Cannot parse tensor from proto: ",
372                                    tensor_proto.DebugString());
373   }
374 
375   if (parsed.dtype() == DT_VARIANT) {
376     const Variant* from = parsed.flat<Variant>().data();
377     int numa_node = attributes().locality().numa_node();
378     Tensor copy(cpu_allocator(numa_node), DT_VARIANT, parsed.shape());
379     Variant* copy_variant = copy.flat<Variant>().data();
380 
381     std::list<Notification> notifications;
382     Status copy_status;
383     auto copier = [this, &alloc_attrs, &notifications, &copy_status](
384                       const Tensor& from, Tensor* to) {
385       // Copier isn't run in a multithreaded environment, so we don't
386       // have to worry about the notifications list being modified in parallel.
387       notifications.emplace_back();
388       Notification& n = *notifications.rbegin();
389       return MaybeCopyTensorToPluggableDevice(
390           alloc_attrs, from, to, [&n, &copy_status](const Status& s) {
391             if (copy_status.ok()) {
392               copy_status.Update(s);
393             }
394             n.Notify();
395           });
396     };
397     Status s;
398     for (int64_t ix = 0; ix < parsed.NumElements(); ++ix) {
399       s = VariantDeviceCopy(VariantDeviceCopyDirection::HOST_TO_DEVICE,
400                             from[ix], &copy_variant[ix], copier);
401       if (!s.ok()) {
402         break;
403       }
404     }
405     for (auto& n : notifications) {
406       n.WaitForNotification();
407     }
408     if (!s.ok()) {
409       return s;
410     }
411     *tensor = std::move(copy);
412     return copy_status;
413   } else {
414     Notification n;
415     Status status;
416     TF_RETURN_IF_ERROR(MaybeCopyTensorToPluggableDevice(
417         alloc_attrs, parsed, tensor, [&n, &status](const Status& s) {
418           status = s;
419           n.Notify();
420         }));
421     n.WaitForNotification();
422     return status;
423   }
424 }
425 
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)426 void PluggableDevice::CopyTensorInSameDevice(
427     const Tensor* input_tensor, Tensor* output_tensor,
428     const DeviceContext* device_context, StatusCallback done) {
429   PluggableDeviceUtil::CopyPluggableDeviceTensorToSameDevice(
430       static_cast<Device*>(this), device_context, input_tensor, output_tensor,
431       std::move(done));
432 }
433 
434 }  // namespace tensorflow
435