• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of interoperability test methods."""
15
16import argparse
17import asyncio
18import collections
19import datetime
20import enum
21import inspect
22import json
23import os
24import threading
25import time
26from typing import Any, Optional, Union
27
28import grpc
29from google import auth as google_auth
30from google.auth import environment_vars as google_auth_environment_vars
31from google.auth.transport import grpc as google_auth_transport_grpc
32from google.auth.transport import requests as google_auth_transport_requests
33from grpc.experimental import aio
34
35from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
36
37_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
38_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
39
40
41async def _expect_status_code(call: aio.Call,
42                              expected_code: grpc.StatusCode) -> None:
43    code = await call.code()
44    if code != expected_code:
45        raise ValueError('expected code %s, got %s' %
46                         (expected_code, await call.code()))
47
48
49async def _expect_status_details(call: aio.Call, expected_details: str) -> None:
50    details = await call.details()
51    if details != expected_details:
52        raise ValueError('expected message %s, got %s' %
53                         (expected_details, await call.details()))
54
55
56async def _validate_status_code_and_details(call: aio.Call,
57                                            expected_code: grpc.StatusCode,
58                                            expected_details: str) -> None:
59    await _expect_status_code(call, expected_code)
60    await _expect_status_details(call, expected_details)
61
62
63def _validate_payload_type_and_length(
64        response: Union[messages_pb2.SimpleResponse, messages_pb2.
65                        StreamingOutputCallResponse], expected_type: Any,
66        expected_length: int) -> None:
67    if response.payload.type is not expected_type:
68        raise ValueError('expected payload type %s, got %s' %
69                         (expected_type, type(response.payload.type)))
70    elif len(response.payload.body) != expected_length:
71        raise ValueError('expected payload body size %d, got %d' %
72                         (expected_length, len(response.payload.body)))
73
74
75async def _large_unary_common_behavior(
76        stub: test_pb2_grpc.TestServiceStub, fill_username: bool,
77        fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials]
78) -> messages_pb2.SimpleResponse:
79    size = 314159
80    request = messages_pb2.SimpleRequest(
81        response_type=messages_pb2.COMPRESSABLE,
82        response_size=size,
83        payload=messages_pb2.Payload(body=b'\x00' * 271828),
84        fill_username=fill_username,
85        fill_oauth_scope=fill_oauth_scope)
86    response = await stub.UnaryCall(request, credentials=call_credentials)
87    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
88    return response
89
90
91async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
92    response = await stub.EmptyCall(empty_pb2.Empty())
93    if not isinstance(response, empty_pb2.Empty):
94        raise TypeError('response is of type "%s", not empty_pb2.Empty!' %
95                        type(response))
96
97
98async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
99    await _large_unary_common_behavior(stub, False, False, None)
100
101
102async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
103    payload_body_sizes = (
104        27182,
105        8,
106        1828,
107        45904,
108    )
109
110    async def request_gen():
111        for size in payload_body_sizes:
112            yield messages_pb2.StreamingInputCallRequest(
113                payload=messages_pb2.Payload(body=b'\x00' * size))
114
115    response = await stub.StreamingInputCall(request_gen())
116    if response.aggregated_payload_size != sum(payload_body_sizes):
117        raise ValueError('incorrect size %d!' %
118                         response.aggregated_payload_size)
119
120
121async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
122    sizes = (
123        31415,
124        9,
125        2653,
126        58979,
127    )
128
129    request = messages_pb2.StreamingOutputCallRequest(
130        response_type=messages_pb2.COMPRESSABLE,
131        response_parameters=(
132            messages_pb2.ResponseParameters(size=sizes[0]),
133            messages_pb2.ResponseParameters(size=sizes[1]),
134            messages_pb2.ResponseParameters(size=sizes[2]),
135            messages_pb2.ResponseParameters(size=sizes[3]),
136        ))
137    call = stub.StreamingOutputCall(request)
138    for size in sizes:
139        response = await call.read()
140        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
141                                          size)
142
143
144async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None:
145    request_response_sizes = (
146        31415,
147        9,
148        2653,
149        58979,
150    )
151    request_payload_sizes = (
152        27182,
153        8,
154        1828,
155        45904,
156    )
157
158    call = stub.FullDuplexCall()
159    for response_size, payload_size in zip(request_response_sizes,
160                                           request_payload_sizes):
161        request = messages_pb2.StreamingOutputCallRequest(
162            response_type=messages_pb2.COMPRESSABLE,
163            response_parameters=(messages_pb2.ResponseParameters(
164                size=response_size),),
165            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
166
167        await call.write(request)
168        response = await call.read()
169        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
170                                          response_size)
171    await call.done_writing()
172    await _validate_status_code_and_details(call, grpc.StatusCode.OK, '')
173
174
175async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub):
176    call = stub.StreamingInputCall()
177    call.cancel()
178    if not call.cancelled():
179        raise ValueError('expected cancelled method to return True')
180    code = await call.code()
181    if code is not grpc.StatusCode.CANCELLED:
182        raise ValueError('expected status code CANCELLED')
183
184
185async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub):
186    request_response_sizes = (
187        31415,
188        9,
189        2653,
190        58979,
191    )
192    request_payload_sizes = (
193        27182,
194        8,
195        1828,
196        45904,
197    )
198
199    call = stub.FullDuplexCall()
200
201    response_size = request_response_sizes[0]
202    payload_size = request_payload_sizes[0]
203    request = messages_pb2.StreamingOutputCallRequest(
204        response_type=messages_pb2.COMPRESSABLE,
205        response_parameters=(messages_pb2.ResponseParameters(
206            size=response_size),),
207        payload=messages_pb2.Payload(body=b'\x00' * payload_size))
208
209    await call.write(request)
210    await call.read()
211
212    call.cancel()
213
214    try:
215        await call.read()
216    except asyncio.CancelledError:
217        assert await call.code() is grpc.StatusCode.CANCELLED
218    else:
219        raise ValueError('expected call to be cancelled')
220
221
222async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub):
223    request_payload_size = 27182
224    time_limit = datetime.timedelta(seconds=1)
225
226    call = stub.FullDuplexCall(timeout=time_limit.total_seconds())
227
228    request = messages_pb2.StreamingOutputCallRequest(
229        response_type=messages_pb2.COMPRESSABLE,
230        payload=messages_pb2.Payload(body=b'\x00' * request_payload_size),
231        response_parameters=(messages_pb2.ResponseParameters(
232            interval_us=int(time_limit.total_seconds() * 2 * 10**6)),))
233    await call.write(request)
234    await call.done_writing()
235    try:
236        await call.read()
237    except aio.AioRpcError as rpc_error:
238        if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
239            raise
240    else:
241        raise ValueError('expected call to exceed deadline')
242
243
244async def _empty_stream(stub: test_pb2_grpc.TestServiceStub):
245    call = stub.FullDuplexCall()
246    await call.done_writing()
247    assert await call.read() == aio.EOF
248
249
250async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub):
251    details = 'test status message'
252    status = grpc.StatusCode.UNKNOWN  # code = 2
253
254    # Test with a UnaryCall
255    request = messages_pb2.SimpleRequest(
256        response_type=messages_pb2.COMPRESSABLE,
257        response_size=1,
258        payload=messages_pb2.Payload(body=b'\x00'),
259        response_status=messages_pb2.EchoStatus(code=status.value[0],
260                                                message=details))
261    call = stub.UnaryCall(request)
262    await _validate_status_code_and_details(call, status, details)
263
264    # Test with a FullDuplexCall
265    call = stub.FullDuplexCall()
266    request = messages_pb2.StreamingOutputCallRequest(
267        response_type=messages_pb2.COMPRESSABLE,
268        response_parameters=(messages_pb2.ResponseParameters(size=1),),
269        payload=messages_pb2.Payload(body=b'\x00'),
270        response_status=messages_pb2.EchoStatus(code=status.value[0],
271                                                message=details))
272    await call.write(request)  # sends the initial request.
273    await call.done_writing()
274    await _validate_status_code_and_details(call, status, details)
275
276
277async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub):
278    call = stub.UnimplementedCall(empty_pb2.Empty())
279    await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
280
281
282async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub):
283    call = stub.UnimplementedCall(empty_pb2.Empty())
284    await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
285
286
287async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
288    initial_metadata_value = "test_initial_metadata_value"
289    trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
290    metadata = aio.Metadata(
291        (_INITIAL_METADATA_KEY, initial_metadata_value),
292        (_TRAILING_METADATA_KEY, trailing_metadata_value),
293    )
294
295    async def _validate_metadata(call):
296        initial_metadata = await call.initial_metadata()
297        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
298            raise ValueError('expected initial metadata %s, got %s' %
299                             (initial_metadata_value,
300                              initial_metadata[_INITIAL_METADATA_KEY]))
301
302        trailing_metadata = await call.trailing_metadata()
303        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
304            raise ValueError('expected trailing metadata %s, got %s' %
305                             (trailing_metadata_value,
306                              trailing_metadata[_TRAILING_METADATA_KEY]))
307
308    # Testing with UnaryCall
309    request = messages_pb2.SimpleRequest(
310        response_type=messages_pb2.COMPRESSABLE,
311        response_size=1,
312        payload=messages_pb2.Payload(body=b'\x00'))
313    call = stub.UnaryCall(request, metadata=metadata)
314    await _validate_metadata(call)
315
316    # Testing with FullDuplexCall
317    call = stub.FullDuplexCall(metadata=metadata)
318    request = messages_pb2.StreamingOutputCallRequest(
319        response_type=messages_pb2.COMPRESSABLE,
320        response_parameters=(messages_pb2.ResponseParameters(size=1),))
321    await call.write(request)
322    await call.read()
323    await call.done_writing()
324    await _validate_metadata(call)
325
326
327async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub,
328                                args: argparse.Namespace):
329    response = await _large_unary_common_behavior(stub, True, True, None)
330    if args.default_service_account != response.username:
331        raise ValueError('expected username %s, got %s' %
332                         (args.default_service_account, response.username))
333
334
335async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub,
336                             args: argparse.Namespace):
337    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
338    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
339    response = await _large_unary_common_behavior(stub, True, True, None)
340    if wanted_email != response.username:
341        raise ValueError('expected username %s, got %s' %
342                         (wanted_email, response.username))
343    if args.oauth_scope.find(response.oauth_scope) == -1:
344        raise ValueError(
345            'expected to find oauth scope "{}" in received "{}"'.format(
346                response.oauth_scope, args.oauth_scope))
347
348
349async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub):
350    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
351    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
352    response = await _large_unary_common_behavior(stub, True, False, None)
353    if wanted_email != response.username:
354        raise ValueError('expected username %s, got %s' %
355                         (wanted_email, response.username))
356
357
358async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub,
359                         args: argparse.Namespace):
360    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
361    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
362    google_credentials, unused_project_id = google_auth.default(
363        scopes=[args.oauth_scope])
364    call_credentials = grpc.metadata_call_credentials(
365        google_auth_transport_grpc.AuthMetadataPlugin(
366            credentials=google_credentials,
367            request=google_auth_transport_requests.Request()))
368    response = await _large_unary_common_behavior(stub, True, False,
369                                                  call_credentials)
370    if wanted_email != response.username:
371        raise ValueError('expected username %s, got %s' %
372                         (wanted_email, response.username))
373
374
375async def _special_status_message(stub: test_pb2_grpc.TestServiceStub):
376    details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode(
377        'utf-8')
378    status = grpc.StatusCode.UNKNOWN  # code = 2
379
380    # Test with a UnaryCall
381    request = messages_pb2.SimpleRequest(
382        response_type=messages_pb2.COMPRESSABLE,
383        response_size=1,
384        payload=messages_pb2.Payload(body=b'\x00'),
385        response_status=messages_pb2.EchoStatus(code=status.value[0],
386                                                message=details))
387    call = stub.UnaryCall(request)
388    await _validate_status_code_and_details(call, status, details)
389
390
391@enum.unique
392class TestCase(enum.Enum):
393    EMPTY_UNARY = 'empty_unary'
394    LARGE_UNARY = 'large_unary'
395    SERVER_STREAMING = 'server_streaming'
396    CLIENT_STREAMING = 'client_streaming'
397    PING_PONG = 'ping_pong'
398    CANCEL_AFTER_BEGIN = 'cancel_after_begin'
399    CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
400    TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
401    EMPTY_STREAM = 'empty_stream'
402    STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
403    UNIMPLEMENTED_METHOD = 'unimplemented_method'
404    UNIMPLEMENTED_SERVICE = 'unimplemented_service'
405    CUSTOM_METADATA = "custom_metadata"
406    COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
407    OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
408    JWT_TOKEN_CREDS = 'jwt_token_creds'
409    PER_RPC_CREDS = 'per_rpc_creds'
410    SPECIAL_STATUS_MESSAGE = 'special_status_message'
411
412
413_TEST_CASE_IMPLEMENTATION_MAPPING = {
414    TestCase.EMPTY_UNARY: _empty_unary,
415    TestCase.LARGE_UNARY: _large_unary,
416    TestCase.SERVER_STREAMING: _server_streaming,
417    TestCase.CLIENT_STREAMING: _client_streaming,
418    TestCase.PING_PONG: _ping_pong,
419    TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin,
420    TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response,
421    TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server,
422    TestCase.EMPTY_STREAM: _empty_stream,
423    TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message,
424    TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method,
425    TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service,
426    TestCase.CUSTOM_METADATA: _custom_metadata,
427    TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds,
428    TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token,
429    TestCase.JWT_TOKEN_CREDS: _jwt_token_creds,
430    TestCase.PER_RPC_CREDS: _per_rpc_creds,
431    TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message,
432}
433
434
435async def test_interoperability(case: TestCase,
436                                stub: test_pb2_grpc.TestServiceStub,
437                                args: Optional[argparse.Namespace] = None
438                               ) -> None:
439    method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case)
440    if method is None:
441        raise NotImplementedError(f'Test case "{case}" not implemented!')
442    else:
443        num_params = len(inspect.signature(method).parameters)
444        if num_params == 1:
445            await method(stub)
446        elif num_params == 2:
447            if args is not None:
448                await method(stub, args)
449            else:
450                raise ValueError(f'Failed to run case [{case}]: args is None')
451        else:
452            raise ValueError(f'Invalid number of parameters [{num_params}]')
453