• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of interoperability test methods."""
15
16import argparse
17import asyncio
18import collections
19import datetime
20import enum
21import inspect
22import json
23import os
24import threading
25import time
26from typing import Any, Optional, Union
27
28from google import auth as google_auth
29from google.auth import environment_vars as google_auth_environment_vars
30from google.auth.transport import grpc as google_auth_transport_grpc
31from google.auth.transport import requests as google_auth_transport_requests
32import grpc
33from grpc.experimental import aio
34
35from src.proto.grpc.testing import empty_pb2
36from src.proto.grpc.testing import messages_pb2
37from src.proto.grpc.testing import test_pb2_grpc
38
39_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
40_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
41
42
43async def _expect_status_code(
44    call: aio.Call, expected_code: grpc.StatusCode
45) -> None:
46    code = await call.code()
47    if code != expected_code:
48        raise ValueError(
49            "expected code %s, got %s" % (expected_code, await call.code())
50        )
51
52
53async def _expect_status_details(call: aio.Call, expected_details: str) -> None:
54    details = await call.details()
55    if details != expected_details:
56        raise ValueError(
57            "expected message %s, got %s"
58            % (expected_details, await call.details())
59        )
60
61
62async def _validate_status_code_and_details(
63    call: aio.Call, expected_code: grpc.StatusCode, expected_details: str
64) -> None:
65    await _expect_status_code(call, expected_code)
66    await _expect_status_details(call, expected_details)
67
68
69def _validate_payload_type_and_length(
70    response: Union[
71        messages_pb2.SimpleResponse, messages_pb2.StreamingOutputCallResponse
72    ],
73    expected_type: Any,
74    expected_length: int,
75) -> None:
76    if response.payload.type is not expected_type:
77        raise ValueError(
78            "expected payload type %s, got %s"
79            % (expected_type, type(response.payload.type))
80        )
81    elif len(response.payload.body) != expected_length:
82        raise ValueError(
83            "expected payload body size %d, got %d"
84            % (expected_length, len(response.payload.body))
85        )
86
87
88async def _large_unary_common_behavior(
89    stub: test_pb2_grpc.TestServiceStub,
90    fill_username: bool,
91    fill_oauth_scope: bool,
92    call_credentials: Optional[grpc.CallCredentials],
93) -> messages_pb2.SimpleResponse:
94    size = 314159
95    request = messages_pb2.SimpleRequest(
96        response_type=messages_pb2.COMPRESSABLE,
97        response_size=size,
98        payload=messages_pb2.Payload(body=b"\x00" * 271828),
99        fill_username=fill_username,
100        fill_oauth_scope=fill_oauth_scope,
101    )
102    response = await stub.UnaryCall(request, credentials=call_credentials)
103    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
104    return response
105
106
107async def _empty_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
108    response = await stub.EmptyCall(empty_pb2.Empty())
109    if not isinstance(response, empty_pb2.Empty):
110        raise TypeError(
111            'response is of type "%s", not empty_pb2.Empty!' % type(response)
112        )
113
114
115async def _large_unary(stub: test_pb2_grpc.TestServiceStub) -> None:
116    await _large_unary_common_behavior(stub, False, False, None)
117
118
119async def _client_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
120    payload_body_sizes = (
121        27182,
122        8,
123        1828,
124        45904,
125    )
126
127    async def request_gen():
128        for size in payload_body_sizes:
129            yield messages_pb2.StreamingInputCallRequest(
130                payload=messages_pb2.Payload(body=b"\x00" * size)
131            )
132
133    response = await stub.StreamingInputCall(request_gen())
134    if response.aggregated_payload_size != sum(payload_body_sizes):
135        raise ValueError(
136            "incorrect size %d!" % response.aggregated_payload_size
137        )
138
139
140async def _server_streaming(stub: test_pb2_grpc.TestServiceStub) -> None:
141    sizes = (
142        31415,
143        9,
144        2653,
145        58979,
146    )
147
148    request = messages_pb2.StreamingOutputCallRequest(
149        response_type=messages_pb2.COMPRESSABLE,
150        response_parameters=(
151            messages_pb2.ResponseParameters(size=sizes[0]),
152            messages_pb2.ResponseParameters(size=sizes[1]),
153            messages_pb2.ResponseParameters(size=sizes[2]),
154            messages_pb2.ResponseParameters(size=sizes[3]),
155        ),
156    )
157    call = stub.StreamingOutputCall(request)
158    for size in sizes:
159        response = await call.read()
160        _validate_payload_type_and_length(
161            response, messages_pb2.COMPRESSABLE, size
162        )
163
164
165async def _ping_pong(stub: test_pb2_grpc.TestServiceStub) -> None:
166    request_response_sizes = (
167        31415,
168        9,
169        2653,
170        58979,
171    )
172    request_payload_sizes = (
173        27182,
174        8,
175        1828,
176        45904,
177    )
178
179    call = stub.FullDuplexCall()
180    for response_size, payload_size in zip(
181        request_response_sizes, request_payload_sizes
182    ):
183        request = messages_pb2.StreamingOutputCallRequest(
184            response_type=messages_pb2.COMPRESSABLE,
185            response_parameters=(
186                messages_pb2.ResponseParameters(size=response_size),
187            ),
188            payload=messages_pb2.Payload(body=b"\x00" * payload_size),
189        )
190
191        await call.write(request)
192        response = await call.read()
193        _validate_payload_type_and_length(
194            response, messages_pb2.COMPRESSABLE, response_size
195        )
196    await call.done_writing()
197    await _validate_status_code_and_details(call, grpc.StatusCode.OK, "")
198
199
200async def _cancel_after_begin(stub: test_pb2_grpc.TestServiceStub):
201    call = stub.StreamingInputCall()
202    call.cancel()
203    if not call.cancelled():
204        raise ValueError("expected cancelled method to return True")
205    code = await call.code()
206    if code is not grpc.StatusCode.CANCELLED:
207        raise ValueError("expected status code CANCELLED")
208
209
210async def _cancel_after_first_response(stub: test_pb2_grpc.TestServiceStub):
211    request_response_sizes = (
212        31415,
213        9,
214        2653,
215        58979,
216    )
217    request_payload_sizes = (
218        27182,
219        8,
220        1828,
221        45904,
222    )
223
224    call = stub.FullDuplexCall()
225
226    response_size = request_response_sizes[0]
227    payload_size = request_payload_sizes[0]
228    request = messages_pb2.StreamingOutputCallRequest(
229        response_type=messages_pb2.COMPRESSABLE,
230        response_parameters=(
231            messages_pb2.ResponseParameters(size=response_size),
232        ),
233        payload=messages_pb2.Payload(body=b"\x00" * payload_size),
234    )
235
236    await call.write(request)
237    await call.read()
238
239    call.cancel()
240
241    try:
242        await call.read()
243    except asyncio.CancelledError:
244        assert await call.code() is grpc.StatusCode.CANCELLED
245    else:
246        raise ValueError("expected call to be cancelled")
247
248
249async def _timeout_on_sleeping_server(stub: test_pb2_grpc.TestServiceStub):
250    request_payload_size = 27182
251    time_limit = datetime.timedelta(seconds=1)
252
253    call = stub.FullDuplexCall(timeout=time_limit.total_seconds())
254
255    request = messages_pb2.StreamingOutputCallRequest(
256        response_type=messages_pb2.COMPRESSABLE,
257        payload=messages_pb2.Payload(body=b"\x00" * request_payload_size),
258        response_parameters=(
259            messages_pb2.ResponseParameters(
260                interval_us=int(time_limit.total_seconds() * 2 * 10**6)
261            ),
262        ),
263    )
264    await call.write(request)
265    await call.done_writing()
266    try:
267        await call.read()
268    except aio.AioRpcError as rpc_error:
269        if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
270            raise
271    else:
272        raise ValueError("expected call to exceed deadline")
273
274
275async def _empty_stream(stub: test_pb2_grpc.TestServiceStub):
276    call = stub.FullDuplexCall()
277    await call.done_writing()
278    assert await call.read() == aio.EOF
279
280
281async def _status_code_and_message(stub: test_pb2_grpc.TestServiceStub):
282    details = "test status message"
283    status = grpc.StatusCode.UNKNOWN  # code = 2
284
285    # Test with a UnaryCall
286    request = messages_pb2.SimpleRequest(
287        response_type=messages_pb2.COMPRESSABLE,
288        response_size=1,
289        payload=messages_pb2.Payload(body=b"\x00"),
290        response_status=messages_pb2.EchoStatus(
291            code=status.value[0], message=details
292        ),
293    )
294    call = stub.UnaryCall(request)
295    await _validate_status_code_and_details(call, status, details)
296
297    # Test with a FullDuplexCall
298    call = stub.FullDuplexCall()
299    request = messages_pb2.StreamingOutputCallRequest(
300        response_type=messages_pb2.COMPRESSABLE,
301        response_parameters=(messages_pb2.ResponseParameters(size=1),),
302        payload=messages_pb2.Payload(body=b"\x00"),
303        response_status=messages_pb2.EchoStatus(
304            code=status.value[0], message=details
305        ),
306    )
307    await call.write(request)  # sends the initial request.
308    await call.done_writing()
309    try:
310        await call.read()
311    except aio.AioRpcError as rpc_error:
312        assert rpc_error.code() == status
313    await _validate_status_code_and_details(call, status, details)
314
315
316async def _unimplemented_method(stub: test_pb2_grpc.TestServiceStub):
317    call = stub.UnimplementedCall(empty_pb2.Empty())
318    await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
319
320
321async def _unimplemented_service(stub: test_pb2_grpc.UnimplementedServiceStub):
322    call = stub.UnimplementedCall(empty_pb2.Empty())
323    await _expect_status_code(call, grpc.StatusCode.UNIMPLEMENTED)
324
325
326async def _custom_metadata(stub: test_pb2_grpc.TestServiceStub):
327    initial_metadata_value = "test_initial_metadata_value"
328    trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
329    metadata = aio.Metadata(
330        (_INITIAL_METADATA_KEY, initial_metadata_value),
331        (_TRAILING_METADATA_KEY, trailing_metadata_value),
332    )
333
334    async def _validate_metadata(call):
335        initial_metadata = await call.initial_metadata()
336        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
337            raise ValueError(
338                "expected initial metadata %s, got %s"
339                % (
340                    initial_metadata_value,
341                    initial_metadata[_INITIAL_METADATA_KEY],
342                )
343            )
344
345        trailing_metadata = await call.trailing_metadata()
346        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
347            raise ValueError(
348                "expected trailing metadata %s, got %s"
349                % (
350                    trailing_metadata_value,
351                    trailing_metadata[_TRAILING_METADATA_KEY],
352                )
353            )
354
355    # Testing with UnaryCall
356    request = messages_pb2.SimpleRequest(
357        response_type=messages_pb2.COMPRESSABLE,
358        response_size=1,
359        payload=messages_pb2.Payload(body=b"\x00"),
360    )
361    call = stub.UnaryCall(request, metadata=metadata)
362    await _validate_metadata(call)
363
364    # Testing with FullDuplexCall
365    call = stub.FullDuplexCall(metadata=metadata)
366    request = messages_pb2.StreamingOutputCallRequest(
367        response_type=messages_pb2.COMPRESSABLE,
368        response_parameters=(messages_pb2.ResponseParameters(size=1),),
369    )
370    await call.write(request)
371    await call.read()
372    await call.done_writing()
373    await _validate_metadata(call)
374
375
376async def _compute_engine_creds(
377    stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace
378):
379    response = await _large_unary_common_behavior(stub, True, True, None)
380    if args.default_service_account != response.username:
381        raise ValueError(
382            "expected username %s, got %s"
383            % (args.default_service_account, response.username)
384        )
385
386
387async def _oauth2_auth_token(
388    stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace
389):
390    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
391    wanted_email = json.load(open(json_key_filename, "r"))["client_email"]
392    response = await _large_unary_common_behavior(stub, True, True, None)
393    if wanted_email != response.username:
394        raise ValueError(
395            "expected username %s, got %s" % (wanted_email, response.username)
396        )
397    if args.oauth_scope.find(response.oauth_scope) == -1:
398        raise ValueError(
399            'expected to find oauth scope "{}" in received "{}"'.format(
400                response.oauth_scope, args.oauth_scope
401            )
402        )
403
404
405async def _jwt_token_creds(stub: test_pb2_grpc.TestServiceStub):
406    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
407    wanted_email = json.load(open(json_key_filename, "r"))["client_email"]
408    response = await _large_unary_common_behavior(stub, True, False, None)
409    if wanted_email != response.username:
410        raise ValueError(
411            "expected username %s, got %s" % (wanted_email, response.username)
412        )
413
414
415async def _per_rpc_creds(
416    stub: test_pb2_grpc.TestServiceStub, args: argparse.Namespace
417):
418    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
419    wanted_email = json.load(open(json_key_filename, "r"))["client_email"]
420    google_credentials, unused_project_id = google_auth.default(
421        scopes=[args.oauth_scope]
422    )
423    call_credentials = grpc.metadata_call_credentials(
424        google_auth_transport_grpc.AuthMetadataPlugin(
425            credentials=google_credentials,
426            request=google_auth_transport_requests.Request(),
427        )
428    )
429    response = await _large_unary_common_behavior(
430        stub, True, False, call_credentials
431    )
432    if wanted_email != response.username:
433        raise ValueError(
434            "expected username %s, got %s" % (wanted_email, response.username)
435        )
436
437
438async def _special_status_message(stub: test_pb2_grpc.TestServiceStub):
439    details = (
440        b"\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP"
441        b" \xf0\x9f\x98\x88\t\n".decode("utf-8")
442    )
443    status = grpc.StatusCode.UNKNOWN  # code = 2
444
445    # Test with a UnaryCall
446    request = messages_pb2.SimpleRequest(
447        response_type=messages_pb2.COMPRESSABLE,
448        response_size=1,
449        payload=messages_pb2.Payload(body=b"\x00"),
450        response_status=messages_pb2.EchoStatus(
451            code=status.value[0], message=details
452        ),
453    )
454    call = stub.UnaryCall(request)
455    await _validate_status_code_and_details(call, status, details)
456
457
458@enum.unique
459class TestCase(enum.Enum):
460    EMPTY_UNARY = "empty_unary"
461    LARGE_UNARY = "large_unary"
462    SERVER_STREAMING = "server_streaming"
463    CLIENT_STREAMING = "client_streaming"
464    PING_PONG = "ping_pong"
465    CANCEL_AFTER_BEGIN = "cancel_after_begin"
466    CANCEL_AFTER_FIRST_RESPONSE = "cancel_after_first_response"
467    TIMEOUT_ON_SLEEPING_SERVER = "timeout_on_sleeping_server"
468    EMPTY_STREAM = "empty_stream"
469    STATUS_CODE_AND_MESSAGE = "status_code_and_message"
470    UNIMPLEMENTED_METHOD = "unimplemented_method"
471    UNIMPLEMENTED_SERVICE = "unimplemented_service"
472    CUSTOM_METADATA = "custom_metadata"
473    COMPUTE_ENGINE_CREDS = "compute_engine_creds"
474    OAUTH2_AUTH_TOKEN = "oauth2_auth_token"
475    JWT_TOKEN_CREDS = "jwt_token_creds"
476    PER_RPC_CREDS = "per_rpc_creds"
477    SPECIAL_STATUS_MESSAGE = "special_status_message"
478
479
480_TEST_CASE_IMPLEMENTATION_MAPPING = {
481    TestCase.EMPTY_UNARY: _empty_unary,
482    TestCase.LARGE_UNARY: _large_unary,
483    TestCase.SERVER_STREAMING: _server_streaming,
484    TestCase.CLIENT_STREAMING: _client_streaming,
485    TestCase.PING_PONG: _ping_pong,
486    TestCase.CANCEL_AFTER_BEGIN: _cancel_after_begin,
487    TestCase.CANCEL_AFTER_FIRST_RESPONSE: _cancel_after_first_response,
488    TestCase.TIMEOUT_ON_SLEEPING_SERVER: _timeout_on_sleeping_server,
489    TestCase.EMPTY_STREAM: _empty_stream,
490    TestCase.STATUS_CODE_AND_MESSAGE: _status_code_and_message,
491    TestCase.UNIMPLEMENTED_METHOD: _unimplemented_method,
492    TestCase.UNIMPLEMENTED_SERVICE: _unimplemented_service,
493    TestCase.CUSTOM_METADATA: _custom_metadata,
494    TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds,
495    TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token,
496    TestCase.JWT_TOKEN_CREDS: _jwt_token_creds,
497    TestCase.PER_RPC_CREDS: _per_rpc_creds,
498    TestCase.SPECIAL_STATUS_MESSAGE: _special_status_message,
499}
500
501
502async def test_interoperability(
503    case: TestCase,
504    stub: test_pb2_grpc.TestServiceStub,
505    args: Optional[argparse.Namespace] = None,
506) -> None:
507    method = _TEST_CASE_IMPLEMENTATION_MAPPING.get(case)
508    if method is None:
509        raise NotImplementedError(f'Test case "{case}" not implemented!')
510    else:
511        num_params = len(inspect.signature(method).parameters)
512        if num_params == 1:
513            await method(stub)
514        elif num_params == 2:
515            if args is not None:
516                await method(stub, args)
517            else:
518                raise ValueError(f"Failed to run case [{case}]: args is None")
519        else:
520            raise ValueError(f"Invalid number of parameters [{num_params}]")
521