1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/rendezvous_util.h"
16
17 namespace tensorflow {
18
SendTensorsToRendezvous(Rendezvous * rendezvous,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,const std::vector<string> & keys,gtl::ArraySlice<Tensor> tensors_to_send)19 Status SendTensorsToRendezvous(
20 Rendezvous* rendezvous, DeviceContext* device_context,
21 const std::vector<AllocatorAttributes>& alloc_attrs,
22 const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send) {
23 if (keys.size() != tensors_to_send.size()) {
24 return errors::InvalidArgument(
25 "keys and tensors_to_send are not the same size. keys.size() = ",
26 keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
27 }
28 if (!alloc_attrs.empty() && (keys.size() != alloc_attrs.size())) {
29 return errors::InvalidArgument(
30 "keys and alloc_attrs are not the same size. ",
31 "keys.size() = ", keys.size(),
32 "; alloc_attrs.size() = ", alloc_attrs.size());
33 }
34
35 if (!rendezvous) {
36 return errors::InvalidArgument("Rendezvous is null.");
37 }
38
39 Rendezvous::ParsedKey parsed;
40 for (int i = 0; i < keys.size(); ++i) {
41 Rendezvous::Args rendez_args;
42 rendez_args.device_context = device_context;
43 if (!alloc_attrs.empty()) {
44 rendez_args.alloc_attrs = alloc_attrs[i];
45 }
46 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
47 TF_RETURN_IF_ERROR(
48 rendezvous->Send(parsed, rendez_args, tensors_to_send[i], false));
49 }
50 return Status::OK();
51 }
52
RecvOutputsFromRendezvousAsync(Rendezvous * rendezvous,DeviceContext * device_context,const std::vector<AllocatorAttributes> & alloc_attrs,const std::vector<string> & keys,std::vector<Tensor> * received_tensors,const StatusCallback & done)53 void RecvOutputsFromRendezvousAsync(
54 Rendezvous* rendezvous, DeviceContext* device_context,
55 const std::vector<AllocatorAttributes>& alloc_attrs,
56 const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
57 const StatusCallback& done) {
58 if (keys.empty()) {
59 done(Status::OK());
60 return;
61 }
62 if (!alloc_attrs.empty() && (keys.size() != alloc_attrs.size())) {
63 done(errors::InvalidArgument(
64 "keys and alloc_attrs are not the same size. ", "keys.size() = ",
65 keys.size(), "; alloc_attrs.size() = ", alloc_attrs.size()));
66 }
67
68 received_tensors->reserve(keys.size());
69 std::vector<
70 std::tuple<string, Tensor*, Rendezvous::ParsedKey, AllocatorAttributes>>
71 arguments;
72 for (int i = 0; i < keys.size(); ++i) {
73 Rendezvous::ParsedKey parsed;
74 Status s = Rendezvous::ParseKey(keys[i], &parsed);
75 received_tensors->push_back(Tensor());
76 if (!s.ok()) {
77 done(s);
78 return;
79 }
80 AllocatorAttributes alloc_attr;
81 if (!alloc_attrs.empty()) {
82 alloc_attr = alloc_attrs[i];
83 }
84 arguments.emplace_back(keys[i], &((*received_tensors)[i]), parsed,
85 alloc_attr);
86 }
87
88 typedef struct {
89 mutex mu;
90 int64 done_counter;
91 Status shared_status = Status::OK();
92 } CallState;
93 CallState* call_state = new CallState;
94 call_state->done_counter = keys.size();
95 for (auto& p : arguments) {
96 const string& key = std::get<0>(p);
97 Tensor* val = std::get<1>(p);
98 Rendezvous::ParsedKey parsed = std::get<2>(p);
99 Rendezvous::Args rendez_args;
100 rendez_args.device_context = device_context;
101 rendez_args.alloc_attrs = std::get<3>(p);
102
103 rendezvous->RecvAsync(
104 parsed, rendez_args,
105 [val, done, key, call_state](const Status& s,
106 const Rendezvous::Args& send_args,
107 const Rendezvous::Args& recv_args,
108 const Tensor& v, const bool is_dead) {
109 Status status = s;
110 if (status.ok()) {
111 *val = v;
112 if (is_dead) {
113 status = errors::InvalidArgument("The tensor returned for ", key,
114 " was not valid.");
115 }
116 }
117 call_state->mu.lock();
118 call_state->shared_status.Update(status);
119 call_state->done_counter--;
120 // If we are the last async call to return, call the done callback.
121 if (call_state->done_counter == 0) {
122 const Status& final_status = call_state->shared_status;
123 call_state->mu.unlock();
124 done(final_status);
125 delete call_state;
126 return;
127 }
128 call_state->mu.unlock();
129 });
130 }
131 }
132
RecvOutputsFromRendezvous(Rendezvous * rendezvous,NamedTensors * out,const Rendezvous::Args & args)133 Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
134 const Rendezvous::Args& args) {
135 // Receives values requested by the caller.
136 Rendezvous::ParsedKey parsed;
137 for (auto& p : *out) {
138 const string& key = p.first;
139 Tensor* val = &p.second;
140 bool is_dead = false;
141 TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
142 TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
143 if (is_dead) {
144 return errors::InvalidArgument("The tensor returned for ", key,
145 " was not valid.");
146 }
147 }
148 return Status::OK();
149 }
150
151 } // namespace tensorflow
152