• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
16 
17 #include "absl/strings/numbers.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/notification.h"
24 
25 namespace tensorflow {
26 
~BufRendezvous()27 BufRendezvous::~BufRendezvous() {
28   mutex_lock l(mu_);
29   if (!hook_table_.empty()) {
30     PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
31                &hook_table_);
32   }
33 }
34 
StartAbort(const Status & s)35 void BufRendezvous::StartAbort(const Status& s) {
36   CHECK(!s.ok());
37   HookTable dummy_table;
38   {
39     mutex_lock l(mu_);
40     // Use a "derived" status as the status for the rendezvous. Derived
41     // status messages are ignored when aggregating errors across devices: this
42     // allows us to prefer our original status message over any cancellation
43     // related errors.
44     status_.Update(StatusGroup::MakeDerived(s));
45     hook_table_.swap(dummy_table);
46   }
47   PurgeTable(s, &dummy_table);
48 }
49 
PurgeTable(const Status & s,HookTable * table)50 void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
51   for (auto& it : *table) {
52     Hook* h = it.second;
53     if (h->cons_cb != nullptr) {
54       h->cons_cb(s, nullptr);
55     }
56     if (h->prod_cb != nullptr) {
57       h->prod_cb(s);
58     }
59     delete h;
60   }
61   table->clear();
62 }
63 
DebugString() const64 string BufRendezvous::Hook::DebugString() const {
65   return absl::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
66                       ", ctx:", reinterpret_cast<uint64>(prod_ctx),
67                       ", val:", reinterpret_cast<uint64>(prod_value),
68                       ", pcb:", reinterpret_cast<uint64>(&prod_cb),
69                       ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
70 }
71 
ProvideBuf(const string & key,Device * dev,DeviceContext * dev_ctx,const Tensor * v,const AllocatorAttributes & attr,const ProducerCallback & done)72 void BufRendezvous::ProvideBuf(const string& key, Device* dev,
73                                DeviceContext* dev_ctx, const Tensor* v,
74                                const AllocatorAttributes& attr,
75                                const ProducerCallback& done) {
76   Hook* h = nullptr;
77   Status providebuf_status;
78   do {
79     mutex_lock l(mu_);
80     if (!status_.ok()) {
81       providebuf_status = status_;
82       break;
83     } else {
84       auto it = hook_table_.find(key);
85       if (it == hook_table_.end()) {
86         h = new Hook;
87         it = hook_table_.insert(std::make_pair(key, h)).first;
88       } else {
89         if (it->second->prod_cb != nullptr) {
90           providebuf_status = errors::Internal(
91               "BufRendezvous::ProvideBuf already called for key ", key);
92           break;
93         }
94         h = it->second;
95       }
96       // Populate Hook with all of the prod values.
97       h->prod_dev = dev;
98       h->prod_ctx = dev_ctx;
99       h->prod_value = v;
100       h->prod_attr = attr;
101       h->prod_cb = done;
102       // If consumer is waiting, kick off right away, removing Hook from table.
103       if (h->cons_cb != nullptr) {
104         hook_table_.erase(it);
105       } else {
106         h = nullptr;
107       }
108     }
109   } while (false);
110   if (h) {
111     h->cons_cb(Status::OK(), h);
112   }
113   if (!providebuf_status.ok()) {
114     done(providebuf_status);
115   }
116 }
117 
ConsumeBuf(const string & key,const string & device_name,const uint64 device_incarnation,const ConsumerCallback & done)118 void BufRendezvous::ConsumeBuf(const string& key, const string& device_name,
119                                const uint64 device_incarnation,
120                                const ConsumerCallback& done) {
121   // Check the incarnation in the request matches the current device
122   // incarnation of the producer.
123   Device* device;
124   Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device);
125   if (consumebuf_status.ok() &&
126       device->attributes().incarnation() != device_incarnation) {
127     consumebuf_status = errors::FailedPrecondition(
128         "RecvBuf expects a different device incarnation: ", device_incarnation,
129         " vs. ", device->attributes().incarnation(),
130         ". Your worker job that contains the device (\"", device_name,
131         "\") was probably restarted. Check your "
132         "worker job for the reason why it was restarted.");
133   }
134   if (!consumebuf_status.ok()) {
135     done(consumebuf_status, nullptr);
136     return;
137   }
138 
139   Hook* existing_hook = nullptr;
140   do {
141     mutex_lock l(mu_);
142     if (!status_.ok()) {
143       consumebuf_status = status_;
144       break;
145     }
146     auto it = hook_table_.find(key);
147     if (it != hook_table_.end()) {
148       // Prepare to consume immediately.
149       if (it->second->cons_cb) {
150         consumebuf_status =
151             errors::Internal("Second consumer arrived for key ", key);
152         break;
153       }
154       existing_hook = it->second;
155       hook_table_.erase(it);
156       existing_hook->cons_cb = done;
157     } else {
158       // Hang consumer callback on the Hook.
159       Hook* h = new Hook;
160       hook_table_[key] = h;
161       h->cons_cb = done;
162       return;
163     }
164   } while (false);
165   if (existing_hook) {
166     existing_hook->cons_cb(Status::OK(), existing_hook);
167     return;
168   }
169   if (!consumebuf_status.ok()) {
170     done(consumebuf_status, nullptr);
171     return;
172   }
173 }
174 
175 /*static*/
DoneWithHook(Hook * h)176 void BufRendezvous::DoneWithHook(Hook* h) {
177   h->prod_cb(Status::OK());
178   delete h;
179 }
180 
LogContents()181 void BufRendezvous::LogContents() {
182   mutex_lock l(mu_);
183   LOG(INFO) << strings::StrCat("BufRendezvous ",
184                                strings::Hex(reinterpret_cast<uint64>(this)),
185                                " step_id=", step_id_, " current contents:");
186   for (auto it : hook_table_) {
187     LOG(INFO) << it.first << ":" << it.second->DebugString();
188   }
189 }
190 
191 }  // namespace tensorflow
192