• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2022 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Unit test for proxy.py"""
16
17import abc
18import asyncio
19from struct import pack
20import time
21import unittest
22
23from pigweed.pw_rpc.internal import packet_pb2
24from pigweed.pw_transfer import transfer_pb2
25from pw_hdlc import encode
26from pw_transfer.chunk import Chunk, ProtocolVersion
27
28import proxy
29
30
31class MockRng(abc.ABC):
32    def __init__(self, results: list[float]):
33        self._results = results
34
35    def uniform(self, from_val: float, to_val: float) -> float:
36        val_range = to_val - from_val
37        val = self._results.pop()
38        val *= val_range
39        val += from_val
40        return val
41
42
43class ProxyTest(unittest.IsolatedAsyncioTestCase):
44    async def test_transposer_simple(self):
45        sent_packets: list[bytes] = []
46        new_packets_event: asyncio.Event = asyncio.Event()
47
48        # Async helper so DataTransposer can await on it.
49        async def append(list: list[bytes], data: bytes):
50            list.append(data)
51            # Notify that a new packet was "sent".
52            new_packets_event.set()
53
54        transposer = proxy.DataTransposer(
55            lambda data: append(sent_packets, data),
56            name="test",
57            rate=0.5,
58            timeout=100,
59            seed=1234567890,
60        )
61        transposer._rng = MockRng([0.6, 0.4])
62        await transposer.process(b'aaaaaaaaaa')
63        await transposer.process(b'bbbbbbbbbb')
64
65        expected_packets = [b'bbbbbbbbbb', b'aaaaaaaaaa']
66        while True:
67            # Wait for new packets with a generous timeout.
68            try:
69                await asyncio.wait_for(new_packets_event.wait(), timeout=60.0)
70            except TimeoutError:
71                self.fail(
72                    f'Timeout waiting for data.  Packets sent: {sent_packets}'
73                )
74
75            # Only assert the sent packets are corrected when we've sent the
76            # expected number.
77            if len(sent_packets) == len(expected_packets):
78                self.assertEqual(sent_packets, expected_packets)
79                return
80
81    async def test_transposer_timeout(self):
82        sent_packets: list[bytes] = []
83
84        # Async helper so DataTransposer can await on it.
85        async def append(list: list[bytes], data: bytes):
86            list.append(data)
87
88        transposer = proxy.DataTransposer(
89            lambda data: append(sent_packets, data),
90            name="test",
91            rate=0.5,
92            timeout=0.100,
93            seed=1234567890,
94        )
95        transposer._rng = MockRng([0.4, 0.6])
96        await transposer.process(b'aaaaaaaaaa')
97
98        # Even though this should be transposed, there is no following data so
99        # the transposer should timout and send this in-order.
100        await transposer.process(b'bbbbbbbbbb')
101
102        # Give the transposer time to timeout.
103        await asyncio.sleep(0.5)
104
105        self.assertEqual(sent_packets, [b'aaaaaaaaaa', b'bbbbbbbbbb'])
106
107    async def test_server_failure(self):
108        sent_packets: list[bytes] = []
109
110        # Async helper so DataTransposer can await on it.
111        async def append(list: list[bytes], data: bytes):
112            list.append(data)
113
114        packets_before_failure = [1, 2, 3]
115        server_failure = proxy.ServerFailure(
116            lambda data: append(sent_packets, data),
117            name="test",
118            packets_before_failure_list=packets_before_failure.copy(),
119            start_immediately=True,
120        )
121
122        # After passing the list to ServerFailure, add a test for no
123        # packets dropped
124        packets_before_failure.append(5)
125
126        packets = [
127            b'1',
128            b'2',
129            b'3',
130            b'4',
131            b'5',
132        ]
133
134        for num_packets in packets_before_failure:
135            sent_packets.clear()
136            for packet in packets:
137                await server_failure.process(packet)
138            self.assertEqual(len(sent_packets), num_packets)
139            server_failure.handle_event(
140                proxy.Event(
141                    proxy.EventType.TRANSFER_START,
142                    Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.START),
143                )
144            )
145
146    async def test_server_failure_transfer_chunks_only(self):
147        sent_packets = []
148
149        # Async helper so DataTransposer can await on it.
150        async def append(list: list[bytes], data: bytes):
151            list.append(data)
152
153        packets_before_failure = [2]
154        server_failure = proxy.ServerFailure(
155            lambda data: append(sent_packets, data),
156            name="test",
157            packets_before_failure_list=packets_before_failure.copy(),
158            start_immediately=True,
159            only_consider_transfer_chunks=True,
160        )
161
162        transfer_chunk = _encode_rpc_frame(
163            Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.DATA, data=b'1')
164        )
165
166        packets = [
167            b'1',
168            b'2',
169            transfer_chunk,  # 1
170            b'3',
171            transfer_chunk,  # 2
172            b'4',
173            b'5',
174            transfer_chunk,  # Transfer chunks should be dropped starting here.
175            transfer_chunk,
176            b'6',
177            b'7',
178            transfer_chunk,
179        ]
180
181        for packet in packets:
182            await server_failure.process(packet)
183
184        expected_result = [
185            b'1',
186            b'2',
187            transfer_chunk,
188            b'3',
189            transfer_chunk,
190            b'4',
191            b'5',
192            b'6',
193            b'7',
194        ]
195        self.assertEqual(sent_packets, expected_result)
196
197    async def test_keep_drop_queue_loop(self):
198        sent_packets: list[bytes] = []
199
200        # Async helper so DataTransposer can await on it.
201        async def append(list: list[bytes], data: bytes):
202            list.append(data)
203
204        keep_drop_queue = proxy.KeepDropQueue(
205            lambda data: append(sent_packets, data),
206            name="test",
207            keep_drop_queue=[2, 1, 3],
208        )
209
210        expected_sequence = [
211            b'1',
212            b'2',
213            b'4',
214            b'5',
215            b'6',
216            b'9',
217        ]
218        input_packets = [
219            b'1',
220            b'2',
221            b'3',
222            b'4',
223            b'5',
224            b'6',
225            b'7',
226            b'8',
227            b'9',
228        ]
229
230        for packet in input_packets:
231            await keep_drop_queue.process(packet)
232        self.assertEqual(sent_packets, expected_sequence)
233
234    async def test_keep_drop_queue(self):
235        sent_packets: list[bytes] = []
236
237        # Async helper so DataTransposer can await on it.
238        async def append(list: list[bytes], data: bytes):
239            list.append(data)
240
241        keep_drop_queue = proxy.KeepDropQueue(
242            lambda data: append(sent_packets, data),
243            name="test",
244            keep_drop_queue=[2, 1, 1, -1],
245        )
246
247        expected_sequence = [
248            b'1',
249            b'2',
250            b'4',
251        ]
252        input_packets = [
253            b'1',
254            b'2',
255            b'3',
256            b'4',
257            b'5',
258            b'6',
259            b'7',
260            b'8',
261            b'9',
262        ]
263
264        for packet in input_packets:
265            await keep_drop_queue.process(packet)
266        self.assertEqual(sent_packets, expected_sequence)
267
268    async def test_keep_drop_queue_transfer_chunks_only(self):
269        sent_packets: list[bytes] = []
270
271        # Async helper so DataTransposer can await on it.
272        async def append(list: list[bytes], data: bytes):
273            list.append(data)
274
275        keep_drop_queue = proxy.KeepDropQueue(
276            lambda data: append(sent_packets, data),
277            name="test",
278            keep_drop_queue=[2, 1, 1, -1],
279            only_consider_transfer_chunks=True,
280        )
281
282        transfer_chunk = _encode_rpc_frame(
283            Chunk(ProtocolVersion.VERSION_TWO, Chunk.Type.DATA, data=b'1')
284        )
285
286        expected_sequence = [
287            b'1',
288            transfer_chunk,
289            b'2',
290            transfer_chunk,
291            b'3',
292            b'4',
293            b'5',
294            b'6',
295            b'7',
296            transfer_chunk,
297            b'8',
298            b'9',
299            b'10',
300        ]
301        input_packets = [
302            b'1',
303            transfer_chunk,  # keep
304            b'2',
305            transfer_chunk,  # keep
306            b'3',
307            b'4',
308            b'5',
309            transfer_chunk,  # drop
310            b'6',
311            b'7',
312            transfer_chunk,  # keep
313            transfer_chunk,  # drop
314            b'8',
315            transfer_chunk,  # drop
316            b'9',
317            transfer_chunk,  # drop
318            transfer_chunk,  # drop
319            b'10',
320        ]
321
322        for packet in input_packets:
323            await keep_drop_queue.process(packet)
324        self.assertEqual(sent_packets, expected_sequence)
325
326    async def test_window_packet_dropper(self):
327        sent_packets: list[bytes] = []
328
329        # Async helper so DataTransposer can await on it.
330        async def append(list: list[bytes], data: bytes):
331            list.append(data)
332
333        window_packet_dropper = proxy.WindowPacketDropper(
334            lambda data: append(sent_packets, data),
335            name="test",
336            window_packet_to_drop=0,
337        )
338
339        packets = [
340            _encode_rpc_frame(
341                Chunk(
342                    ProtocolVersion.VERSION_TWO,
343                    Chunk.Type.DATA,
344                    data=b'1',
345                    session_id=1,
346                )
347            ),
348            _encode_rpc_frame(
349                Chunk(
350                    ProtocolVersion.VERSION_TWO,
351                    Chunk.Type.DATA,
352                    data=b'2',
353                    session_id=1,
354                )
355            ),
356            _encode_rpc_frame(
357                Chunk(
358                    ProtocolVersion.VERSION_TWO,
359                    Chunk.Type.DATA,
360                    data=b'3',
361                    session_id=1,
362                )
363            ),
364            _encode_rpc_frame(
365                Chunk(
366                    ProtocolVersion.VERSION_TWO,
367                    Chunk.Type.DATA,
368                    data=b'4',
369                    session_id=1,
370                )
371            ),
372            _encode_rpc_frame(
373                Chunk(
374                    ProtocolVersion.VERSION_TWO,
375                    Chunk.Type.DATA,
376                    data=b'5',
377                    session_id=1,
378                )
379            ),
380        ]
381
382        expected_packets = packets[1:]
383
384        # Test each even twice to assure the filter does not have issues
385        # on new window bondaries.
386        events = [
387            proxy.Event(
388                proxy.EventType.PARAMETERS_RETRANSMIT,
389                Chunk(
390                    ProtocolVersion.VERSION_TWO,
391                    Chunk.Type.PARAMETERS_RETRANSMIT,
392                ),
393            ),
394            proxy.Event(
395                proxy.EventType.PARAMETERS_CONTINUE,
396                Chunk(
397                    ProtocolVersion.VERSION_TWO, Chunk.Type.PARAMETERS_CONTINUE
398                ),
399            ),
400            proxy.Event(
401                proxy.EventType.PARAMETERS_RETRANSMIT,
402                Chunk(
403                    ProtocolVersion.VERSION_TWO,
404                    Chunk.Type.PARAMETERS_RETRANSMIT,
405                ),
406            ),
407            proxy.Event(
408                proxy.EventType.PARAMETERS_CONTINUE,
409                Chunk(
410                    ProtocolVersion.VERSION_TWO, Chunk.Type.PARAMETERS_CONTINUE
411                ),
412            ),
413        ]
414
415        for event in events:
416            sent_packets.clear()
417            for packet in packets:
418                await window_packet_dropper.process(packet)
419            self.assertEqual(sent_packets, expected_packets)
420            window_packet_dropper.handle_event(event)
421
422    async def test_window_packet_dropper_extra_in_flight_packets(self):
423        sent_packets: list[bytes] = []
424
425        # Async helper so DataTransposer can await on it.
426        async def append(list: list[bytes], data: bytes):
427            list.append(data)
428
429        window_packet_dropper = proxy.WindowPacketDropper(
430            lambda data: append(sent_packets, data),
431            name="test",
432            window_packet_to_drop=1,
433        )
434
435        packets = [
436            _encode_rpc_frame(
437                Chunk(
438                    ProtocolVersion.VERSION_TWO,
439                    Chunk.Type.DATA,
440                    data=b'1',
441                    offset=0,
442                )
443            ),
444            _encode_rpc_frame(
445                Chunk(
446                    ProtocolVersion.VERSION_TWO,
447                    Chunk.Type.DATA,
448                    data=b'2',
449                    offset=1,
450                )
451            ),
452            _encode_rpc_frame(
453                Chunk(
454                    ProtocolVersion.VERSION_TWO,
455                    Chunk.Type.DATA,
456                    data=b'3',
457                    offset=2,
458                )
459            ),
460            _encode_rpc_frame(
461                Chunk(
462                    ProtocolVersion.VERSION_TWO,
463                    Chunk.Type.DATA,
464                    data=b'2',
465                    offset=1,
466                )
467            ),
468            _encode_rpc_frame(
469                Chunk(
470                    ProtocolVersion.VERSION_TWO,
471                    Chunk.Type.DATA,
472                    data=b'3',
473                    offset=2,
474                )
475            ),
476            _encode_rpc_frame(
477                Chunk(
478                    ProtocolVersion.VERSION_TWO,
479                    Chunk.Type.DATA,
480                    data=b'4',
481                    offset=3,
482                )
483            ),
484        ]
485
486        expected_packets = packets[1:]
487
488        # Test each even twice to assure the filter does not have issues
489        # on new window bondaries.
490        events = [
491            None,
492            proxy.Event(
493                proxy.EventType.PARAMETERS_RETRANSMIT,
494                Chunk(
495                    ProtocolVersion.VERSION_TWO,
496                    Chunk.Type.PARAMETERS_RETRANSMIT,
497                    offset=1,
498                ),
499            ),
500            None,
501            None,
502            None,
503            None,
504        ]
505
506        for packet, event in zip(packets, events):
507            await window_packet_dropper.process(packet)
508            if event is not None:
509                window_packet_dropper.handle_event(event)
510
511        expected_packets = [packets[0], packets[2], packets[3], packets[5]]
512        self.assertEqual(sent_packets, expected_packets)
513
514    async def test_event_filter(self):
515        sent_packets: list[bytes] = []
516
517        # Async helper so EventFilter can await on it.
518        async def append(list: list[bytes], data: bytes):
519            list.append(data)
520
521        queue = asyncio.Queue()
522
523        event_filter = proxy.EventFilter(
524            lambda data: append(sent_packets, data),
525            name="test",
526            event_queue=queue,
527        )
528
529        request = packet_pb2.RpcPacket(
530            type=packet_pb2.PacketType.REQUEST,
531            channel_id=101,
532            service_id=1001,
533            method_id=100001,
534        ).SerializeToString()
535
536        packets = [
537            request,
538            _encode_rpc_frame(
539                Chunk(
540                    ProtocolVersion.VERSION_TWO, Chunk.Type.START, session_id=1
541                )
542            ),
543            _encode_rpc_frame(
544                Chunk(
545                    ProtocolVersion.VERSION_TWO,
546                    Chunk.Type.DATA,
547                    session_id=1,
548                    data=b'3',
549                )
550            ),
551            _encode_rpc_frame(
552                Chunk(
553                    ProtocolVersion.VERSION_TWO,
554                    Chunk.Type.DATA,
555                    session_id=1,
556                    data=b'3',
557                )
558            ),
559            request,
560            _encode_rpc_frame(
561                Chunk(
562                    ProtocolVersion.VERSION_TWO, Chunk.Type.START, session_id=2
563                )
564            ),
565            _encode_rpc_frame(
566                Chunk(
567                    ProtocolVersion.VERSION_TWO,
568                    Chunk.Type.DATA,
569                    session_id=2,
570                    data=b'4',
571                )
572            ),
573            _encode_rpc_frame(
574                Chunk(
575                    ProtocolVersion.VERSION_TWO,
576                    Chunk.Type.DATA,
577                    session_id=2,
578                    data=b'5',
579                )
580            ),
581        ]
582
583        expected_events = [
584            None,  # request
585            proxy.EventType.TRANSFER_START,
586            None,  # data chunk
587            None,  # data chunk
588            None,  # request
589            proxy.EventType.TRANSFER_START,
590            None,  # data chunk
591            None,  # data chunk
592        ]
593
594        for packet, expected_event_type in zip(packets, expected_events):
595            await event_filter.process(packet)
596            try:
597                event_type = queue.get_nowait().type
598            except asyncio.QueueEmpty:
599                event_type = None
600            self.assertEqual(event_type, expected_event_type)
601
602
603def _encode_rpc_frame(chunk: Chunk) -> bytes:
604    packet = packet_pb2.RpcPacket(
605        type=packet_pb2.PacketType.SERVER_STREAM,
606        channel_id=101,
607        service_id=1001,
608        method_id=100001,
609        payload=chunk.to_message().SerializeToString(),
610    ).SerializeToString()
611    return encode.ui_frame(73, packet)
612
613
614if __name__ == '__main__':
615    unittest.main()
616