1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
16
17 #include "absl/strings/numbers.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/notification.h"
24
25 namespace tensorflow {
26
~BufRendezvous()27 BufRendezvous::~BufRendezvous() {
28 mutex_lock l(mu_);
29 if (!hook_table_.empty()) {
30 PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
31 &hook_table_);
32 }
33 }
34
StartAbort(const Status & s)35 void BufRendezvous::StartAbort(const Status& s) {
36 CHECK(!s.ok());
37 HookTable dummy_table;
38 {
39 mutex_lock l(mu_);
40 // Use a "derived" status as the status for the rendezvous. Derived
41 // status messages are ignored when aggregating errors across devices: this
42 // allows us to prefer our original status message over any cancellation
43 // related errors.
44 status_.Update(StatusGroup::MakeDerived(s));
45 hook_table_.swap(dummy_table);
46 }
47 PurgeTable(s, &dummy_table);
48 }
49
PurgeTable(const Status & s,HookTable * table)50 void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
51 for (auto& it : *table) {
52 Hook* h = it.second;
53 if (h->cons_cb != nullptr) {
54 h->cons_cb(s, nullptr);
55 }
56 if (h->prod_cb != nullptr) {
57 h->prod_cb(s);
58 }
59 delete h;
60 }
61 table->clear();
62 }
63
DebugString() const64 string BufRendezvous::Hook::DebugString() const {
65 return absl::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
66 ", ctx:", reinterpret_cast<uint64>(prod_ctx),
67 ", val:", reinterpret_cast<uint64>(prod_value),
68 ", pcb:", reinterpret_cast<uint64>(&prod_cb),
69 ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
70 }
71
ProvideBuf(const string & key,Device * dev,DeviceContext * dev_ctx,const Tensor * v,const AllocatorAttributes & attr,const ProducerCallback & done)72 void BufRendezvous::ProvideBuf(const string& key, Device* dev,
73 DeviceContext* dev_ctx, const Tensor* v,
74 const AllocatorAttributes& attr,
75 const ProducerCallback& done) {
76 Hook* h = nullptr;
77 Status providebuf_status;
78 do {
79 mutex_lock l(mu_);
80 if (!status_.ok()) {
81 providebuf_status = status_;
82 break;
83 } else {
84 auto it = hook_table_.find(key);
85 if (it == hook_table_.end()) {
86 h = new Hook;
87 it = hook_table_.insert(std::make_pair(key, h)).first;
88 } else {
89 if (it->second->prod_cb != nullptr) {
90 providebuf_status = errors::Internal(
91 "BufRendezvous::ProvideBuf already called for key ", key);
92 break;
93 }
94 h = it->second;
95 }
96 // Populate Hook with all of the prod values.
97 h->prod_dev = dev;
98 h->prod_ctx = dev_ctx;
99 h->prod_value = v;
100 h->prod_attr = attr;
101 h->prod_cb = done;
102 // If consumer is waiting, kick off right away, removing Hook from table.
103 if (h->cons_cb != nullptr) {
104 hook_table_.erase(it);
105 } else {
106 h = nullptr;
107 }
108 }
109 } while (false);
110 if (h) {
111 h->cons_cb(Status::OK(), h);
112 }
113 if (!providebuf_status.ok()) {
114 done(providebuf_status);
115 }
116 }
117
ConsumeBuf(const string & key,const string & device_name,const uint64 device_incarnation,const ConsumerCallback & done)118 void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
119 const uint64 device_incarnation,
120 const ConsumerCallback& done) {
121 // Check the incarnation in the request matches the current device
122 // incarnation of the producer.
123 Device* device;
124 Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device);
125 if (consumebuf_status.ok() &&
126 device->attributes().incarnation() != device_incarnation) {
127 consumebuf_status = errors::FailedPrecondition(
128 "RecvBuf expects a different device incarnation: ", device_incarnation,
129 " vs. ", device->attributes().incarnation(),
130 ". Your worker job that contains the device (\"", device_name,
131 "\") was probably restarted. Check your "
132 "worker job for the reason why it was restarted.");
133 }
134 if (!consumebuf_status.ok()) {
135 done(consumebuf_status, nullptr);
136 return;
137 }
138
139 Hook* existing_hook = nullptr;
140 do {
141 mutex_lock l(mu_);
142 if (!status_.ok()) {
143 consumebuf_status = status_;
144 break;
145 }
146 auto it = hook_table_.find(key);
147 if (it != hook_table_.end()) {
148 // Prepare to consume immediately.
149 if (it->second->cons_cb) {
150 consumebuf_status =
151 errors::Internal("Second consumer arrived for key ", key);
152 break;
153 }
154 existing_hook = it->second;
155 hook_table_.erase(it);
156 existing_hook->cons_cb = done;
157 } else {
158 // Hang consumer callback on the Hook.
159 Hook* h = new Hook;
160 hook_table_[key] = h;
161 h->cons_cb = done;
162 return;
163 }
164 } while (false);
165 if (existing_hook) {
166 existing_hook->cons_cb(Status::OK(), existing_hook);
167 return;
168 }
169 if (!consumebuf_status.ok()) {
170 done(consumebuf_status, nullptr);
171 return;
172 }
173 }
174
175 /*static*/
DoneWithHook(Hook * h)176 void BufRendezvous::DoneWithHook(Hook* h) {
177 h->prod_cb(Status::OK());
178 delete h;
179 }
180
LogContents()181 void BufRendezvous::LogContents() {
182 mutex_lock l(mu_);
183 LOG(INFO) << strings::StrCat("BufRendezvous ",
184 strings::Hex(reinterpret_cast<uint64>(this)),
185 " step_id=", step_id_, " current contents:");
186 for (auto it : hook_table_) {
187 LOG(INFO) << it.first << ":" << it.second->DebugString();
188 }
189 }
190
191 } // namespace tensorflow
192