1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16import grpc 17from typing import AsyncIterable 18from grpc.experimental import aio 19from grpc.aio._typing import MetadatumType, MetadataKey, MetadataValue 20from grpc.aio._metadata import Metadata 21 22from tests.unit.framework.common import test_constants 23 24 25def seen_metadata(expected: Metadata, actual: Metadata): 26 return not bool(set(tuple(expected)) - set(tuple(actual))) 27 28 29def seen_metadatum(expected_key: MetadataKey, expected_value: MetadataValue, 30 actual: Metadata) -> bool: 31 obtained = actual[expected_key] 32 return obtained == expected_value 33 34 35async def block_until_certain_state(channel: aio.Channel, 36 expected_state: grpc.ChannelConnectivity): 37 state = channel.get_state() 38 while state != expected_state: 39 await channel.wait_for_state_change(state) 40 state = channel.get_state() 41 42 43def inject_callbacks(call: aio.Call): 44 first_callback_ran = asyncio.Event() 45 46 def first_callback(call): 47 # Validate that all resopnses have been received 48 # and the call is an end state. 49 assert call.done() 50 first_callback_ran.set() 51 52 second_callback_ran = asyncio.Event() 53 54 def second_callback(call): 55 # Validate that all responses have been received 56 # and the call is an end state. 57 assert call.done() 58 second_callback_ran.set() 59 60 call.add_done_callback(first_callback) 61 call.add_done_callback(second_callback) 62 63 async def validation(): 64 await asyncio.wait_for( 65 asyncio.gather(first_callback_ran.wait(), 66 second_callback_ran.wait()), 67 test_constants.SHORT_TIMEOUT) 68 69 return validation() 70 71 72class CountingRequestIterator: 73 74 def __init__(self, request_iterator): 75 self.request_cnt = 0 76 self._request_iterator = request_iterator 77 78 async def _forward_requests(self): 79 async for request in self._request_iterator: 80 self.request_cnt += 1 81 yield request 82 83 def __aiter__(self): 84 return self._forward_requests() 85 86 87class CountingResponseIterator: 88 89 def __init__(self, response_iterator): 90 self.response_cnt = 0 91 self._response_iterator = response_iterator 92 93 async def _forward_responses(self): 94 async for response in self._response_iterator: 95 self.response_cnt += 1 96 yield response 97 98 def __aiter__(self): 99 return self._forward_responses() 100