1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/rpc/grpc_stub.h"
17 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
18
19 namespace xla {
20
21 GRPCStub::~GRPCStub() = default;
22
MakeRPC(const std::function<::grpc::Status (::grpc::ClientContext *)> & rpc_method)23 Status MakeRPC(
24 const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) {
25 ::grpc::ClientContext context;
26 ::grpc::Status s = rpc_method(&context);
27 return tensorflow::FromGrpcStatus(s);
28 }
29
TransferToClient(const TransferToClientRequest * request,TransferToClientResponse * response)30 Status GRPCStub::TransferToClient(const TransferToClientRequest* request,
31 TransferToClientResponse* response) {
32 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
33 return grpc_stub_->TransferToClient(context, *request, response);
34 });
35 }
36
TransferToServer(const TransferToServerRequest * request,TransferToServerResponse * response)37 Status GRPCStub::TransferToServer(const TransferToServerRequest* request,
38 TransferToServerResponse* response) {
39 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
40 return grpc_stub_->TransferToServer(context, *request, response);
41 });
42 }
43
TransferToInfeed(const TransferToInfeedRequest * request,TransferToInfeedResponse * response)44 Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request,
45 TransferToInfeedResponse* response) {
46 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
47 return grpc_stub_->TransferToInfeed(context, *request, response);
48 });
49 }
50
TransferFromOutfeed(const TransferFromOutfeedRequest * request,TransferFromOutfeedResponse * response)51 Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request,
52 TransferFromOutfeedResponse* response) {
53 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
54 return grpc_stub_->TransferFromOutfeed(context, *request, response);
55 });
56 }
57
ResetDevice(const ResetDeviceRequest * request,ResetDeviceResponse * response)58 Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
59 ResetDeviceResponse* response) {
60 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
61 return grpc_stub_->ResetDevice(context, *request, response);
62 });
63 }
64
Compile(const CompileRequest * request,CompileResponse * response)65 Status GRPCStub::Compile(const CompileRequest* request,
66 CompileResponse* response) {
67 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
68 return grpc_stub_->Compile(context, *request, response);
69 });
70 }
71
Execute(const ExecuteRequest * request,ExecuteResponse * response)72 Status GRPCStub::Execute(const ExecuteRequest* request,
73 ExecuteResponse* response) {
74 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
75 return grpc_stub_->Execute(context, *request, response);
76 });
77 }
78
ExecuteGraphParallel(const ExecuteGraphParallelRequest * request,ExecuteParallelResponse * response)79 Status GRPCStub::ExecuteGraphParallel(
80 const ExecuteGraphParallelRequest* request,
81 ExecuteParallelResponse* response) {
82 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
83 return grpc_stub_->ExecuteGraphParallel(context, *request, response);
84 });
85 }
86
WaitForExecution(const WaitForExecutionRequest * request,WaitForExecutionResponse * response)87 Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request,
88 WaitForExecutionResponse* response) {
89 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
90 return grpc_stub_->WaitForExecution(context, *request, response);
91 });
92 }
93
DeconstructTuple(const DeconstructTupleRequest * request,DeconstructTupleResponse * response)94 Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request,
95 DeconstructTupleResponse* response) {
96 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
97 return grpc_stub_->DeconstructTuple(context, *request, response);
98 });
99 }
100
GetComputationGraphStats(const ComputationGraphStatsRequest * request,ComputationStatsResponse * response)101 Status GRPCStub::GetComputationGraphStats(
102 const ComputationGraphStatsRequest* request,
103 ComputationStatsResponse* response) {
104 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
105 return grpc_stub_->GetComputationGraphStats(context, *request, response);
106 });
107 }
108
GetShape(const GetShapeRequest * request,GetShapeResponse * response)109 Status GRPCStub::GetShape(const GetShapeRequest* request,
110 GetShapeResponse* response) {
111 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
112 return grpc_stub_->GetShape(context, *request, response);
113 });
114 }
115
GetDeviceHandles(const GetDeviceHandlesRequest * request,GetDeviceHandlesResponse * response)116 Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request,
117 GetDeviceHandlesResponse* response) {
118 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
119 return grpc_stub_->GetDeviceHandles(context, *request, response);
120 });
121 }
122
CreateChannelHandle(const CreateChannelHandleRequest * request,CreateChannelHandleResponse * response)123 Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request,
124 CreateChannelHandleResponse* response) {
125 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
126 return grpc_stub_->CreateChannelHandle(context, *request, response);
127 });
128 }
129
ComputeConstantGraph(const ComputeConstantGraphRequest * request,ComputeConstantResponse * response)130 Status GRPCStub::ComputeConstantGraph(
131 const ComputeConstantGraphRequest* request,
132 ComputeConstantResponse* response) {
133 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
134 return grpc_stub_->ComputeConstantGraph(context, *request, response);
135 });
136 }
137
138 // Methods used by GlobalData.
Unregister(const UnregisterRequest * request,UnregisterResponse * response)139 Status GRPCStub::Unregister(const UnregisterRequest* request,
140 UnregisterResponse* response) {
141 return MakeRPC([this, request, response](::grpc::ClientContext* context) {
142 return grpc_stub_->Unregister(context, *request, response);
143 });
144 }
145
146 } // namespace xla
147