1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/cancellation.h" 22 #include "tensorflow/core/framework/control_flow.h" 23 #include "tensorflow/core/framework/device_base.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/refcount.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/util/device_name_utils.h" 28 29 namespace tensorflow { 30 31 class DeviceMgr; 32 33 // A Rendezvous is an abstraction for passing tensors from producers 34 // to consumers. A rendezvous is a table of channels. Each channel is 35 // keyed by a rendezvous key. The key encodes a pair of <producer, 36 // consumer>, where the producer and the consumer are tensorflow 37 // devices. 38 // 39 // The producer calls the Send() method to send one tensor over one 40 // named channel. The consumer calls the Recv() method to receive one 41 // tensor from a named channel. A sequence of tensors can be passed 42 // from the producer to the consumer. The consumer receives them in 43 // the order as the producer sends them. 44 // 45 // A consumer may safely request the tensor before or after it has 46 // been produced. A consumer has the choice of making a blocking call 47 // or providing a callback: in either case, the consumer receives the 48 // Tensor as soon as it is available. A producer never blocks. 49 class RendezvousInterface { 50 public: 51 struct Args { 52 DeviceContext* device_context = nullptr; 53 AllocatorAttributes alloc_attrs; 54 CancellationManager* cancellation_manager = nullptr; // not owned. 55 }; 56 57 // Parses the key constructed by CreateKey and parse src/dst device 58 // names into structures respectively. 59 struct ParsedKey { 60 StringPiece src_device; 61 DeviceNameUtils::ParsedName src; 62 uint64 src_incarnation = 0; 63 StringPiece dst_device; 64 DeviceNameUtils::ParsedName dst; 65 StringPiece edge_name; 66 ParsedKeyParsedKey67 ParsedKey() {} ParsedKeyParsedKey68 ParsedKey(const ParsedKey& b) { *this = b; } 69 70 ParsedKey& operator=(const ParsedKey& b); FullKeyParsedKey71 StringPiece FullKey() const { return buf_; } 72 73 private: 74 friend class Rendezvous; 75 friend class SendOp; 76 friend class RecvOp; 77 std::string buf_; 78 }; 79 80 // The caller is a tensor producer and it sends a message (a tensor 81 // "val" and a bool "is_dead") under the given "key". 82 // 83 // {val, is_dead} is bundled as a message sent and received. 84 // Typically, is_dead is set by some control flow nodes 85 // (e.g., a not-taken branch). args is passed by Send to the 86 // Recv function to communicate any information that the Recv 87 // function might need. This is typically only necessary for 88 // Send/Recv on the same worker. 89 // 90 // Send() never blocks. 91 virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val, 92 const bool is_dead) = 0; 93 94 // Callback provided by a tensor consumer waiting on the rendezvous. 95 // It will be invoked when the tensor is available, or when a non-OK 96 // status arises in the production of that tensor. It also gets 97 // two Rendezvous::Args, one provided by the sender, the other by the 98 // receiver, which may be needed when a non-CPU device is in use 99 // by either side. 100 typedef std::function<void(const Status&, const Args&, const Args&, 101 const Tensor&, const bool)> 102 DoneCallback; 103 104 virtual void RecvAsync(const ParsedKey& key, const Args& args, 105 DoneCallback done) = 0; 106 107 // Synchronous wrapper for RecvAsync. 108 Status Recv(const ParsedKey& key, const Args& args, Tensor* val, 109 bool* is_dead, int64 timeout_ms); 110 Status Recv(const ParsedKey& key, const Args& args, Tensor* val, 111 bool* is_dead); 112 113 // Aborts all pending and future Send/Recv with the given "status". 114 // 115 // StartAbort() does not wait for ongoing calls to finish. 116 // REQUIRES: !status.ok() 117 virtual void StartAbort(const Status& status) = 0; 118 119 protected: 120 virtual ~RendezvousInterface(); 121 is_cross_process()122 virtual bool is_cross_process() { return false; } 123 friend class ProcessFunctionLibraryRuntime; 124 }; 125 126 // A reference-counted implementation of RendezvousInterface. 127 // 128 // This class is used in cases where a rendezvous may be shared between multiple 129 // threads with no clear owner. 130 class Rendezvous : public RendezvousInterface, public core::RefCounted { 131 public: 132 class Factory { 133 public: 134 // Default to a factory that evaluates to false. Factory()135 Factory() : valid_(false) {} 136 Factory(std::function<Status (const int64,const DeviceMgr *,Rendezvous **)> create_fn,std::function<Status (const int64)> cleanup_fn)137 Factory(std::function<Status(const int64, const DeviceMgr*, Rendezvous**)> 138 create_fn, 139 std::function<Status(const int64)> cleanup_fn) 140 : valid_(true), 141 create_fn_(std::move(create_fn)), 142 cleanup_fn_(std::move(cleanup_fn)) {} 143 144 // If no clean up fn is provided, just put in a dummy. 145 // For backwards compatibility. Factory(std::function<Status (const int64,const DeviceMgr *,Rendezvous **)> create_fn)146 explicit Factory( 147 std::function<Status(const int64, const DeviceMgr*, Rendezvous**)> 148 create_fn) 149 : valid_(true), 150 create_fn_(std::move(create_fn)), 151 cleanup_fn_([](const int64 step_id) { return Status::OK(); }) {} 152 153 explicit operator bool() const { return valid_; } 154 operator()155 Status operator()(const int64 step_id, const DeviceMgr* device_mgr, 156 Rendezvous** rendez) const { 157 return create_fn_(step_id, device_mgr, rendez); 158 } 159 CleanUp(const int64 step_id)160 Status CleanUp(const int64 step_id) const { return cleanup_fn_(step_id); } 161 162 private: 163 bool valid_; 164 std::function<Status(const int64, const DeviceMgr*, Rendezvous**)> 165 create_fn_; 166 std::function<Status(const int64)> cleanup_fn_; 167 }; 168 169 // Constructs a rendezvous key for the tensor of "name" sent from 170 // "src_device" to "dst_device". The tensor is generated in the frame 171 // and iteration specified by "frame_iter". 172 static std::string CreateKey(const std::string& src_device, 173 uint64 src_incarnation, 174 const std::string& dst_device, 175 const std::string& name, 176 const FrameAndIter& frame_iter); 177 178 static Status ParseKey(StringPiece key, ParsedKey* out); 179 }; 180 181 // Returns a Rendezvous instance that is limited to use only by 182 // producers and consumers in the local process. The caller assumes 183 // ownership of one Ref() on the returned object. 184 Rendezvous* NewLocalRendezvous(); 185 186 } // end namespace tensorflow 187 188 #endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ 189