1 // Copyright 2020 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 #pragma once 15 16 #include <new> 17 18 #include "pw_bytes/span.h" 19 #include "pw_rpc/internal/base_client_call.h" 20 #include "pw_rpc/internal/method_type.h" 21 #include "pw_rpc/internal/nanopb_common.h" 22 #include "pw_status/status.h" 23 24 namespace pw::rpc { 25 26 // Response handler callback for unary RPC methods. 27 template <typename Response> 28 class UnaryResponseHandler { 29 public: 30 virtual ~UnaryResponseHandler() = default; 31 32 // Called when the response is received from the server with the method's 33 // status and the deserialized response struct. 34 virtual void ReceivedResponse(Status status, const Response& response) = 0; 35 36 // Called when an error occurs internally in the RPC client or server. RpcError(Status)37 virtual void RpcError(Status) {} 38 }; 39 40 // Response handler callbacks for server streaming RPC methods. 41 template <typename Response> 42 class ServerStreamingResponseHandler { 43 public: 44 virtual ~ServerStreamingResponseHandler() = default; 45 46 // Called on every response received from the server with the deserialized 47 // response struct. 48 virtual void ReceivedResponse(const Response& response) = 0; 49 50 // Called when the server ends the stream with the overall RPC status. 51 virtual void Complete(Status status) = 0; 52 53 // Called when an error occurs internally in the RPC client or server. RpcError(Status)54 virtual void RpcError(Status) {} 55 }; 56 57 namespace internal { 58 59 // Non-templated nanopb base class providing protobuf encoding and decoding. 60 class BaseNanopbClientCall : public BaseClientCall { 61 public: 62 Status SendRequest(const void* request_struct); 63 64 protected: BaseNanopbClientCall(rpc::Channel * channel,uint32_t service_id,uint32_t method_id,ResponseHandler handler,internal::NanopbMessageDescriptor request_fields,internal::NanopbMessageDescriptor response_fields)65 constexpr BaseNanopbClientCall( 66 rpc::Channel* channel, 67 uint32_t service_id, 68 uint32_t method_id, 69 ResponseHandler handler, 70 internal::NanopbMessageDescriptor request_fields, 71 internal::NanopbMessageDescriptor response_fields) 72 : BaseClientCall(channel, service_id, method_id, handler), 73 serde_(request_fields, response_fields) {} 74 serde()75 constexpr const internal::NanopbMethodSerde& serde() const { return serde_; } 76 77 private: 78 internal::NanopbMethodSerde serde_; 79 }; 80 81 template <typename Callback> 82 struct CallbackTraits {}; 83 84 template <typename ResponseType> 85 struct CallbackTraits<UnaryResponseHandler<ResponseType>> { 86 using Response = ResponseType; 87 88 static constexpr MethodType kType = MethodType::kUnary; 89 }; 90 91 template <typename ResponseType> 92 struct CallbackTraits<ServerStreamingResponseHandler<ResponseType>> { 93 using Response = ResponseType; 94 95 static constexpr MethodType kType = MethodType::kServerStreaming; 96 }; 97 98 } // namespace internal 99 100 template <typename Callback> 101 class NanopbClientCall : public internal::BaseNanopbClientCall { 102 public: 103 constexpr NanopbClientCall(Channel* channel, 104 uint32_t service_id, 105 uint32_t method_id, 106 Callback& callback, 107 internal::NanopbMessageDescriptor request_fields, 108 internal::NanopbMessageDescriptor response_fields) 109 : BaseNanopbClientCall(channel, 110 service_id, 111 method_id, 112 &ResponseHandler, 113 request_fields, 114 response_fields), 115 callback_(callback) {} 116 117 private: 118 using Traits = internal::CallbackTraits<Callback>; 119 using Response = typename Traits::Response; 120 121 // Buffer into which the nanopb struct is decoded. Its contents are unknown, 122 // so it is aligned to maximum alignment to accommodate any type. 123 using ResponseBuffer = 124 std::aligned_storage_t<sizeof(Response), alignof(std::max_align_t)>; 125 126 friend class Client; 127 128 static void ResponseHandler(internal::BaseClientCall& call, 129 const internal::Packet& packet) { 130 static_cast<NanopbClientCall<Callback>&>(call).HandleResponse(packet); 131 } 132 133 void HandleResponse(const internal::Packet& packet) { 134 if constexpr (Traits::kType == internal::MethodType::kUnary) { 135 InvokeUnaryCallback(packet); 136 } 137 if constexpr (Traits::kType == internal::MethodType::kServerStreaming) { 138 InvokeServerStreamingCallback(packet); 139 } 140 } 141 142 void InvokeUnaryCallback(const internal::Packet& packet) { 143 if (packet.type() == internal::PacketType::SERVER_ERROR) { 144 callback_.RpcError(packet.status()); 145 return; 146 } 147 148 ResponseBuffer response_struct{}; 149 150 if (serde().DecodeResponse(&response_struct, packet.payload())) { 151 callback_.ReceivedResponse( 152 packet.status(), 153 *std::launder(reinterpret_cast<Response*>(&response_struct))); 154 } else { 155 callback_.RpcError(Status::DataLoss()); 156 } 157 158 Unregister(); 159 } 160 161 void InvokeServerStreamingCallback(const internal::Packet& packet) { 162 if (packet.type() == internal::PacketType::SERVER_ERROR) { 163 callback_.RpcError(packet.status()); 164 return; 165 } 166 167 if (packet.type() == internal::PacketType::SERVER_STREAM_END) { 168 callback_.Complete(packet.status()); 169 return; 170 } 171 172 ResponseBuffer response_struct{}; 173 174 if (serde().DecodeResponse(&response_struct, packet.payload())) { 175 callback_.ReceivedResponse( 176 *std::launder(reinterpret_cast<Response*>(&response_struct))); 177 } else { 178 callback_.RpcError(Status::DataLoss()); 179 } 180 } 181 182 Callback& callback_; 183 }; 184 185 } // namespace pw::rpc 186