• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/mutex.h"
27 
28 namespace tensorflow {
29 class Device;
30 class DeviceContext;
31 class Tensor;
32 
33 // EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local
34 // Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense
35 // numeric types.  Similar to Rendezvous but never owns a Ref on the
36 // tensor, instead it uses an explicit callback to the producer when
37 // the consumer side is finished with the value.  This allows the
38 // producer to perform in-place updates on the source buffer or to take
39 // other actions that depend on knowing the consumer has passed a certain
40 // execution point.
41 class BufRendezvous {
42  public:
BufRendezvous(uint64 step_id,const DeviceMgr * dev_mgr)43   explicit BufRendezvous(uint64 step_id, const DeviceMgr* dev_mgr)
44       : step_id_(step_id), dev_mgr_(dev_mgr) {}
45 
46   ~BufRendezvous();
47 
48   // Inform all all waiting parties that this BufRendezvous is defunct
49   // because of an error Status interrupting the Step.
50   void StartAbort(const Status& s);
51 
52   struct Hook;
53   // Provided by the consumer to be called when access to the buffer
54   // is available.  If the Status arg is not OK, then hook will not
55   // be populated.  Ownership of Hook passes to consumer with the
56   // callback.
57   typedef std::function<void(const Status&, Hook*)> ConsumerCallback;
58   // Provided by the producer to be called when the consumer has finished
59   // reading the buffer and will no longer access it.
60   typedef std::function<void(const Status&)> ProducerCallback;
61 
62   struct Hook {
63     Device* prod_dev;
64     DeviceContext* prod_ctx;
65     const Tensor* prod_value;
66     AllocatorAttributes prod_attr;
67     ProducerCallback prod_cb;
68     ConsumerCallback cons_cb;
HookHook69     Hook()
70         : prod_dev(nullptr),
71           prod_ctx(nullptr),
72           prod_value(nullptr),
73           prod_cb(nullptr),
74           cons_cb(nullptr) {}
75     string DebugString() const;
76   };
77 
78   // Called to advertise availability of a Tensor value corresponding
79   // to key.  That value must stay valid until done is called.
80   void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx,
81                   const Tensor* v, const AllocatorAttributes& attr,
82                   const ProducerCallback& done);
83 
84   // Called to request access to a Tensor value corresponding to key.
85   // Consumer is provided with a Hook as soon as available.
86   //
87   // This function also checks that the current incarnation number of the
88   // `device` that produced this value matches the `incarnation` expected by the
89   // consumer, and invokes `done` with `FailedPrecondition` status and
90   // `nullptr` hook if it does not match.
91   void ConsumeBuf(const string& key, const string& device,
92                   const uint64 incarnation, const ConsumerCallback& done);
93 
94   // Consumer must call this function when it's done reading the Hook provided
95   // by the ConsumerCallback.  This function will invoke the producer callback
96   // and then delete h.
97   static void DoneWithHook(Hook* h);
98 
99   // Write the current contents of the table to the INFO log.
100   void LogContents();
101 
102  protected:
103   const uint64 step_id_;
104   const DeviceMgr* const dev_mgr_;  // Not owned.
105   mutex mu_;
106   Status status_ GUARDED_BY(mu_);
107   typedef absl::flat_hash_map<string, Hook*> HookTable;
108   HookTable hook_table_ GUARDED_BY(mu_);
109 
110   void PurgeTable(const Status& s, HookTable* table);
111 };
112 }  // namespace tensorflow
113 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
114