• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <new>
17 
18 #include "pw_bytes/span.h"
19 #include "pw_rpc/internal/base_client_call.h"
20 #include "pw_rpc/internal/method_type.h"
21 #include "pw_rpc/internal/nanopb_common.h"
22 #include "pw_status/status.h"
23 
24 namespace pw::rpc {
25 
26 // Response handler callback for unary RPC methods.
27 template <typename Response>
28 class UnaryResponseHandler {
29  public:
30   virtual ~UnaryResponseHandler() = default;
31 
32   // Called when the response is received from the server with the method's
33   // status and the deserialized response struct.
34   virtual void ReceivedResponse(Status status, const Response& response) = 0;
35 
36   // Called when an error occurs internally in the RPC client or server.
RpcError(Status)37   virtual void RpcError(Status) {}
38 };
39 
40 // Response handler callbacks for server streaming RPC methods.
41 template <typename Response>
42 class ServerStreamingResponseHandler {
43  public:
44   virtual ~ServerStreamingResponseHandler() = default;
45 
46   // Called on every response received from the server with the deserialized
47   // response struct.
48   virtual void ReceivedResponse(const Response& response) = 0;
49 
50   // Called when the server ends the stream with the overall RPC status.
51   virtual void Complete(Status status) = 0;
52 
53   // Called when an error occurs internally in the RPC client or server.
RpcError(Status)54   virtual void RpcError(Status) {}
55 };
56 
57 namespace internal {
58 
59 // Non-templated nanopb base class providing protobuf encoding and decoding.
60 class BaseNanopbClientCall : public BaseClientCall {
61  public:
62   Status SendRequest(const void* request_struct);
63 
64  protected:
BaseNanopbClientCall(rpc::Channel * channel,uint32_t service_id,uint32_t method_id,ResponseHandler handler,internal::NanopbMessageDescriptor request_fields,internal::NanopbMessageDescriptor response_fields)65   constexpr BaseNanopbClientCall(
66       rpc::Channel* channel,
67       uint32_t service_id,
68       uint32_t method_id,
69       ResponseHandler handler,
70       internal::NanopbMessageDescriptor request_fields,
71       internal::NanopbMessageDescriptor response_fields)
72       : BaseClientCall(channel, service_id, method_id, handler),
73         serde_(request_fields, response_fields) {}
74 
serde()75   constexpr const internal::NanopbMethodSerde& serde() const { return serde_; }
76 
77  private:
78   internal::NanopbMethodSerde serde_;
79 };
80 
81 template <typename Callback>
82 struct CallbackTraits {};
83 
84 template <typename ResponseType>
85 struct CallbackTraits<UnaryResponseHandler<ResponseType>> {
86   using Response = ResponseType;
87 
88   static constexpr MethodType kType = MethodType::kUnary;
89 };
90 
91 template <typename ResponseType>
92 struct CallbackTraits<ServerStreamingResponseHandler<ResponseType>> {
93   using Response = ResponseType;
94 
95   static constexpr MethodType kType = MethodType::kServerStreaming;
96 };
97 
98 }  // namespace internal
99 
100 template <typename Callback>
101 class NanopbClientCall : public internal::BaseNanopbClientCall {
102  public:
103   constexpr NanopbClientCall(Channel* channel,
104                              uint32_t service_id,
105                              uint32_t method_id,
106                              Callback& callback,
107                              internal::NanopbMessageDescriptor request_fields,
108                              internal::NanopbMessageDescriptor response_fields)
109       : BaseNanopbClientCall(channel,
110                              service_id,
111                              method_id,
112                              &ResponseHandler,
113                              request_fields,
114                              response_fields),
115         callback_(callback) {}
116 
117  private:
118   using Traits = internal::CallbackTraits<Callback>;
119   using Response = typename Traits::Response;
120 
121   // Buffer into which the nanopb struct is decoded. Its contents are unknown,
122   // so it is aligned to maximum alignment to accommodate any type.
123   using ResponseBuffer =
124       std::aligned_storage_t<sizeof(Response), alignof(std::max_align_t)>;
125 
126   friend class Client;
127 
128   static void ResponseHandler(internal::BaseClientCall& call,
129                               const internal::Packet& packet) {
130     static_cast<NanopbClientCall<Callback>&>(call).HandleResponse(packet);
131   }
132 
133   void HandleResponse(const internal::Packet& packet) {
134     if constexpr (Traits::kType == internal::MethodType::kUnary) {
135       InvokeUnaryCallback(packet);
136     }
137     if constexpr (Traits::kType == internal::MethodType::kServerStreaming) {
138       InvokeServerStreamingCallback(packet);
139     }
140   }
141 
142   void InvokeUnaryCallback(const internal::Packet& packet) {
143     if (packet.type() == internal::PacketType::SERVER_ERROR) {
144       callback_.RpcError(packet.status());
145       return;
146     }
147 
148     ResponseBuffer response_struct{};
149 
150     if (serde().DecodeResponse(&response_struct, packet.payload())) {
151       callback_.ReceivedResponse(
152           packet.status(),
153           *std::launder(reinterpret_cast<Response*>(&response_struct)));
154     } else {
155       callback_.RpcError(Status::DataLoss());
156     }
157 
158     Unregister();
159   }
160 
161   void InvokeServerStreamingCallback(const internal::Packet& packet) {
162     if (packet.type() == internal::PacketType::SERVER_ERROR) {
163       callback_.RpcError(packet.status());
164       return;
165     }
166 
167     if (packet.type() == internal::PacketType::SERVER_STREAM_END) {
168       callback_.Complete(packet.status());
169       return;
170     }
171 
172     ResponseBuffer response_struct{};
173 
174     if (serde().DecodeResponse(&response_struct, packet.payload())) {
175       callback_.ReceivedResponse(
176           *std::launder(reinterpret_cast<Response*>(&response_struct)));
177     } else {
178       callback_.RpcError(Status::DataLoss());
179     }
180   }
181 
182   Callback& callback_;
183 };
184 
185 }  // namespace pw::rpc
186