• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
18 
19 #include <cstddef>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
24 #include "tensorflow/core/common_runtime/eager/shape_inference.h"
25 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
26 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/protobuf/eager_service.pb.h"
32 
33 namespace tensorflow {
34 namespace eager {
35 
36 // RemoteExecuteNode is an implementation of EagerNode which enqueues
37 // an operation via RPC in a remote EagerService.
38 class RemoteExecuteNode : public AsyncRemoteExecuteNode {
39  public:
RemoteExecuteNode(EagerContext * eager_context,std::unique_ptr<EnqueueRequest> request,Device * device,uint64 context_view_id,EagerClient * eager_client,CancellationManager * cancellation_manager,const NodeDef & ndef,FunctionLibraryDefinition * lib_def,const gtl::InlinedVector<TensorHandle *,4> & inputs,absl::Span<TensorHandle * > retvals)40   RemoteExecuteNode(EagerContext* eager_context,
41                     std::unique_ptr<EnqueueRequest> request, Device* device,
42                     uint64 context_view_id, EagerClient* eager_client,
43                     CancellationManager* cancellation_manager,
44                     const NodeDef& ndef, FunctionLibraryDefinition* lib_def,
45                     const gtl::InlinedVector<TensorHandle*, 4>& inputs,
46                     absl::Span<TensorHandle*> retvals)
47       : AsyncRemoteExecuteNode(),
48         eager_context_(eager_context),
49         request_(std::move(request)),
50         device_(device),
51         context_view_id_(context_view_id),
52         eager_client_(eager_client),
53         cancellation_manager_(cancellation_manager),
54         ndef_(ndef),
55         lib_def_(lib_def),
56         inputs_(inputs) {
57     // Copy the output handles, since the container for them might get
58     // destroyed.
59     for (auto handle : retvals) {
60       handle->Ref();
61       retvals_.push_back(handle);
62     }
63 
64     // This is required to ensure that the tensor handles stay alive across the
65     // execution.
66     for (auto handle : inputs_) {
67       handle->Ref();
68     }
69     eager_client_->Ref();
70 
71     needs_remote_inputs_ = false;
72     for (const TensorHandle* input : inputs_) {
73       // TODO(bramandia): Should this be op_device() instead?
74       if (input->resource_device() != nullptr &&
75           input->resource_device() != device_) {
76         needs_remote_inputs_ = true;
77         break;
78       }
79     }
80   }
81 
~RemoteExecuteNode()82   ~RemoteExecuteNode() override {
83     for (auto handle : retvals_) {
84       handle->Unref();
85     }
86 
87     for (auto handle : inputs_) {
88       handle->Unref();
89     }
90     eager_client_->Unref();
91   }
92 
Prepare()93   Status Prepare() override {
94     return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_);
95   }
96 
97   void RunAsync(StatusCallback done) override;
98 
SyncExecutors()99   Status SyncExecutors() override { return eager_context_->SyncExecutors(); }
100 
Abort(Status status)101   void Abort(Status status) override {
102     int i = 0;
103     for (auto handle : retvals_) {
104       handle->PoisonRemote(status, device_, context_view_id_);
105       ++i;
106     }
107   }
108 
eager_client()109   const EagerClient* eager_client() const override { return eager_client_; }
110 
needs_remote_inputs()111   bool needs_remote_inputs() const override { return needs_remote_inputs_; }
112 
allow_multiple_pending_requests()113   bool allow_multiple_pending_requests() const override {
114     return eager_client_->allow_multiple_pending_requests();
115   }
116 
DebugString()117   string DebugString() const override {
118     string out = "[RemoteExecuteNode]";
119     strings::StrAppend(&out, " request: ", request_->DebugString());
120     strings::StrAppend(&out, ", target_device: ", device_->name());
121     return out;
122   }
123 
124  private:
125   EagerContext* eager_context_;  // Not owned, and must outlive this node.
126   std::unique_ptr<EnqueueRequest> request_;
127   Device* device_;             // Not owned
128   uint64 context_view_id_;
129   bool needs_remote_inputs_;
130   EagerClient* eager_client_;  // Not owned, and must outlive this node.
131   CancellationManager* cancellation_manager_;
132   const NodeDef ndef_;
133   const FunctionLibraryDefinition* lib_def_;
134   gtl::InlinedVector<TensorHandle*, 4> inputs_;
135   gtl::InlinedVector<TensorHandle*, 2> retvals_;
136 };
137 
138 }  // namespace eager
139 }  // namespace tensorflow
140 
141 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
142