1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
16
17 #include "tensorflow/core/common_runtime/device.h"
18 #include "tensorflow/core/common_runtime/process_util.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/core/notification.h"
21
22 namespace tensorflow {
23
~BufRendezvous()24 BufRendezvous::~BufRendezvous() {
25 mutex_lock l(mu_);
26 if (!hook_table_.empty()) {
27 PurgeTable(errors::Internal("Delete called on non-empty BufRendezvous"),
28 &hook_table_);
29 }
30 }
31
StartAbort(const Status & s)32 void BufRendezvous::StartAbort(const Status& s) {
33 CHECK(!s.ok());
34 HookTable dummy_table;
35 {
36 mutex_lock l(mu_);
37 status_.Update(s);
38 hook_table_.swap(dummy_table);
39 }
40 PurgeTable(s, &dummy_table);
41 }
42
PurgeTable(const Status & s,HookTable * table)43 void BufRendezvous::PurgeTable(const Status& s, HookTable* table) {
44 for (auto& it : *table) {
45 Hook* h = it.second;
46 if (h->cons_cb != nullptr) {
47 h->cons_cb(s, nullptr);
48 }
49 if (h->prod_cb != nullptr) {
50 h->prod_cb(s);
51 }
52 delete h;
53 }
54 table->clear();
55 }
56
DebugString() const57 string BufRendezvous::Hook::DebugString() const {
58 return strings::StrCat("[dev:", (prod_dev ? prod_dev->name() : "none"),
59 ", ctx:", reinterpret_cast<uint64>(prod_ctx),
60 ", val:", reinterpret_cast<uint64>(prod_value),
61 ", pcb:", reinterpret_cast<uint64>(&prod_cb),
62 ", ccb:", reinterpret_cast<uint64>(&cons_cb), "]");
63 }
64
ProvideBuf(const string & key,Device * dev,DeviceContext * dev_ctx,const Tensor * v,const AllocatorAttributes & attr,const ProducerCallback & done)65 void BufRendezvous::ProvideBuf(const string& key, Device* dev,
66 DeviceContext* dev_ctx, const Tensor* v,
67 const AllocatorAttributes& attr,
68 const ProducerCallback& done) {
69 Hook* h = nullptr;
70 Status providebuf_status;
71 do {
72 mutex_lock l(mu_);
73 if (!status_.ok()) {
74 providebuf_status = status_;
75 break;
76 } else {
77 auto it = hook_table_.find(key);
78 if (it == hook_table_.end()) {
79 h = new Hook;
80 it = hook_table_.insert(std::make_pair(key, h)).first;
81 } else {
82 if (it->second->prod_cb != nullptr) {
83 providebuf_status = errors::Internal(
84 "BufRendezvous::ProvideBuf already called for key ", key);
85 break;
86 }
87 h = it->second;
88 }
89 // Populate Hook with all of the prod values.
90 h->prod_dev = dev;
91 h->prod_ctx = dev_ctx;
92 h->prod_value = v;
93 h->prod_attr = attr;
94 h->prod_cb = done;
95 // If consumer is waiting, kick off right away, removing Hook from table.
96 if (h->cons_cb != nullptr) {
97 hook_table_.erase(it);
98 } else {
99 h = nullptr;
100 }
101 }
102 } while (false);
103 if (h) {
104 h->cons_cb(Status::OK(), h);
105 }
106 if (!providebuf_status.ok()) {
107 done(providebuf_status);
108 }
109 }
110
ConsumeBuf(const string & key,const ConsumerCallback & done)111 void BufRendezvous::ConsumeBuf(const string& key,
112 const ConsumerCallback& done) {
113 Hook* existing_hook = nullptr;
114 Status consumebuf_status;
115 do {
116 mutex_lock l(mu_);
117 if (!status_.ok()) {
118 consumebuf_status = status_;
119 break;
120 }
121 auto it = hook_table_.find(key);
122 if (it != hook_table_.end()) {
123 // Prepare to consume immediately.
124 if (it->second->cons_cb) {
125 consumebuf_status =
126 errors::Internal("Second consumer arrived for key ", key);
127 break;
128 }
129 existing_hook = it->second;
130 hook_table_.erase(it);
131 existing_hook->cons_cb = done;
132 } else {
133 // Hang consumer callback on the Hook.
134 Hook* h = new Hook;
135 hook_table_[key] = h;
136 h->cons_cb = done;
137 return;
138 }
139 } while (false);
140 if (existing_hook) {
141 existing_hook->cons_cb(Status::OK(), existing_hook);
142 return;
143 }
144 if (!consumebuf_status.ok()) {
145 done(consumebuf_status, nullptr);
146 return;
147 }
148 }
149
150 /*static*/
DoneWithHook(Hook * h)151 void BufRendezvous::DoneWithHook(Hook* h) {
152 h->prod_cb(Status::OK());
153 delete h;
154 }
155
LogContents()156 void BufRendezvous::LogContents() {
157 mutex_lock l(mu_);
158 LOG(INFO) << strings::StrCat("BufRendezvous ",
159 strings::Hex(reinterpret_cast<uint64>(this)),
160 " step_id=", step_id_, " current contents:");
161 for (auto it : hook_table_) {
162 LOG(INFO) << it.first << ":" << it.second->DebugString();
163 }
164 }
165
166 } // namespace tensorflow
167