1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Functions for working with pw_rpc packets.""" 15 16import dataclasses 17 18from google.protobuf import message 19from pw_status import Status 20 21from pw_rpc.internal import packet_pb2 22 23 24def decode(data: bytes) -> packet_pb2.RpcPacket: 25 packet = packet_pb2.RpcPacket() 26 packet.MergeFromString(data) 27 return packet 28 29 30def decode_payload(packet, payload_type): 31 payload = payload_type() 32 payload.MergeFromString(packet.payload) 33 return payload 34 35 36@dataclasses.dataclass(eq=True, frozen=True) 37class RpcIds: 38 """Integer IDs that uniquely identify a remote procedure call.""" 39 40 channel_id: int 41 service_id: int 42 method_id: int 43 call_id: int 44 45 46def encode_request(rpc: RpcIds, request: message.Message | None) -> bytes: 47 payload = request.SerializeToString() if request is not None else bytes() 48 49 return packet_pb2.RpcPacket( 50 type=packet_pb2.PacketType.REQUEST, 51 channel_id=rpc.channel_id, 52 service_id=rpc.service_id, 53 method_id=rpc.method_id, 54 call_id=rpc.call_id, 55 payload=payload, 56 ).SerializeToString() 57 58 59def encode_response(rpc: RpcIds, response: message.Message) -> bytes: 60 return packet_pb2.RpcPacket( 61 type=packet_pb2.PacketType.RESPONSE, 62 channel_id=rpc.channel_id, 63 service_id=rpc.service_id, 64 method_id=rpc.method_id, 65 call_id=rpc.call_id, 66 payload=response.SerializeToString(), 67 ).SerializeToString() 68 69 70def encode_client_stream(rpc: RpcIds, request: message.Message) -> bytes: 71 return packet_pb2.RpcPacket( 72 type=packet_pb2.PacketType.CLIENT_STREAM, 73 channel_id=rpc.channel_id, 74 service_id=rpc.service_id, 75 method_id=rpc.method_id, 76 call_id=rpc.call_id, 77 payload=request.SerializeToString(), 78 ).SerializeToString() 79 80 81def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes: 82 return packet_pb2.RpcPacket( 83 type=packet_pb2.PacketType.CLIENT_ERROR, 84 channel_id=packet.channel_id, 85 service_id=packet.service_id, 86 method_id=packet.method_id, 87 call_id=packet.call_id, 88 status=status.value, 89 ).SerializeToString() 90 91 92def encode_cancel(rpc: RpcIds) -> bytes: 93 return packet_pb2.RpcPacket( 94 type=packet_pb2.PacketType.CLIENT_ERROR, 95 status=Status.CANCELLED.value, 96 channel_id=rpc.channel_id, 97 service_id=rpc.service_id, 98 method_id=rpc.method_id, 99 call_id=rpc.call_id, 100 ).SerializeToString() 101 102 103def encode_client_stream_end(rpc: RpcIds) -> bytes: 104 return packet_pb2.RpcPacket( 105 type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION, 106 channel_id=rpc.channel_id, 107 service_id=rpc.service_id, 108 method_id=rpc.method_id, 109 call_id=rpc.call_id, 110 ).SerializeToString() 111 112 113def for_server(packet: packet_pb2.RpcPacket) -> bool: 114 return packet.type % 2 == 0 115