1# Copyright 2020 The gRPC Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import asyncio 16from typing import AsyncIterable 17 18import grpc 19from grpc.aio._metadata import Metadata 20from grpc.aio._typing import MetadataKey 21from grpc.aio._typing import MetadataValue 22from grpc.aio._typing import MetadatumType 23from grpc.experimental import aio 24 25from tests.unit.framework.common import test_constants 26 27ADHOC_METHOD = "/test/AdHoc" 28 29 30def seen_metadata(expected: Metadata, actual: Metadata): 31 return not bool(set(tuple(expected)) - set(tuple(actual))) 32 33 34def seen_metadatum( 35 expected_key: MetadataKey, expected_value: MetadataValue, actual: Metadata 36) -> bool: 37 obtained = actual[expected_key] 38 return obtained == expected_value 39 40 41async def block_until_certain_state( 42 channel: aio.Channel, expected_state: grpc.ChannelConnectivity 43): 44 state = channel.get_state() 45 while state != expected_state: 46 await channel.wait_for_state_change(state) 47 state = channel.get_state() 48 49 50def inject_callbacks(call: aio.Call): 51 first_callback_ran = asyncio.Event() 52 53 def first_callback(call): 54 # Validate that all responses have been received 55 # and the call is an end state. 56 assert call.done() 57 first_callback_ran.set() 58 59 second_callback_ran = asyncio.Event() 60 61 def second_callback(call): 62 # Validate that all responses have been received 63 # and the call is an end state. 64 assert call.done() 65 second_callback_ran.set() 66 67 call.add_done_callback(first_callback) 68 call.add_done_callback(second_callback) 69 70 async def validation(): 71 await asyncio.wait_for( 72 asyncio.gather( 73 first_callback_ran.wait(), second_callback_ran.wait() 74 ), 75 test_constants.SHORT_TIMEOUT, 76 ) 77 78 return validation() 79 80 81class CountingRequestIterator: 82 def __init__(self, request_iterator): 83 self.request_cnt = 0 84 self._request_iterator = request_iterator 85 86 async def _forward_requests(self): 87 async for request in self._request_iterator: 88 self.request_cnt += 1 89 yield request 90 91 def __aiter__(self): 92 return self._forward_requests() 93 94 95class CountingResponseIterator: 96 def __init__(self, response_iterator): 97 self.response_cnt = 0 98 self._response_iterator = response_iterator 99 100 async def _forward_responses(self): 101 async for response in self._response_iterator: 102 self.response_cnt += 1 103 yield response 104 105 def __aiter__(self): 106 return self._forward_responses() 107 108 109class AdhocGenericHandler(grpc.GenericRpcHandler): 110 """A generic handler to plugin testing server methods on the fly.""" 111 112 _handler: grpc.RpcMethodHandler 113 114 def __init__(self): 115 self._handler = None 116 117 def set_adhoc_handler(self, handler: grpc.RpcMethodHandler): 118 self._handler = handler 119 120 def service(self, handler_call_details): 121 if handler_call_details.method == ADHOC_METHOD: 122 return self._handler 123 else: 124 return None 125