1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Copied from auto-generated gRPC code in order to enable using grpc_call.h 16 // for raw message handling. 17 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_ 18 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_ 19 20 #include <functional> 21 22 #include "grpcpp/impl/codegen/async_generic_service.h" 23 #include "grpcpp/impl/codegen/async_stream.h" 24 #include "grpcpp/impl/codegen/async_unary_call.h" 25 #include "grpcpp/impl/codegen/client_callback.h" 26 #include "grpcpp/impl/codegen/client_context.h" 27 #include "grpcpp/impl/codegen/completion_queue.h" 28 #include "grpcpp/impl/codegen/method_handler.h" 29 #include "grpcpp/impl/codegen/proto_utils.h" 30 #include "grpcpp/impl/codegen/rpc_method.h" 31 #include "grpcpp/impl/codegen/server_callback.h" 32 #include "grpcpp/impl/codegen/server_context.h" 33 #include "grpcpp/impl/codegen/service_type.h" 34 #include "grpcpp/impl/codegen/status.h" 35 #include "grpcpp/impl/codegen/stub_options.h" 36 #include "grpcpp/impl/codegen/sync_stream.h" 37 38 #if defined(LIBTPU_ON_GCE) 39 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" 40 #else 41 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara" 42 #endif 43 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" 44 45 namespace tensorflow { 46 namespace tpu { 47 namespace grpc { 48 class TpuCompilationCacheService final { 49 public: 50 using RequestType = ::tensorflow::tpu::GetTpuProgramRequest; 51 #if defined(LIBTPU_ON_GCE) 52 using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal; 53 #else 54 using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse; 55 #endif 56 57 // N.B. This must be synchronized with the method order in 58 // tpu_compilation_cache.proto. 59 enum class MethodId { kGetTpuProgram = 0 }; 60 service_full_name()61 static constexpr char const* service_full_name() { 62 #if defined(LIBTPU_ON_GCE) 63 return "tensorflow.tpu.TpuCompilationCacheServiceExternal"; 64 #else 65 return "tensorflow.tpu.TpuCompilationCacheService"; 66 #endif 67 } 68 class StubInterface { 69 public: ~StubInterface()70 virtual ~StubInterface() {} 71 // This method requests the cached proto that the TPU execute op has 72 // been instructed to execute. 73 virtual ::grpc::Status GetTpuProgram(::grpc::ClientContext* context, 74 const RequestType& request, 75 ResponseType* response) = 0; 76 std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface<ResponseType>> AsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)77 AsyncGetTpuProgram(::grpc::ClientContext* context, 78 const RequestType& request, 79 ::grpc::CompletionQueue* cq) { 80 return std::unique_ptr< 81 ::grpc::ClientAsyncResponseReaderInterface<ResponseType>>( 82 AsyncGetTpuProgramRaw(context, request, cq)); 83 } 84 std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface<ResponseType>> PrepareAsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)85 PrepareAsyncGetTpuProgram(::grpc::ClientContext* context, 86 const RequestType& request, 87 ::grpc::CompletionQueue* cq) { 88 return std::unique_ptr< 89 ::grpc::ClientAsyncResponseReaderInterface<ResponseType>>( 90 PrepareAsyncGetTpuProgramRaw(context, request, cq)); 91 } 92 93 private: 94 virtual ::grpc::ClientAsyncResponseReaderInterface<ResponseType>* 95 AsyncGetTpuProgramRaw(::grpc::ClientContext* context, 96 const RequestType& request, 97 ::grpc::CompletionQueue* cq) = 0; 98 virtual ::grpc::ClientAsyncResponseReaderInterface<ResponseType>* 99 PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context, 100 const RequestType& request, 101 ::grpc::CompletionQueue* cq) = 0; 102 }; 103 class Stub final : public StubInterface { 104 public: 105 explicit Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); 106 ::grpc::Status GetTpuProgram(::grpc::ClientContext* context, 107 const RequestType& request, 108 ResponseType* response) override; 109 std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>> AsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)110 AsyncGetTpuProgram(::grpc::ClientContext* context, 111 const RequestType& request, 112 ::grpc::CompletionQueue* cq) { 113 return std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>( 114 AsyncGetTpuProgramRaw(context, request, cq)); 115 } 116 std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>> PrepareAsyncGetTpuProgram(::grpc::ClientContext * context,const RequestType & request,::grpc::CompletionQueue * cq)117 PrepareAsyncGetTpuProgram(::grpc::ClientContext* context, 118 const RequestType& request, 119 ::grpc::CompletionQueue* cq) { 120 return std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseType>>( 121 PrepareAsyncGetTpuProgramRaw(context, request, cq)); 122 } 123 124 private: 125 std::shared_ptr<::grpc::ChannelInterface> channel_; 126 ::grpc::ClientAsyncResponseReader<ResponseType>* AsyncGetTpuProgramRaw( 127 ::grpc::ClientContext* context, const RequestType& request, 128 ::grpc::CompletionQueue* cq) override; 129 ::grpc::ClientAsyncResponseReader<ResponseType>* 130 PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context, 131 const RequestType& request, 132 ::grpc::CompletionQueue* cq) override; 133 const ::grpc::internal::RpcMethod rpcmethod_get_tpu_program_; 134 }; 135 static std::unique_ptr<Stub> NewStub( 136 const std::shared_ptr<::grpc::ChannelInterface>& channel, 137 const ::grpc::StubOptions& options = ::grpc::StubOptions()); 138 139 class Service : public ::grpc::Service { 140 public: 141 Service(); 142 ~Service() override; 143 // This method requests the cached proto that the TPU execute op has 144 // been instructed to execute. 145 virtual ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, 146 const RequestType* request, 147 ResponseType* response); 148 }; 149 template <class BaseClass> 150 class WithAsyncMethod_GetTpuProgram : public BaseClass { 151 private: BaseClassMustBeDerivedFromService(const Service * service)152 void BaseClassMustBeDerivedFromService(const Service* service) {} 153 154 public: WithAsyncMethod_GetTpuProgram()155 WithAsyncMethod_GetTpuProgram() { ::grpc::Service::MarkMethodAsync(0); } ~WithAsyncMethod_GetTpuProgram()156 ~WithAsyncMethod_GetTpuProgram() override { 157 BaseClassMustBeDerivedFromService(this); 158 } 159 // disable synchronous version of this method GetTpuProgram(::grpc::ServerContext * context,const RequestType * request,ResponseType * response)160 ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, 161 const RequestType* request, 162 ResponseType* response) override { 163 abort(); 164 return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 165 } RequestGetTpuProgram(::grpc::ServerContext * context,RequestType * request,::grpc::ServerAsyncResponseWriter<ResponseType> * response,::grpc::CompletionQueue * new_call_cq,::grpc::ServerCompletionQueue * notification_cq,void * tag)166 void RequestGetTpuProgram( 167 ::grpc::ServerContext* context, RequestType* request, 168 ::grpc::ServerAsyncResponseWriter<ResponseType>* response, 169 ::grpc::CompletionQueue* new_call_cq, 170 ::grpc::ServerCompletionQueue* notification_cq, void* tag) { 171 ::grpc::Service::RequestAsyncUnary(0, context, request, response, 172 new_call_cq, notification_cq, tag); 173 } 174 175 // Make RequestAsyncUnary accessible to grpc_call.h 176 using ::grpc::Service::RequestAsyncUnary; 177 }; 178 typedef WithAsyncMethod_GetTpuProgram<Service> AsyncService; 179 template <class BaseClass> 180 class WithGenericMethod_GetTpuProgram : public BaseClass { 181 private: BaseClassMustBeDerivedFromService(const Service * service)182 void BaseClassMustBeDerivedFromService(const Service* service) {} 183 184 public: WithGenericMethod_GetTpuProgram()185 WithGenericMethod_GetTpuProgram() { ::grpc::Service::MarkMethodGeneric(0); } ~WithGenericMethod_GetTpuProgram()186 ~WithGenericMethod_GetTpuProgram() override { 187 BaseClassMustBeDerivedFromService(this); 188 } 189 // disable synchronous version of this method GetTpuProgram(::grpc::ServerContext * context,const RequestType * request,ResponseType * response)190 ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, 191 const RequestType* request, 192 ResponseType* response) override { 193 abort(); 194 return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 195 } 196 }; 197 template <class BaseClass> 198 class WithStreamedUnaryMethod_GetTpuProgram : public BaseClass { 199 private: BaseClassMustBeDerivedFromService(const Service * service)200 void BaseClassMustBeDerivedFromService(const Service* service) {} 201 202 public: WithStreamedUnaryMethod_GetTpuProgram()203 WithStreamedUnaryMethod_GetTpuProgram() { 204 ::grpc::Service::MarkMethodStreamed( 205 0, 206 new ::grpc::internal::StreamedUnaryHandler<RequestType, ResponseType>( 207 std::bind(&WithStreamedUnaryMethod_GetTpuProgram< 208 BaseClass>::StreamedGetTpuProgram, 209 this, std::placeholders::_1, std::placeholders::_2))); 210 } ~WithStreamedUnaryMethod_GetTpuProgram()211 ~WithStreamedUnaryMethod_GetTpuProgram() override { 212 BaseClassMustBeDerivedFromService(this); 213 } 214 // disable regular version of this method GetTpuProgram(::grpc::ServerContext * context,const RequestType * request,ResponseType * response)215 ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, 216 const RequestType* request, 217 ResponseType* response) override { 218 abort(); 219 return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); 220 } 221 // replace default version of method with streamed unary 222 virtual ::grpc::Status StreamedGetTpuProgram( 223 ::grpc::ServerContext* context, 224 ::grpc::ServerUnaryStreamer<RequestType, ResponseType>* 225 server_unary_streamer) = 0; 226 }; 227 typedef WithStreamedUnaryMethod_GetTpuProgram<Service> StreamedUnaryService; 228 typedef Service SplitStreamedService; 229 typedef WithStreamedUnaryMethod_GetTpuProgram<Service> StreamedService; 230 }; 231 } // namespace grpc 232 } // namespace tpu 233 } // namespace tensorflow 234 235 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_ 236