• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/rpc/grpc_stub.h"
17 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
18 
19 namespace xla {
20 
21 GRPCStub::~GRPCStub() = default;
22 
MakeRPC(const std::function<::grpc::Status (::grpc::ClientContext *)> & rpc_method)23 Status MakeRPC(
24     const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) {
25   ::grpc::ClientContext context;
26   ::grpc::Status s = rpc_method(&context);
27   return tensorflow::FromGrpcStatus(s);
28 }
29 
TransferToClient(const TransferToClientRequest * request,TransferToClientResponse * response)30 Status GRPCStub::TransferToClient(const TransferToClientRequest* request,
31                                   TransferToClientResponse* response) {
32   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
33     return grpc_stub_->TransferToClient(context, *request, response);
34   });
35 }
36 
TransferToServer(const TransferToServerRequest * request,TransferToServerResponse * response)37 Status GRPCStub::TransferToServer(const TransferToServerRequest* request,
38                                   TransferToServerResponse* response) {
39   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
40     return grpc_stub_->TransferToServer(context, *request, response);
41   });
42 }
43 
TransferToInfeed(const TransferToInfeedRequest * request,TransferToInfeedResponse * response)44 Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request,
45                                   TransferToInfeedResponse* response) {
46   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
47     return grpc_stub_->TransferToInfeed(context, *request, response);
48   });
49 }
50 
TransferFromOutfeed(const TransferFromOutfeedRequest * request,TransferFromOutfeedResponse * response)51 Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request,
52                                      TransferFromOutfeedResponse* response) {
53   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
54     return grpc_stub_->TransferFromOutfeed(context, *request, response);
55   });
56 }
57 
ResetDevice(const ResetDeviceRequest * request,ResetDeviceResponse * response)58 Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
59                              ResetDeviceResponse* response) {
60   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
61     return grpc_stub_->ResetDevice(context, *request, response);
62   });
63 }
64 
Compile(const CompileRequest * request,CompileResponse * response)65 Status GRPCStub::Compile(const CompileRequest* request,
66                          CompileResponse* response) {
67   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
68     return grpc_stub_->Compile(context, *request, response);
69   });
70 }
71 
Execute(const ExecuteRequest * request,ExecuteResponse * response)72 Status GRPCStub::Execute(const ExecuteRequest* request,
73                          ExecuteResponse* response) {
74   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
75     return grpc_stub_->Execute(context, *request, response);
76   });
77 }
78 
ExecuteGraphParallel(const ExecuteGraphParallelRequest * request,ExecuteParallelResponse * response)79 Status GRPCStub::ExecuteGraphParallel(
80     const ExecuteGraphParallelRequest* request,
81     ExecuteParallelResponse* response) {
82   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
83     return grpc_stub_->ExecuteGraphParallel(context, *request, response);
84   });
85 }
86 
WaitForExecution(const WaitForExecutionRequest * request,WaitForExecutionResponse * response)87 Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request,
88                                   WaitForExecutionResponse* response) {
89   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
90     return grpc_stub_->WaitForExecution(context, *request, response);
91   });
92 }
93 
DeconstructTuple(const DeconstructTupleRequest * request,DeconstructTupleResponse * response)94 Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request,
95                                   DeconstructTupleResponse* response) {
96   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
97     return grpc_stub_->DeconstructTuple(context, *request, response);
98   });
99 }
100 
GetComputationGraphStats(const ComputationGraphStatsRequest * request,ComputationStatsResponse * response)101 Status GRPCStub::GetComputationGraphStats(
102     const ComputationGraphStatsRequest* request,
103     ComputationStatsResponse* response) {
104   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
105     return grpc_stub_->GetComputationGraphStats(context, *request, response);
106   });
107 }
108 
GetShape(const GetShapeRequest * request,GetShapeResponse * response)109 Status GRPCStub::GetShape(const GetShapeRequest* request,
110                           GetShapeResponse* response) {
111   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
112     return grpc_stub_->GetShape(context, *request, response);
113   });
114 }
115 
GetDeviceHandles(const GetDeviceHandlesRequest * request,GetDeviceHandlesResponse * response)116 Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request,
117                                   GetDeviceHandlesResponse* response) {
118   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
119     return grpc_stub_->GetDeviceHandles(context, *request, response);
120   });
121 }
122 
CreateChannelHandle(const CreateChannelHandleRequest * request,CreateChannelHandleResponse * response)123 Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request,
124                                      CreateChannelHandleResponse* response) {
125   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
126     return grpc_stub_->CreateChannelHandle(context, *request, response);
127   });
128 }
129 
ComputeConstantGraph(const ComputeConstantGraphRequest * request,ComputeConstantResponse * response)130 Status GRPCStub::ComputeConstantGraph(
131     const ComputeConstantGraphRequest* request,
132     ComputeConstantResponse* response) {
133   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
134     return grpc_stub_->ComputeConstantGraph(context, *request, response);
135   });
136 }
137 
138 // Methods used by GlobalData.
Unregister(const UnregisterRequest * request,UnregisterResponse * response)139 Status GRPCStub::Unregister(const UnregisterRequest* request,
140                             UnregisterResponse* response) {
141   return MakeRPC([this, request, response](::grpc::ClientContext* context) {
142     return grpc_stub_->Unregister(context, *request, response);
143   });
144 }
145 
146 }  // namespace xla
147