• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implementations of interoperability test methods."""
15
16# NOTE(lidiz) This module only exists in Bazel BUILD file, for more details
17# please refer to comments in the "bazel_namespace_package_hack" module.
18try:
19    from tests import bazel_namespace_package_hack
20    bazel_namespace_package_hack.sys_path_to_site_dir_hack()
21except ImportError:
22    pass
23
24import enum
25import json
26import os
27import threading
28import time
29
30from google import auth as google_auth
31from google.auth import environment_vars as google_auth_environment_vars
32from google.auth.transport import grpc as google_auth_transport_grpc
33from google.auth.transport import requests as google_auth_transport_requests
34import grpc
35
36from src.proto.grpc.testing import empty_pb2
37from src.proto.grpc.testing import messages_pb2
38
39_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
40_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
41
42
43def _expect_status_code(call, expected_code):
44    if call.code() != expected_code:
45        raise ValueError('expected code %s, got %s' %
46                         (expected_code, call.code()))
47
48
49def _expect_status_details(call, expected_details):
50    if call.details() != expected_details:
51        raise ValueError('expected message %s, got %s' %
52                         (expected_details, call.details()))
53
54
55def _validate_status_code_and_details(call, expected_code, expected_details):
56    _expect_status_code(call, expected_code)
57    _expect_status_details(call, expected_details)
58
59
60def _validate_payload_type_and_length(response, expected_type, expected_length):
61    if response.payload.type is not expected_type:
62        raise ValueError('expected payload type %s, got %s' %
63                         (expected_type, type(response.payload.type)))
64    elif len(response.payload.body) != expected_length:
65        raise ValueError('expected payload body size %d, got %d' %
66                         (expected_length, len(response.payload.body)))
67
68
69def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
70                                 call_credentials):
71    size = 314159
72    request = messages_pb2.SimpleRequest(
73        response_type=messages_pb2.COMPRESSABLE,
74        response_size=size,
75        payload=messages_pb2.Payload(body=b'\x00' * 271828),
76        fill_username=fill_username,
77        fill_oauth_scope=fill_oauth_scope)
78    response_future = stub.UnaryCall.future(request,
79                                            credentials=call_credentials)
80    response = response_future.result()
81    _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
82    return response
83
84
85def _empty_unary(stub):
86    response = stub.EmptyCall(empty_pb2.Empty())
87    if not isinstance(response, empty_pb2.Empty):
88        raise TypeError('response is of type "%s", not empty_pb2.Empty!' %
89                        type(response))
90
91
92def _large_unary(stub):
93    _large_unary_common_behavior(stub, False, False, None)
94
95
96def _client_streaming(stub):
97    payload_body_sizes = (
98        27182,
99        8,
100        1828,
101        45904,
102    )
103    payloads = (messages_pb2.Payload(body=b'\x00' * size)
104                for size in payload_body_sizes)
105    requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
106                for payload in payloads)
107    response = stub.StreamingInputCall(requests)
108    if response.aggregated_payload_size != 74922:
109        raise ValueError('incorrect size %d!' %
110                         response.aggregated_payload_size)
111
112
113def _server_streaming(stub):
114    sizes = (
115        31415,
116        9,
117        2653,
118        58979,
119    )
120
121    request = messages_pb2.StreamingOutputCallRequest(
122        response_type=messages_pb2.COMPRESSABLE,
123        response_parameters=(
124            messages_pb2.ResponseParameters(size=sizes[0]),
125            messages_pb2.ResponseParameters(size=sizes[1]),
126            messages_pb2.ResponseParameters(size=sizes[2]),
127            messages_pb2.ResponseParameters(size=sizes[3]),
128        ))
129    response_iterator = stub.StreamingOutputCall(request)
130    for index, response in enumerate(response_iterator):
131        _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
132                                          sizes[index])
133
134
135class _Pipe(object):
136
137    def __init__(self):
138        self._condition = threading.Condition()
139        self._values = []
140        self._open = True
141
142    def __iter__(self):
143        return self
144
145    def __next__(self):
146        return self.next()
147
148    def next(self):
149        with self._condition:
150            while not self._values and self._open:
151                self._condition.wait()
152            if self._values:
153                return self._values.pop(0)
154            else:
155                raise StopIteration()
156
157    def add(self, value):
158        with self._condition:
159            self._values.append(value)
160            self._condition.notify()
161
162    def close(self):
163        with self._condition:
164            self._open = False
165            self._condition.notify()
166
167    def __enter__(self):
168        return self
169
170    def __exit__(self, type, value, traceback):
171        self.close()
172
173
174def _ping_pong(stub):
175    request_response_sizes = (
176        31415,
177        9,
178        2653,
179        58979,
180    )
181    request_payload_sizes = (
182        27182,
183        8,
184        1828,
185        45904,
186    )
187
188    with _Pipe() as pipe:
189        response_iterator = stub.FullDuplexCall(pipe)
190        for response_size, payload_size in zip(request_response_sizes,
191                                               request_payload_sizes):
192            request = messages_pb2.StreamingOutputCallRequest(
193                response_type=messages_pb2.COMPRESSABLE,
194                response_parameters=(messages_pb2.ResponseParameters(
195                    size=response_size),),
196                payload=messages_pb2.Payload(body=b'\x00' * payload_size))
197            pipe.add(request)
198            response = next(response_iterator)
199            _validate_payload_type_and_length(response,
200                                              messages_pb2.COMPRESSABLE,
201                                              response_size)
202
203
204def _cancel_after_begin(stub):
205    with _Pipe() as pipe:
206        response_future = stub.StreamingInputCall.future(pipe)
207        response_future.cancel()
208        if not response_future.cancelled():
209            raise ValueError('expected cancelled method to return True')
210        if response_future.code() is not grpc.StatusCode.CANCELLED:
211            raise ValueError('expected status code CANCELLED')
212
213
214def _cancel_after_first_response(stub):
215    request_response_sizes = (
216        31415,
217        9,
218        2653,
219        58979,
220    )
221    request_payload_sizes = (
222        27182,
223        8,
224        1828,
225        45904,
226    )
227    with _Pipe() as pipe:
228        response_iterator = stub.FullDuplexCall(pipe)
229
230        response_size = request_response_sizes[0]
231        payload_size = request_payload_sizes[0]
232        request = messages_pb2.StreamingOutputCallRequest(
233            response_type=messages_pb2.COMPRESSABLE,
234            response_parameters=(messages_pb2.ResponseParameters(
235                size=response_size),),
236            payload=messages_pb2.Payload(body=b'\x00' * payload_size))
237        pipe.add(request)
238        response = next(response_iterator)
239        # We test the contents of `response` in the Ping Pong test - don't check
240        # them here.
241        response_iterator.cancel()
242
243        try:
244            next(response_iterator)
245        except grpc.RpcError as rpc_error:
246            if rpc_error.code() is not grpc.StatusCode.CANCELLED:
247                raise
248        else:
249            raise ValueError('expected call to be cancelled')
250
251
252def _timeout_on_sleeping_server(stub):
253    request_payload_size = 27182
254    with _Pipe() as pipe:
255        response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)
256
257        request = messages_pb2.StreamingOutputCallRequest(
258            response_type=messages_pb2.COMPRESSABLE,
259            payload=messages_pb2.Payload(body=b'\x00' * request_payload_size))
260        pipe.add(request)
261        try:
262            next(response_iterator)
263        except grpc.RpcError as rpc_error:
264            if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
265                raise
266        else:
267            raise ValueError('expected call to exceed deadline')
268
269
270def _empty_stream(stub):
271    with _Pipe() as pipe:
272        response_iterator = stub.FullDuplexCall(pipe)
273        pipe.close()
274        try:
275            next(response_iterator)
276            raise ValueError('expected exactly 0 responses')
277        except StopIteration:
278            pass
279
280
281def _status_code_and_message(stub):
282    details = 'test status message'
283    code = 2
284    status = grpc.StatusCode.UNKNOWN  # code = 2
285
286    # Test with a UnaryCall
287    request = messages_pb2.SimpleRequest(
288        response_type=messages_pb2.COMPRESSABLE,
289        response_size=1,
290        payload=messages_pb2.Payload(body=b'\x00'),
291        response_status=messages_pb2.EchoStatus(code=code, message=details))
292    response_future = stub.UnaryCall.future(request)
293    _validate_status_code_and_details(response_future, status, details)
294
295    # Test with a FullDuplexCall
296    with _Pipe() as pipe:
297        response_iterator = stub.FullDuplexCall(pipe)
298        request = messages_pb2.StreamingOutputCallRequest(
299            response_type=messages_pb2.COMPRESSABLE,
300            response_parameters=(messages_pb2.ResponseParameters(size=1),),
301            payload=messages_pb2.Payload(body=b'\x00'),
302            response_status=messages_pb2.EchoStatus(code=code, message=details))
303        pipe.add(request)  # sends the initial request.
304    try:
305        next(response_iterator)
306    except grpc.RpcError as rpc_error:
307        assert rpc_error.code() == status
308    # Dropping out of with block closes the pipe
309    _validate_status_code_and_details(response_iterator, status, details)
310
311
312def _unimplemented_method(test_service_stub):
313    response_future = (test_service_stub.UnimplementedCall.future(
314        empty_pb2.Empty()))
315    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
316
317
318def _unimplemented_service(unimplemented_service_stub):
319    response_future = (unimplemented_service_stub.UnimplementedCall.future(
320        empty_pb2.Empty()))
321    _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
322
323
324def _custom_metadata(stub):
325    initial_metadata_value = "test_initial_metadata_value"
326    trailing_metadata_value = b"\x0a\x0b\x0a\x0b\x0a\x0b"
327    metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
328                (_TRAILING_METADATA_KEY, trailing_metadata_value))
329
330    def _validate_metadata(response):
331        initial_metadata = dict(response.initial_metadata())
332        if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
333            raise ValueError('expected initial metadata %s, got %s' %
334                             (initial_metadata_value,
335                              initial_metadata[_INITIAL_METADATA_KEY]))
336        trailing_metadata = dict(response.trailing_metadata())
337        if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
338            raise ValueError('expected trailing metadata %s, got %s' %
339                             (trailing_metadata_value,
340                              trailing_metadata[_TRAILING_METADATA_KEY]))
341
342    # Testing with UnaryCall
343    request = messages_pb2.SimpleRequest(
344        response_type=messages_pb2.COMPRESSABLE,
345        response_size=1,
346        payload=messages_pb2.Payload(body=b'\x00'))
347    response_future = stub.UnaryCall.future(request, metadata=metadata)
348    _validate_metadata(response_future)
349
350    # Testing with FullDuplexCall
351    with _Pipe() as pipe:
352        response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
353        request = messages_pb2.StreamingOutputCallRequest(
354            response_type=messages_pb2.COMPRESSABLE,
355            response_parameters=(messages_pb2.ResponseParameters(size=1),))
356        pipe.add(request)  # Sends the request
357        next(response_iterator)  # Causes server to send trailing metadata
358    # Dropping out of the with block closes the pipe
359    _validate_metadata(response_iterator)
360
361
362def _compute_engine_creds(stub, args):
363    response = _large_unary_common_behavior(stub, True, True, None)
364    if args.default_service_account != response.username:
365        raise ValueError('expected username %s, got %s' %
366                         (args.default_service_account, response.username))
367
368
369def _oauth2_auth_token(stub, args):
370    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
371    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
372    response = _large_unary_common_behavior(stub, True, True, None)
373    if wanted_email != response.username:
374        raise ValueError('expected username %s, got %s' %
375                         (wanted_email, response.username))
376    if args.oauth_scope.find(response.oauth_scope) == -1:
377        raise ValueError(
378            'expected to find oauth scope "{}" in received "{}"'.format(
379                response.oauth_scope, args.oauth_scope))
380
381
382def _jwt_token_creds(stub, args):
383    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
384    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
385    response = _large_unary_common_behavior(stub, True, False, None)
386    if wanted_email != response.username:
387        raise ValueError('expected username %s, got %s' %
388                         (wanted_email, response.username))
389
390
391def _per_rpc_creds(stub, args):
392    json_key_filename = os.environ[google_auth_environment_vars.CREDENTIALS]
393    wanted_email = json.load(open(json_key_filename, 'r'))['client_email']
394    google_credentials, unused_project_id = google_auth.default(
395        scopes=[args.oauth_scope])
396    call_credentials = grpc.metadata_call_credentials(
397        google_auth_transport_grpc.AuthMetadataPlugin(
398            credentials=google_credentials,
399            request=google_auth_transport_requests.Request()))
400    response = _large_unary_common_behavior(stub, True, False, call_credentials)
401    if wanted_email != response.username:
402        raise ValueError('expected username %s, got %s' %
403                         (wanted_email, response.username))
404
405
406def _special_status_message(stub, args):
407    details = b'\t\ntest with whitespace\r\nand Unicode BMP \xe2\x98\xba and non-BMP \xf0\x9f\x98\x88\t\n'.decode(
408        'utf-8')
409    code = 2
410    status = grpc.StatusCode.UNKNOWN  # code = 2
411
412    # Test with a UnaryCall
413    request = messages_pb2.SimpleRequest(
414        response_type=messages_pb2.COMPRESSABLE,
415        response_size=1,
416        payload=messages_pb2.Payload(body=b'\x00'),
417        response_status=messages_pb2.EchoStatus(code=code, message=details))
418    response_future = stub.UnaryCall.future(request)
419    _validate_status_code_and_details(response_future, status, details)
420
421
422@enum.unique
423class TestCase(enum.Enum):
424    EMPTY_UNARY = 'empty_unary'
425    LARGE_UNARY = 'large_unary'
426    SERVER_STREAMING = 'server_streaming'
427    CLIENT_STREAMING = 'client_streaming'
428    PING_PONG = 'ping_pong'
429    CANCEL_AFTER_BEGIN = 'cancel_after_begin'
430    CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
431    EMPTY_STREAM = 'empty_stream'
432    STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
433    UNIMPLEMENTED_METHOD = 'unimplemented_method'
434    UNIMPLEMENTED_SERVICE = 'unimplemented_service'
435    CUSTOM_METADATA = "custom_metadata"
436    COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
437    OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
438    JWT_TOKEN_CREDS = 'jwt_token_creds'
439    PER_RPC_CREDS = 'per_rpc_creds'
440    TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
441    SPECIAL_STATUS_MESSAGE = 'special_status_message'
442
443    def test_interoperability(self, stub, args):
444        if self is TestCase.EMPTY_UNARY:
445            _empty_unary(stub)
446        elif self is TestCase.LARGE_UNARY:
447            _large_unary(stub)
448        elif self is TestCase.SERVER_STREAMING:
449            _server_streaming(stub)
450        elif self is TestCase.CLIENT_STREAMING:
451            _client_streaming(stub)
452        elif self is TestCase.PING_PONG:
453            _ping_pong(stub)
454        elif self is TestCase.CANCEL_AFTER_BEGIN:
455            _cancel_after_begin(stub)
456        elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:
457            _cancel_after_first_response(stub)
458        elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:
459            _timeout_on_sleeping_server(stub)
460        elif self is TestCase.EMPTY_STREAM:
461            _empty_stream(stub)
462        elif self is TestCase.STATUS_CODE_AND_MESSAGE:
463            _status_code_and_message(stub)
464        elif self is TestCase.UNIMPLEMENTED_METHOD:
465            _unimplemented_method(stub)
466        elif self is TestCase.UNIMPLEMENTED_SERVICE:
467            _unimplemented_service(stub)
468        elif self is TestCase.CUSTOM_METADATA:
469            _custom_metadata(stub)
470        elif self is TestCase.COMPUTE_ENGINE_CREDS:
471            _compute_engine_creds(stub, args)
472        elif self is TestCase.OAUTH2_AUTH_TOKEN:
473            _oauth2_auth_token(stub, args)
474        elif self is TestCase.JWT_TOKEN_CREDS:
475            _jwt_token_creds(stub, args)
476        elif self is TestCase.PER_RPC_CREDS:
477            _per_rpc_creds(stub, args)
478        elif self is TestCase.SPECIAL_STATUS_MESSAGE:
479            _special_status_message(stub, args)
480        else:
481            raise NotImplementedError('Test case "%s" not implemented!' %
482                                      self.name)
483