• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Copied from auto-generated gRPC code in order to enable using grpc_call.h
16 // for raw message handling.
17 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_
18 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_
19 
20 #include <functional>
21 
22 #include "grpcpp/impl/codegen/async_generic_service.h"
23 #include "grpcpp/impl/codegen/async_stream.h"
24 #include "grpcpp/impl/codegen/async_unary_call.h"
25 #include "grpcpp/impl/codegen/client_callback.h"
26 #include "grpcpp/impl/codegen/client_context.h"
27 #include "grpcpp/impl/codegen/completion_queue.h"
28 #include "grpcpp/impl/codegen/method_handler.h"
29 #include "grpcpp/impl/codegen/proto_utils.h"
30 #include "grpcpp/impl/codegen/rpc_method.h"
31 #include "grpcpp/impl/codegen/server_callback.h"
32 #include "grpcpp/impl/codegen/server_context.h"
33 #include "grpcpp/impl/codegen/service_type.h"
34 #include "grpcpp/impl/codegen/status.h"
35 #include "grpcpp/impl/codegen/stub_options.h"
36 #include "grpcpp/impl/codegen/sync_stream.h"
37 
38 #if defined(LIBTPU_ON_GCE)
39 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
40 #else
41 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"  // copybara"
42 #endif
43 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
44 
45 namespace tensorflow {
46 namespace tpu {
47 namespace grpc {
48 class TpuCompilationCacheService final {
49  public:
50   using RequestType = ::tensorflow::tpu::GetTpuProgramRequest;
51 #if defined(LIBTPU_ON_GCE)
52   using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal;
53 #else
54   using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse;
55 #endif
56 
57   // N.B. This must be synchronized with the method order in
58   // tpu_compilation_cache.proto.
59   enum class MethodId { kGetTpuProgram = 0 };
60 
service_full_name()61   static constexpr char const* service_full_name() {
62 #if defined(LIBTPU_ON_GCE)
63     return "tensorflow.tpu.TpuCompilationCacheServiceExternal";
64 #else
65     return "tensorflow.tpu.TpuCompilationCacheService";
66 #endif
67   }
68   class StubInterface {
69    public:
~StubInterface()70     virtual ~StubInterface() {}
71     // This method requests the cached proto that the TPU execute op has
72     // been instructed to execute.
73     virtual ::grpc::Status GetTpuProgram(::grpc::ClientContext* context,
74                                          const RequestType& request,
75                                          ResponseType* response) = 0;
76     std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface<ResponseType>>
AsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)77     AsyncGetTpuProgram(::grpc::ClientContext* context,
78                        const RequestType& request,
79                        ::grpc::CompletionQueue* cq) {
80       return std::unique_ptr<
81           ::grpc::ClientAsyncResponseReaderInterface<ResponseType>>(
82           AsyncGetTpuProgramRaw(context, request, cq));
83     }
84     std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface<ResponseType>>
PrepareAsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)85     PrepareAsyncGetTpuProgram(::grpc::ClientContext* context,
86                               const RequestType& request,
87                               ::grpc::CompletionQueue* cq) {
88       return std::unique_ptr<
89           ::grpc::ClientAsyncResponseReaderInterface<ResponseType>>(
90           PrepareAsyncGetTpuProgramRaw(context, request, cq));
91     }
92 
93    private:
94     virtual ::grpc::ClientAsyncResponseReaderInterface<ResponseType>*
95     AsyncGetTpuProgramRaw(::grpc::ClientContext* context,
96                           const RequestType& request,
97                           ::grpc::CompletionQueue* cq) = 0;
98     virtual ::grpc::ClientAsyncResponseReaderInterface<ResponseType>*
99     PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context,
100                                  const RequestType& request,
101                                  ::grpc::CompletionQueue* cq) = 0;
102   };
103   class Stub final : public StubInterface {
104    public:
105     explicit Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel);
106     ::grpc::Status GetTpuProgram(::grpc::ClientContext* context,
107                                  const RequestType& request,
108                                  ResponseType* response) override;
109     std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>
AsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)110     AsyncGetTpuProgram(::grpc::ClientContext* context,
111                        const RequestType& request,
112                        ::grpc::CompletionQueue* cq) {
113       return std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>(
114           AsyncGetTpuProgramRaw(context, request, cq));
115     }
116     std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>
PrepareAsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)117     PrepareAsyncGetTpuProgram(::grpc::ClientContext* context,
118                               const RequestType& request,
119                               ::grpc::CompletionQueue* cq) {
120       return std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>(
121           PrepareAsyncGetTpuProgramRaw(context, request, cq));
122     }
123 
124    private:
125     std::shared_ptr<::grpc::ChannelInterface> channel_;
126     ::grpc::ClientAsyncResponseReader<ResponseType>* AsyncGetTpuProgramRaw(
127         ::grpc::ClientContext* context, const RequestType& request,
128         ::grpc::CompletionQueue* cq) override;
129     ::grpc::ClientAsyncResponseReader<ResponseType>*
130     PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context,
131                                  const RequestType& request,
132                                  ::grpc::CompletionQueue* cq) override;
133     const ::grpc::internal::RpcMethod rpcmethod_get_tpu_program_;
134   };
135   static std::unique_ptr<Stub> NewStub(
136       const std::shared_ptr<::grpc::ChannelInterface>& channel,
137       const ::grpc::StubOptions& options = ::grpc::StubOptions());
138 
139   class Service : public ::grpc::Service {
140    public:
141     Service();
142     ~Service() override;
143     // This method requests the cached proto that the TPU execute op has
144     // been instructed to execute.
145     virtual ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
146                                          const RequestType* request,
147                                          ResponseType* response);
148   };
149   template <class BaseClass>
150   class WithAsyncMethod_GetTpuProgram : public BaseClass {
151    private:
BaseClassMustBeDerivedFromService(const Service * service)152     void BaseClassMustBeDerivedFromService(const Service* service) {}
153 
154    public:
WithAsyncMethod_GetTpuProgram()155     WithAsyncMethod_GetTpuProgram() { ::grpc::Service::MarkMethodAsync(0); }
~WithAsyncMethod_GetTpuProgram()156     ~WithAsyncMethod_GetTpuProgram() override {
157       BaseClassMustBeDerivedFromService(this);
158     }
159     // disable synchronous version of this method
GetTpuProgram(::grpc::ServerContext * context,const RequestType * request,ResponseType * response)160     ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
161                                  const RequestType* request,
162                                  ResponseType* response) override {
163       abort();
164       return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
165     }
RequestGetTpuProgram(::grpc::ServerContext * context,RequestType * request,::grpc::ServerAsyncResponseWriter<ResponseType> * response,::grpc::CompletionQueue * new_call_cq,::grpc::ServerCompletionQueue * notification_cq,void * tag)166     void RequestGetTpuProgram(
167         ::grpc::ServerContext* context, RequestType* request,
168         ::grpc::ServerAsyncResponseWriter<ResponseType>* response,
169         ::grpc::CompletionQueue* new_call_cq,
170         ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
171       ::grpc::Service::RequestAsyncUnary(0, context, request, response,
172                                          new_call_cq, notification_cq, tag);
173     }
174 
175     // Make RequestAsyncUnary accessible to grpc_call.h
176     using ::grpc::Service::RequestAsyncUnary;
177   };
178   typedef WithAsyncMethod_GetTpuProgram<Service> AsyncService;
179   template <class BaseClass>
180   class WithGenericMethod_GetTpuProgram : public BaseClass {
181    private:
BaseClassMustBeDerivedFromService(const Service * service)182     void BaseClassMustBeDerivedFromService(const Service* service) {}
183 
184    public:
WithGenericMethod_GetTpuProgram()185     WithGenericMethod_GetTpuProgram() { ::grpc::Service::MarkMethodGeneric(0); }
~WithGenericMethod_GetTpuProgram()186     ~WithGenericMethod_GetTpuProgram() override {
187       BaseClassMustBeDerivedFromService(this);
188     }
189     // disable synchronous version of this method
GetTpuProgram(::grpc::ServerContext * context,const RequestType * request,ResponseType * response)190     ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
191                                  const RequestType* request,
192                                  ResponseType* response) override {
193       abort();
194       return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
195     }
196   };
197   template <class BaseClass>
198   class WithStreamedUnaryMethod_GetTpuProgram : public BaseClass {
199    private:
BaseClassMustBeDerivedFromService(const Service * service)200     void BaseClassMustBeDerivedFromService(const Service* service) {}
201 
202    public:
WithStreamedUnaryMethod_GetTpuProgram()203     WithStreamedUnaryMethod_GetTpuProgram() {
204       ::grpc::Service::MarkMethodStreamed(
205           0,
206           new ::grpc::internal::StreamedUnaryHandler<RequestType, ResponseType>(
207               std::bind(&WithStreamedUnaryMethod_GetTpuProgram<
208                             BaseClass>::StreamedGetTpuProgram,
209                         this, std::placeholders::_1, std::placeholders::_2)));
210     }
~WithStreamedUnaryMethod_GetTpuProgram()211     ~WithStreamedUnaryMethod_GetTpuProgram() override {
212       BaseClassMustBeDerivedFromService(this);
213     }
214     // disable regular version of this method
GetTpuProgram(::grpc::ServerContext * context,const RequestType * request,ResponseType * response)215     ::grpc::Status GetTpuProgram(::grpc::ServerContext* context,
216                                  const RequestType* request,
217                                  ResponseType* response) override {
218       abort();
219       return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, "");
220     }
221     // replace default version of method with streamed unary
222     virtual ::grpc::Status StreamedGetTpuProgram(
223         ::grpc::ServerContext* context,
224         ::grpc::ServerUnaryStreamer<RequestType, ResponseType>*
225             server_unary_streamer) = 0;
226   };
227   typedef WithStreamedUnaryMethod_GetTpuProgram<Service> StreamedUnaryService;
228   typedef Service SplitStreamedService;
229   typedef WithStreamedUnaryMethod_GetTpuProgram<Service> StreamedService;
230 };
231 }  // namespace grpc
232 }  // namespace tpu
233 }  // namespace tensorflow
234 
235 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_
236