1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Functions for working with pw_rpc packets.""" 15 16from typing import Optional 17 18from google.protobuf import message 19from pw_status import Status 20 21from pw_rpc.internal import packet_pb2 22 23 24def decode(data: bytes) -> packet_pb2.RpcPacket: 25 packet = packet_pb2.RpcPacket() 26 packet.MergeFromString(data) 27 return packet 28 29 30def decode_payload(packet, payload_type): 31 payload = payload_type() 32 payload.MergeFromString(packet.payload) 33 return payload 34 35 36def _ids(rpc: tuple) -> tuple: 37 return tuple(item if isinstance(item, int) else item.id for item in rpc) 38 39 40def encode_request(rpc: tuple, request: Optional[message.Message]) -> bytes: 41 channel, service, method = _ids(rpc) 42 payload = request.SerializeToString() if request is not None else bytes() 43 44 return packet_pb2.RpcPacket( 45 type=packet_pb2.PacketType.REQUEST, 46 channel_id=channel, 47 service_id=service, 48 method_id=method, 49 payload=payload, 50 ).SerializeToString() 51 52 53def encode_response(rpc: tuple, response: message.Message) -> bytes: 54 channel, service, method = _ids(rpc) 55 56 return packet_pb2.RpcPacket( 57 type=packet_pb2.PacketType.RESPONSE, 58 channel_id=channel, 59 service_id=service, 60 method_id=method, 61 payload=response.SerializeToString(), 62 ).SerializeToString() 63 64 65def encode_client_stream(rpc: tuple, request: message.Message) -> bytes: 66 channel, service, method = _ids(rpc) 67 68 return packet_pb2.RpcPacket( 69 type=packet_pb2.PacketType.CLIENT_STREAM, 70 channel_id=channel, 71 service_id=service, 72 method_id=method, 73 payload=request.SerializeToString(), 74 ).SerializeToString() 75 76 77def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes: 78 return packet_pb2.RpcPacket( 79 type=packet_pb2.PacketType.CLIENT_ERROR, 80 channel_id=packet.channel_id, 81 service_id=packet.service_id, 82 method_id=packet.method_id, 83 status=status.value, 84 ).SerializeToString() 85 86 87def encode_cancel(rpc: tuple) -> bytes: 88 channel, service, method = _ids(rpc) 89 return packet_pb2.RpcPacket( 90 type=packet_pb2.PacketType.CLIENT_ERROR, 91 status=Status.CANCELLED.value, 92 channel_id=channel, 93 service_id=service, 94 method_id=method, 95 ).SerializeToString() 96 97 98def encode_client_stream_end(rpc: tuple) -> bytes: 99 channel, service, method = _ids(rpc) 100 101 return packet_pb2.RpcPacket( 102 type=packet_pb2.PacketType.CLIENT_STREAM_END, 103 channel_id=channel, 104 service_id=service, 105 method_id=method, 106 ).SerializeToString() 107 108 109def for_server(packet: packet_pb2.RpcPacket) -> bool: 110 return packet.type % 2 == 0 111