• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of interoperability test methods."""
15
16import argparse
17import asyncio
18import collections
19import datetime
20import enum
21import inspect
22import json
23import os
24import threading
25import time
26from typing import Any, Optional, Union
27
28import grpc
29from google import auth as google_auth
30from google.auth import environment_vars as google_auth_environment_vars
31from google.auth.transport import grpc as google_auth_transport_grpc
32from google.auth.transport import requests as google_auth_transport_requests
33from grpc.experimental import aio
34
35from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc
36
37_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
38_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
39
40
41async def _expect_status_code(call: aio.Call,
42                              expected_code: grpc.StatusCode) -> None:
43    code = await call.code()
44    if code != expected_code:
45        raise ValueError('expected code %s, got %s' %
46                         (expected_code, await call.code()))
47
48
49async def _expect_status_details(call: aio.Call, expected_details: str) -> None:
50    details = await call.details()
51    if details != expected_details:
52        raise ValueError('expected message %s, got %s' %
53                         (expected_details, await call.details()))
54
55
56async def _validate_status_code_and_details(call: aio.Call,
57                                            expected_code: grpc.StatusCode,
58                                            expected_details: str) -> None:
59    await _expect_status_code(call, expected_code)
60    await _expect_status_details(call, expected_details)
61
62
63def _validate_payload_type_and_length(response: Union[
64    messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse],
65                                      expected_type: Any,
66                                      expected_length: int) -> None:
67    if response.payload.type is not expected_type:
68        raise ValueError('expected payload type %s, got %s' %
69                         (expected_type, type(response.payload.type)))
70    elif len(response.payload.body) != expected_length:
71        raise ValueError('expected payload body size %d, got %d' %
72                         (expected_length, len(response.payload.body)))
73
74
75async def _large_unary_common_behavior(
76    stub: test_pb2_grpc.TestServiceStub, fill_username: bool,
77    fill_oauth_scope: bool, call_credentials: Optional[grpc.CallCredentials]
78) -> messages_pb2.SimpleResponse:
79    size = 314159
80    request = messages_pb2.SimpleRequest(
81        response_type=messages_pb2.COMPRESSABLE,
82        response_size=size,
83        payload=messages_pb2.Payload(body=b'\x00' * 271828),
84        fill_username=fill_username,
85        fill_oauth_scope=fill_oauth_scope)
86    response = await stub.UnaryCall(request, credentials=call_credentials)
87    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
88    return response
89
90
91async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
92    response = await stub.EmptyCall(empty_pb2.Empty())
93    if not isinstance(response, empty_pb2.Empty):
94        raise TypeError('response is of type "%s", not empty_pb2.Empty!' %
95                        type(response))
96
97
98async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
99    await _large_unary_common_behavior(stub, False, False, None)
100
101
102async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
103    payload_body_sizes = (
104        27182,
105        8,
106        1828,
107        45904,
108    )
109
110    async def request_gen():
111        for size in payload_body_sizes:
112            yield messages_pb2.StreamingInputCallRequest(
113                payload=messages_pb2.Payload(body=b'\x00' * size))
114
115    response = await stub.StreamingInputCall(request_gen())
116    if response.aggregated_payload_size != sum(payload_body_sizes):
117        raise ValueError('incorrect size %d!' %
118                         response.aggregated_payload_size)
119
120
121async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
122    sizes = (
123        31415,
124        9,
125        2653,
126        58979,
127    )
128
129    request = messages_pb2.StreamingOutputCallRequest(
130        response_type=messages_pb2.COMPRESSABLE,
131        response_parameters=(
132            messages_pb2.ResponseParameters(size=sizes[0]),
133            messages_pb2.ResponseParameters(size=sizes[1]),
134            messages_pb2.ResponseParameters(size=sizes[2]),
135            messages_pb2.ResponseParameters(size=sizes[3]),
136        ))
137    call = stub.StreamingOutputCall(request)
138    for size in sizes:
139        response = await call.read()
140        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
141                                          size)
142
143
144async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None:
145    request_response_sizes = (
146        31415,
147        9,
148        2653,
149        58979,
150    )
151    request_payload_sizes = (
152        27182,
153        8,
154        1828,
155        45904,
156    )
157
158    call = stub.FullDuplexCall()
159    for response_size, payload_size in zip(request_response_sizes,
160                                           request_payload_sizes):
161        request = messages_pb2.StreamingOutputCallRequest(
162            response_type=messages_pb2.COMPRESSABLE,
163            response_parameters=(messages_pb2.ResponseParameters(
164                size=response_size),),
165            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
166
167        await call.write(request)
168        response = await call.read()
169        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
170                                          response_size)
171    await call.done_writing()
172    await _validate_status_code_and_details(call, grpc.StatusCode.OK, '')
173
174
175async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub):
176    call = stub.StreamingInputCall()
177    call.cancel()
178    if not call.cancelled():
179        raise ValueError('expected cancelled method to return True')
180    code = await call.code()
181    if code is not grpc.StatusCode.CANCELLED:
182        raise ValueError('expected status code CANCELLED')
183
184
185async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub):
186    request_response_sizes = (
187        31415,
188        9,
189        2653,
190        58979,
191    )
192    request_payload_sizes = (
193        27182,
194        8,
195        1828,
196        45904,
197    )
198
199    call = stub.FullDuplexCall()
200
201    response_size = request_response_sizes[0]
202    payload_size = request_payload_sizes[0]
203    request = messages_pb2.StreamingOutputCallRequest(
204        response_type=messages_pb2.COMPRESSABLE,
205        response_parameters=(messages_pb2.ResponseParameters(
206            size=response_size),),
207        payload=messages_pb2.Payload(body=b'\x00' * payload_size))
208
209    await call.write(request)
210    await call.read()
211
212    call.cancel()
213
214    try:
215        await call.read()
216    except asyncio.CancelledError:
217        assert await call.code() is grpc.StatusCode.CANCELLED
218    else:
219        raise ValueError('expected call to be cancelled')
220
221
222async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub):
223    request_payload_size = 27182
224    time_limit = datetime.timedelta(seconds=1)
225
226    call = stub.FullDuplexCall(timeout=time_limit.total_seconds())
227
228    request = messages_pb2.StreamingOutputCallRequest(
229        response_type=messages_pb2.COMPRESSABLE,
230        payload=messages_pb2.Payload(body=b'\x00' * request_payload_size),
231        response_parameters=(messages_pb2.ResponseParameters(
232            interval_us=int(time_limit.total_seconds() * 2 * 10**6)),))
233    await call.write(request)
234    await call.done_writing()
235    try:
236        await call.read()
237    except aio.AioRpcError as rpc_error:
238        if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
239            raise
240    else:
241        raise ValueError('expected call to exceed deadline')
242
243
244async def _empty_stream(stub: test_pb2_grpc.TestServiceStub):
245    call = stub.FullDuplexCall()
246    await call.done_writing()
247    assert await call.read() == aio.EOF
248
249
250async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub):
251    details = 'test status message'
252    status = grpc.StatusCode.UNKNOWN  # code = 2
253
254    # Test with a UnaryCall
255    request = messages_pb2.SimpleRequest(
256        response_type=messages_pb2.COMPRESSABLE,
257        response_size=1,
258        payload=messages_pb2.Payload(body=b'\x00'),
259        response_status=messages_pb2.EchoStatus(code=status.value[0],
260                                                message=details))
261    call = stub.UnaryCall(request)
262    await _validate_status_code_and_details(call, status, details)
263
264    # Test with a FullDuplexCall
265    call = stub.FullDuplexCall()
266    request = messages_pb2.StreamingOutputCallRequest(
267        response_type=messages_pb2.COMPRESSABLE,
268        response_parameters=(messages_pb2.ResponseParameters(size=1),),
269        payload=messages_pb2.Payload(body=b'\x00'),
270        response_status=messages_pb2.EchoStatus(code=status.value[0],
271                                                message=details))
272    await call.write(request)  # sends the initial request.
273    await call.done_writing()
274    try:
275        await call.read()
276    except aio.AioRpcError as rpc_error:
277        assert rpc_error.code() == status
278    await _validate_status_code_and_details(call, status, details)
279
280
281async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub):
282    call = stub.UnimplementedCall(empty_pb2.Empty())
283    await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
284
285
286async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub):
287    call = stub.UnimplementedCall(empty_pb2.Empty())
288    await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
289
290
291async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
292    initial_metadata_value = "test_initial_metadata_value"
293    trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
294    metadata = aio.Metadata(
295        (_INITIAL_METADATA_KEY, initial_metadata_value),
296        (_TRAILING_METADATA_KEY, trailing_metadata_value),
297    )
298
299    async def _validate_metadata(call):
300        initial_metadata = await call.initial_metadata()
301        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
302            raise ValueError('expected initial metadata %s, got %s' %
303                             (initial_metadata_value,
304                              initial_metadata[_INITIAL_METADATA_KEY]))
305
306        trailing_metadata = await call.trailing_metadata()
307        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
308            raise ValueError('expected trailing metadata %s, got %s' %
309                             (trailing_metadata_value,
310                              trailing_metadata[_TRAILING_METADATA_KEY]))
311
312    # Testing with UnaryCall
313    request = messages_pb2.SimpleRequest(
314        response_type=messages_pb2.COMPRESSABLE,
315        response_size=1,
316        payload=messages_pb2.Payload(body=b'\x00'))
317    call = stub.UnaryCall(request, metadata=metadata)
318    await _validate_metadata(call)
319
320    # Testing with FullDuplexCall
321    call = stub.FullDuplexCall(metadata=metadata)
322    request = messages_pb2.StreamingOutputCallRequest(
323        response_type=messages_pb2.COMPRESSABLE,
324        response_parameters=(messages_pb2.ResponseParameters(size=1),))
325    await call.write(request)
326    await call.read()
327    await call.done_writing()
328    await _validate_metadata(call)
329
330
331async def _compute_engine_creds(stub: test_pb2_grpc.TestServiceStub,
332                                args: argparse.Namespace):
333    response = await _large_unary_common_behavior(stub, True, True, None)
334    if args.default_service_account != response.username:
335        raise ValueError('expected username %s, got %s' %
336                         (args.default_service_account, response.username))
337
338
339async def _oauth2_auth_token(stub: test_pb2_grpc.TestServiceStub,
340                             args: argparse.Namespace):
341    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
342    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
343    response = await _large_unary_common_behavior(stub, True, True, None)
344    if wanted_email != response.username:
345        raise ValueError('expected username %s, got %s' %
346                         (wanted_email, response.username))
347    if args.oauth_scope.find(response.oauth_scope) == -1:
348        raise ValueError(
349            'expected to find oauth scope "{}" in received "{}"'.format(
350                response.oauth_scope, args.oauth_scope))
351
352
353async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub):
354    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
355    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
356    response = await _large_unary_common_behavior(stub, True, False, None)
357    if wanted_email != response.username:
358        raise ValueError('expected username %s, got %s' %
359                         (wanted_email, response.username))
360
361
362async def _per_rpc_creds(stub: test_pb2_grpc.TestServiceStub,
363                         args: argparse.Namespace):
364    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
365    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
366    google_credentials, unused_project_id = google_auth.default(
367        scopes=[args.oauth_scope])
368    call_credentials = grpc.metadata_call_credentials(
369        google_auth_transport_grpc.AuthMetadataPlugin(
370            credentials=google_credentials,
371            request=google_auth_transport_requests.Request()))
372    response = await _large_unary_common_behavior(stub, True, False,
373                                                  call_credentials)
374    if wanted_email != response.username:
375        raise ValueError('expected username %s, got %s' %
376                         (wanted_email, response.username))
377
378
379async def _special_status_message(stub: test_pb2_grpc.TestServiceStub):
380    details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode(
381        'utf-8')
382    status = grpc.StatusCode.UNKNOWN  # code = 2
383
384    # Test with a UnaryCall
385    request = messages_pb2.SimpleRequest(
386        response_type=messages_pb2.COMPRESSABLE,
387        response_size=1,
388        payload=messages_pb2.Payload(body=b'\x00'),
389        response_status=messages_pb2.EchoStatus(code=status.value[0],
390                                                message=details))
391    call = stub.UnaryCall(request)
392    await _validate_status_code_and_details(call, status, details)
393
394
395@enum.unique
396class TestCase(enum.Enum):
397    EMPTY_UNARY = 'empty_unary'
398    LARGE_UNARY = 'large_unary'
399    SERVER_STREAMING = 'server_streaming'
400    CLIENT_STREAMING = 'client_streaming'
401    PING_PONG = 'ping_pong'
402    CANCEL_AFTER_BEGIN = 'cancel_after_begin'
403    CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
404    TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
405    EMPTY_STREAM = 'empty_stream'
406    STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
407    UNIMPLEMENTED_METHOD = 'unimplemented_method'
408    UNIMPLEMENTED_SERVICE = 'unimplemented_service'
409    CUSTOM_METADATA = "custom_metadata"
410    COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
411    OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
412    JWT_TOKEN_CREDS = 'jwt_token_creds'
413    PER_RPC_CREDS = 'per_rpc_creds'
414    SPECIAL_STATUS_MESSAGE = 'special_status_message'
415
416
417_TEST_CASE_IMPLEMENTATION_MAPPING = {
418    TestCase.EMPTY_UNARY: _empty_unary,
419    TestCase.LARGE_UNARY: _large_unary,
420    TestCase.SERVER_STREAMING: _server_streaming,
421    TestCase.CLIENT_STREAMING: _client_streaming,
422    TestCase.PING_PONG: _ping_pong,
423    TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin,
424    TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response,
425    TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server,
426    TestCase.EMPTY_STREAM: _empty_stream,
427    TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message,
428    TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method,
429    TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service,
430    TestCase.CUSTOM_METADATA: _custom_metadata,
431    TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds,
432    TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token,
433    TestCase.JWT_TOKEN_CREDS: _jwt_token_creds,
434    TestCase.PER_RPC_CREDS: _per_rpc_creds,
435    TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message,
436}
437
438
439async def test_interoperability(
440        case: TestCase,
441        stub: test_pb2_grpc.TestServiceStub,
442        args: Optional[argparse.Namespace] = None) -> None:
443    method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case)
444    if method is None:
445        raise NotImplementedError(f'Test case "{case}" not implemented!')
446    else:
447        num_params = len(inspect.signature(method).parameters)
448        if num_params == 1:
449            await method(stub)
450        elif num_params == 2:
451            if args is not None:
452                await method(stub, args)
453            else:
454                raise ValueError(f'Failed to run case [{case}]: args is None')
455        else:
456            raise ValueError(f'Invalid number of parameters [{num_params}]')
457