• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 #include "tensorflow/core/protobuf/error_codes.pb.h"
27 #include "tensorflow/stream_executor/stream.h"
28 
29 namespace xla {
30 
LocalDeviceState(se::StreamExecutor * executor,LocalClient * client,AllocationModel allocation_model,int max_inflight_computations,bool allow_event_reuse,bool use_callback_stream)31 LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
32                                    LocalClient* client,
33                                    AllocationModel allocation_model,
34                                    int max_inflight_computations,
35                                    bool allow_event_reuse,
36                                    bool use_callback_stream)
37     : allocation_model_(allocation_model),
38       event_pool_(allow_event_reuse),
39       compute_semaphore_(
40           /*capacity=*/max_inflight_computations),
41       executor_(executor),
42       client_(client),
43       prng_seed_generator_(prng_seed_device_()),
44       prng_seed_distribution_(std::numeric_limits<int>::min(),
45                               std::numeric_limits<int>::max()) {
46   compute_stream_ = std::make_unique<se::Stream>(executor);
47   host_to_device_stream_ = std::make_unique<se::Stream>(executor);
48   compute_stream_->Init();
49   host_to_device_stream_->Init();
50   if (use_callback_stream) {
51     callback_stream_map_ =
52         absl::flat_hash_map<se::Stream*, std::unique_ptr<se::Stream>>();
53   }
54   device_to_host_streams_.reserve(kNumDeviceToHostStreams);
55   for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
56     auto stream = std::make_unique<se::Stream>(executor);
57     stream->Init();
58     device_to_host_streams_.push_back(std::move(stream));
59   }
60   device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
61   for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
62     auto stream = std::make_unique<se::Stream>(executor);
63     stream->Init();
64     device_to_device_streams_.push_back(std::move(stream));
65   }
66   execute_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
67                                                    "py_xla_execute");
68   callback_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
69                                                     "py_xla_callback");
70 }
71 
~LocalDeviceState()72 LocalDeviceState::~LocalDeviceState() {
73   Status status = SynchronizeAllActivity();
74   if (!status.ok()) {
75     LOG(ERROR) << "Error when closing device: " << status;
76   }
77 }
78 
SynchronizeAllActivity()79 Status LocalDeviceState::SynchronizeAllActivity() {
80   Status status;
81   // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
82   // suffice. However on the Host platform SynchronizeAllActivity is a dummy
83   // implementation that doesn't actually block. To make sure activity has
84   // stopped, also block on the compute stream. If SynchronizeAllActivity is
85   // fixed, we could remove the BlockHostUntilDone call.
86   status.Update(compute_stream_->BlockHostUntilDone());
87   if (callback_stream_map_.has_value()) {
88     for (auto& callback_stream : callback_stream_map_.value()) {
89       status.Update(callback_stream.second->BlockHostUntilDone());
90     }
91   }
92   bool ok = compute_stream_->parent()->SynchronizeAllActivity();
93   if (!ok) {
94     status.Update(Unknown("SynchronizeAllActivity failed."));
95   }
96   return status;
97 }
98 
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)99 Status LocalDeviceState::ThenMemcpyDeviceToDevice(
100     se::Stream* transfer_stream, se::Stream* dst_stream,
101     se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
102   // The default implementation simply calls ThenMemcpyD2D, and assumes that
103   // the buffer addresses identify the devices. This does not work
104   // on all platforms; this method is virtual so it can be overridden.
105   transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
106   return Status::OK();
107 }
108 
ThenExecuteCallback(se::Stream * stream,std::function<void ()> callback)109 void LocalDeviceState::ThenExecuteCallback(se::Stream* stream,
110                                            std::function<void()> callback) {
111   tensorflow::profiler::TraceMe traceme("ThenExecuteCallback");
112   if (callback_stream_map_.has_value()) {
113     // Prevent concurrent updates to the callback stream map.
114     absl::MutexLock lock(&mu_);
115     auto callback_stream = callback_stream_map_->find(stream);
116     if (callback_stream == callback_stream_map_->end()) {
117       auto new_stream = std::make_unique<se::Stream>(executor_);
118       new_stream->Init();
119       callback_stream =
120           callback_stream_map_->insert({stream, std::move(new_stream)}).first;
121     }
122     callback_stream->second->ThenWaitFor(stream);
123     stream = callback_stream->second.get();
124   }
125   stream->ThenDoHostCallback([this, callback{std::move(callback)}]() mutable {
126     callback_thread_->Schedule(std::move(callback));
127   });
128 }
129 
GetDeviceToHostStream()130 se::Stream* LocalDeviceState::GetDeviceToHostStream() {
131   absl::MutexLock lock(&mu_);
132   int i = next_device_to_host_stream_;
133   next_device_to_host_stream_ =
134       (next_device_to_host_stream_ + 1) % device_to_host_streams_.size();
135   return device_to_host_streams_.at(i).get();
136 }
137 
GetDeviceToDeviceStream()138 se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
139   absl::MutexLock lock(&mu_);
140   int i = next_device_to_device_stream_;
141   next_device_to_device_stream_ =
142       (next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
143   return device_to_device_streams_.at(i).get();
144 }
145 
BorrowStreamFromPool()146 std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
147   absl::MutexLock lock(&mu_);
148   if (usage_stream_pool_.empty()) {
149     auto stream = std::make_unique<se::Stream>(compute_stream_->parent());
150     stream->Init();
151     return stream;
152   } else {
153     std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
154     usage_stream_pool_.pop();
155     auto status = stream->RefreshStatus();  // Can return error::Unimplemented
156     // Stream may fail with "ABORTED: Bad connection".
157     if (status.code() != tensorflow::error::ABORTED) {
158       CHECK(stream->ok()) << status;
159     }
160     return stream;
161   }
162 }
163 
ReturnStreamToPool(std::unique_ptr<se::Stream> stream)164 void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
165   auto status = stream->RefreshStatus();  // Can return error::Unimplemented
166   // Stream may fail with "ABORTED: Bad connection".
167   if (status.code() != tensorflow::error::ABORTED) {
168     CHECK(stream->ok()) << status;
169   }
170   absl::MutexLock lock(&mu_);
171   usage_stream_pool_.push(std::move(stream));
172 }
173 
GetNewPrngSeed()174 int LocalDeviceState::GetNewPrngSeed() {
175   absl::MutexLock lock(&mu_);
176   int x = 0;
177   do {
178     x = prng_seed_distribution_(prng_seed_generator_);
179   } while (x == 0);
180   return x;
181 }
182 
183 }  // namespace xla
184