• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16from typing import AsyncIterable
17
18import grpc
19from grpc.aio._metadata import Metadata
20from grpc.aio._typing import MetadataKey
21from grpc.aio._typing import MetadataValue
22from grpc.aio._typing import MetadatumType
23from grpc.experimental import aio
24
25from tests.unit.framework.common import test_constants
26
27ADHOC_METHOD = "/test/AdHoc"
28
29
30def seen_metadata(expected: Metadata, actual: Metadata):
31    return not bool(set(tuple(expected)) - set(tuple(actual)))
32
33
34def seen_metadatum(
35    expected_key: MetadataKey, expected_value: MetadataValue, actual: Metadata
36) -> bool:
37    obtained = actual[expected_key]
38    return obtained == expected_value
39
40
41async def block_until_certain_state(
42    channel: aio.Channel, expected_state: grpc.ChannelConnectivity
43):
44    state = channel.get_state()
45    while state != expected_state:
46        await channel.wait_for_state_change(state)
47        state = channel.get_state()
48
49
50def inject_callbacks(call: aio.Call):
51    first_callback_ran = asyncio.Event()
52
53    def first_callback(call):
54        # Validate that all responses have been received
55        # and the call is an end state.
56        assert call.done()
57        first_callback_ran.set()
58
59    second_callback_ran = asyncio.Event()
60
61    def second_callback(call):
62        # Validate that all responses have been received
63        # and the call is an end state.
64        assert call.done()
65        second_callback_ran.set()
66
67    call.add_done_callback(first_callback)
68    call.add_done_callback(second_callback)
69
70    async def validation():
71        await asyncio.wait_for(
72            asyncio.gather(
73                first_callback_ran.wait(), second_callback_ran.wait()
74            ),
75            test_constants.SHORT_TIMEOUT,
76        )
77
78    return validation()
79
80
81class CountingRequestIterator:
82    def __init__(self, request_iterator):
83        self.request_cnt = 0
84        self._request_iterator = request_iterator
85
86    async def _forward_requests(self):
87        async for request in self._request_iterator:
88            self.request_cnt += 1
89            yield request
90
91    def __aiter__(self):
92        return self._forward_requests()
93
94
95class CountingResponseIterator:
96    def __init__(self, response_iterator):
97        self.response_cnt = 0
98        self._response_iterator = response_iterator
99
100    async def _forward_responses(self):
101        async for response in self._response_iterator:
102            self.response_cnt += 1
103            yield response
104
105    def __aiter__(self):
106        return self._forward_responses()
107
108
109class AdhocGenericHandler(grpc.GenericRpcHandler):
110    """A generic handler to plugin testing server methods on the fly."""
111
112    _handler: grpc.RpcMethodHandler
113
114    def __init__(self):
115        self._handler = None
116
117    def set_adhoc_handler(self, handler: grpc.RpcMethodHandler):
118        self._handler = handler
119
120    def service(self, handler_call_details):
121        if handler_call_details.method == ADHOC_METHOD:
122            return self._handler
123        else:
124            return None
125