• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26 #include "tensorflow/stream_executor/stream.h"
27 
28 namespace xla {
29 
LocalDeviceState(se::StreamExecutor * executor,LocalClient * client,AllocationModel allocation_model,bool asynchronous,bool allow_event_reuse)30 LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
31                                    LocalClient* client,
32                                    AllocationModel allocation_model,
33                                    bool asynchronous, bool allow_event_reuse)
34     : allocation_model_(allocation_model),
35       event_pool_(allow_event_reuse),
36       compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
37       executor_(executor),
38       client_(client),
39       prng_seed_generator_(prng_seed_device_()),
40       prng_seed_distribution_(std::numeric_limits<int>::min(),
41                               std::numeric_limits<int>::max()) {
42   compute_stream_ = absl::make_unique<se::Stream>(executor);
43   host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
44   callback_stream_ = absl::make_unique<se::Stream>(executor);
45   compute_stream_->Init();
46   host_to_device_stream_->Init();
47   callback_stream_->Init();
48   device_to_host_streams_.reserve(kNumDeviceToHostStreams);
49   for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
50     auto stream = absl::make_unique<se::Stream>(executor);
51     stream->Init();
52     device_to_host_streams_.push_back(std::move(stream));
53   }
54   device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
55   for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
56     auto stream = absl::make_unique<se::Stream>(executor);
57     stream->Init();
58     device_to_device_streams_.push_back(std::move(stream));
59   }
60   execute_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
61                                                     "py_xla_execute");
62   callback_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
63                                                      "py_xla_callback");
64 }
65 
~LocalDeviceState()66 LocalDeviceState::~LocalDeviceState() {
67   Status status = SynchronizeAllActivity();
68   if (!status.ok()) {
69     LOG(ERROR) << "Error when closing device: " << status;
70   }
71 }
72 
SynchronizeAllActivity()73 Status LocalDeviceState::SynchronizeAllActivity() {
74   Status status;
75   // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
76   // suffice. However on the Host platform SynchronizeAllActivity is a dummy
77   // implementation that doesn't actually block. To make sure activity has
78   // stopped, also block on the compute stream. If SynchronizeAllActivity is
79   // fixed, we could remove the BlockHostUntilDone call.
80   status.Update(compute_stream_->BlockHostUntilDone());
81   status.Update(callback_stream_->BlockHostUntilDone());
82   bool ok = compute_stream_->parent()->SynchronizeAllActivity();
83   if (!ok) {
84     status.Update(Unknown("SynchronizeAllActivity failed."));
85   }
86   return status;
87 }
88 
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)89 Status LocalDeviceState::ThenMemcpyDeviceToDevice(
90     se::Stream* transfer_stream, se::Stream* dst_stream,
91     se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
92   // The default implementation simply calls ThenMemcpyD2D, and assumes that
93   // the buffer addresses identify the devices. This does not work
94   // on all platforms; this method is virtual so it can be overridden.
95   transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
96   return Status::OK();
97 }
98 
ThenExecuteOnCallbackThread(se::Stream * stream,std::function<void ()> callback) const99 void LocalDeviceState::ThenExecuteOnCallbackThread(
100     se::Stream* stream, std::function<void()> callback) const {
101   stream->ThenDoHostCallback([this, callback]() mutable {
102     callback_thread_->Schedule(std::move(callback));
103   });
104 }
105 
GetDeviceToHostStream()106 se::Stream* LocalDeviceState::GetDeviceToHostStream() {
107   absl::MutexLock lock(&mu_);
108   int i = next_device_to_host_stream_;
109   next_device_to_host_stream_ =
110       (next_device_to_host_stream_ + 1) % device_to_host_streams_.size();
111   return device_to_host_streams_.at(i).get();
112 }
113 
GetDeviceToDeviceStream()114 se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
115   absl::MutexLock lock(&mu_);
116   int i = next_device_to_device_stream_;
117   next_device_to_device_stream_ =
118       (next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
119   return device_to_device_streams_.at(i).get();
120 }
121 
BorrowStreamFromPool()122 std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
123   absl::MutexLock lock(&mu_);
124   if (usage_stream_pool_.empty()) {
125     auto stream = absl::make_unique<se::Stream>(compute_stream_->parent());
126     stream->Init();
127     return stream;
128   } else {
129     std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
130     usage_stream_pool_.pop();
131     auto status = stream->RefreshStatus();  // Can return error::Unimplemented
132     // Stream may fail with "ABORTED: Bad connection".
133     if (status.code() != tensorflow::error::ABORTED) {
134       CHECK(stream->ok()) << status;
135     }
136     return stream;
137   }
138 }
139 
ReturnStreamToPool(std::unique_ptr<se::Stream> stream)140 void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
141   auto status = stream->RefreshStatus();  // Can return error::Unimplemented
142   // Stream may fail with "ABORTED: Bad connection".
143   if (status.code() != tensorflow::error::ABORTED) {
144     CHECK(stream->ok()) << status;
145   }
146   absl::MutexLock lock(&mu_);
147   usage_stream_pool_.push(std::move(stream));
148 }
149 
GetNewPrngSeed()150 int LocalDeviceState::GetNewPrngSeed() {
151   absl::MutexLock lock(&mu_);
152   int x = 0;
153   do {
154     x = prng_seed_distribution_(prng_seed_generator_);
155   } while (x == 0);
156   return x;
157 }
158 
159 }  // namespace xla
160