• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests application-provided metadata, status code, and details."""
15
16import logging
17import threading
18import unittest
19
20import grpc
21
22from tests.unit import test_common
23from tests.unit.framework.common import test_constants
24from tests.unit.framework.common import test_control
25
26_SERIALIZED_REQUEST = b"\x46\x47\x48"
27_SERIALIZED_RESPONSE = b"\x49\x50\x51"
28
29_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST
30_REQUEST_DESERIALIZER = lambda unused_serialized_request: object()
31_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE
32_RESPONSE_DESERIALIZER = lambda unused_serialized_response: object()
33
34_SERVICE = "test.TestService"
35_UNARY_UNARY = "UnaryUnary"
36_UNARY_STREAM = "UnaryStream"
37_STREAM_UNARY = "StreamUnary"
38_STREAM_STREAM = "StreamStream"
39
40_CLIENT_METADATA = (
41    ("client-md-key", "client-md-key"),
42    ("client-md-key-bin", b"\x00\x01"),
43)
44
45_SERVER_INITIAL_METADATA = (
46    ("server-initial-md-key", "server-initial-md-value"),
47    ("server-initial-md-key-bin", b"\x00\x02"),
48)
49
50_SERVER_TRAILING_METADATA = (
51    ("server-trailing-md-key", "server-trailing-md-value"),
52    ("server-trailing-md-key-bin", b"\x00\x03"),
53)
54
55_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
56_DETAILS = "Test details!"
57
58# calling abort should always fail an RPC, even for "invalid" codes
59_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK)
60_EXPECTED_CLIENT_CODES = (
61    _NON_OK_CODE,
62    grpc.StatusCode.UNKNOWN,
63    grpc.StatusCode.UNKNOWN,
64)
65_EXPECTED_DETAILS = (_DETAILS, _DETAILS, "")
66
67
68class _Servicer(object):
69    def __init__(self):
70        self._lock = threading.Lock()
71        self._abort_call = False
72        self._code = None
73        self._details = None
74        self._exception = False
75        self._return_none = False
76        self._received_client_metadata = None
77
78    def unary_unary(self, request, context):
79        with self._lock:
80            self._received_client_metadata = context.invocation_metadata()
81            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
82            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
83            if self._abort_call:
84                context.abort(self._code, self._details)
85            else:
86                if self._code is not None:
87                    context.set_code(self._code)
88                if self._details is not None:
89                    context.set_details(self._details)
90            if self._exception:
91                raise test_control.Defect()
92            else:
93                return None if self._return_none else object()
94
95    def unary_stream(self, request, context):
96        with self._lock:
97            self._received_client_metadata = context.invocation_metadata()
98            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
99            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
100            if self._abort_call:
101                context.abort(self._code, self._details)
102            else:
103                if self._code is not None:
104                    context.set_code(self._code)
105                if self._details is not None:
106                    context.set_details(self._details)
107            for _ in range(test_constants.STREAM_LENGTH // 2):
108                yield _SERIALIZED_RESPONSE
109            if self._exception:
110                raise test_control.Defect()
111
112    def stream_unary(self, request_iterator, context):
113        with self._lock:
114            self._received_client_metadata = context.invocation_metadata()
115            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
116            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
117            # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
118            # request iterator.
119            list(request_iterator)
120            if self._abort_call:
121                context.abort(self._code, self._details)
122            else:
123                if self._code is not None:
124                    context.set_code(self._code)
125                if self._details is not None:
126                    context.set_details(self._details)
127            if self._exception:
128                raise test_control.Defect()
129            else:
130                return None if self._return_none else _SERIALIZED_RESPONSE
131
132    def stream_stream(self, request_iterator, context):
133        with self._lock:
134            self._received_client_metadata = context.invocation_metadata()
135            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
136            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
137            # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
138            # request iterator.
139            list(request_iterator)
140            if self._abort_call:
141                context.abort(self._code, self._details)
142            else:
143                if self._code is not None:
144                    context.set_code(self._code)
145                if self._details is not None:
146                    context.set_details(self._details)
147            for _ in range(test_constants.STREAM_LENGTH // 3):
148                yield object()
149            if self._exception:
150                raise test_control.Defect()
151
152    def set_abort_call(self):
153        with self._lock:
154            self._abort_call = True
155
156    def set_code(self, code):
157        with self._lock:
158            self._code = code
159
160    def set_details(self, details):
161        with self._lock:
162            self._details = details
163
164    def set_exception(self):
165        with self._lock:
166            self._exception = True
167
168    def set_return_none(self):
169        with self._lock:
170            self._return_none = True
171
172    def received_client_metadata(self):
173        with self._lock:
174            return self._received_client_metadata
175
176
177def get_method_handlers(servicer):
178    return {
179        _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(
180            servicer.unary_unary,
181            request_deserializer=_REQUEST_DESERIALIZER,
182            response_serializer=_RESPONSE_SERIALIZER,
183        ),
184        _UNARY_STREAM: grpc.unary_stream_rpc_method_handler(
185            servicer.unary_stream
186        ),
187        _STREAM_UNARY: grpc.stream_unary_rpc_method_handler(
188            servicer.stream_unary
189        ),
190        _STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
191            servicer.stream_stream,
192            request_deserializer=_REQUEST_DESERIALIZER,
193            response_serializer=_RESPONSE_SERIALIZER,
194        ),
195    }
196
197
198class MetadataCodeDetailsTest(unittest.TestCase):
199    def setUp(self):
200        self._servicer = _Servicer()
201        self._server = test_common.test_server()
202        self._server.add_registered_method_handlers(
203            _SERVICE, get_method_handlers(self._servicer)
204        )
205        port = self._server.add_insecure_port("[::]:0")
206        self._server.start()
207
208        self._channel = grpc.insecure_channel("localhost:{}".format(port))
209        unary_unary_method_name = "/".join(
210            (
211                "",
212                _SERVICE,
213                _UNARY_UNARY,
214            )
215        )
216        self._unary_unary = self._channel.unary_unary(
217            unary_unary_method_name,
218            request_serializer=_REQUEST_SERIALIZER,
219            response_deserializer=_RESPONSE_DESERIALIZER,
220            _registered_method=True,
221        )
222        unary_stream_method_name = "/".join(
223            (
224                "",
225                _SERVICE,
226                _UNARY_STREAM,
227            )
228        )
229        self._unary_stream = self._channel.unary_stream(
230            unary_stream_method_name,
231            _registered_method=True,
232        )
233        stream_unary_method_name = "/".join(
234            (
235                "",
236                _SERVICE,
237                _STREAM_UNARY,
238            )
239        )
240        self._stream_unary = self._channel.stream_unary(
241            stream_unary_method_name,
242            _registered_method=True,
243        )
244        stream_stream_method_name = "/".join(
245            (
246                "",
247                _SERVICE,
248                _STREAM_STREAM,
249            )
250        )
251        self._stream_stream = self._channel.stream_stream(
252            stream_stream_method_name,
253            request_serializer=_REQUEST_SERIALIZER,
254            response_deserializer=_RESPONSE_DESERIALIZER,
255            _registered_method=True,
256        )
257
258    def tearDown(self):
259        self._server.stop(None)
260        self._channel.close()
261
262    def testSuccessfulUnaryUnary(self):
263        self._servicer.set_details(_DETAILS)
264
265        unused_response, call = self._unary_unary.with_call(
266            object(), metadata=_CLIENT_METADATA
267        )
268
269        self.assertTrue(
270            test_common.metadata_transmitted(
271                _CLIENT_METADATA, self._servicer.received_client_metadata()
272            )
273        )
274        self.assertTrue(
275            test_common.metadata_transmitted(
276                _SERVER_INITIAL_METADATA, call.initial_metadata()
277            )
278        )
279        self.assertTrue(
280            test_common.metadata_transmitted(
281                _SERVER_TRAILING_METADATA, call.trailing_metadata()
282            )
283        )
284        self.assertIs(grpc.StatusCode.OK, call.code())
285
286    def testSuccessfulUnaryStream(self):
287        self._servicer.set_details(_DETAILS)
288
289        response_iterator_call = self._unary_stream(
290            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
291        )
292        received_initial_metadata = response_iterator_call.initial_metadata()
293        list(response_iterator_call)
294
295        self.assertTrue(
296            test_common.metadata_transmitted(
297                _CLIENT_METADATA, self._servicer.received_client_metadata()
298            )
299        )
300        self.assertTrue(
301            test_common.metadata_transmitted(
302                _SERVER_INITIAL_METADATA, received_initial_metadata
303            )
304        )
305        self.assertTrue(
306            test_common.metadata_transmitted(
307                _SERVER_TRAILING_METADATA,
308                response_iterator_call.trailing_metadata(),
309            )
310        )
311        self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
312
313    def testSuccessfulStreamUnary(self):
314        self._servicer.set_details(_DETAILS)
315
316        unused_response, call = self._stream_unary.with_call(
317            iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
318            metadata=_CLIENT_METADATA,
319        )
320
321        self.assertTrue(
322            test_common.metadata_transmitted(
323                _CLIENT_METADATA, self._servicer.received_client_metadata()
324            )
325        )
326        self.assertTrue(
327            test_common.metadata_transmitted(
328                _SERVER_INITIAL_METADATA, call.initial_metadata()
329            )
330        )
331        self.assertTrue(
332            test_common.metadata_transmitted(
333                _SERVER_TRAILING_METADATA, call.trailing_metadata()
334            )
335        )
336        self.assertIs(grpc.StatusCode.OK, call.code())
337
338    def testSuccessfulStreamStream(self):
339        self._servicer.set_details(_DETAILS)
340
341        response_iterator_call = self._stream_stream(
342            iter([object()] * test_constants.STREAM_LENGTH),
343            metadata=_CLIENT_METADATA,
344        )
345        received_initial_metadata = response_iterator_call.initial_metadata()
346        list(response_iterator_call)
347
348        self.assertTrue(
349            test_common.metadata_transmitted(
350                _CLIENT_METADATA, self._servicer.received_client_metadata()
351            )
352        )
353        self.assertTrue(
354            test_common.metadata_transmitted(
355                _SERVER_INITIAL_METADATA, received_initial_metadata
356            )
357        )
358        self.assertTrue(
359            test_common.metadata_transmitted(
360                _SERVER_TRAILING_METADATA,
361                response_iterator_call.trailing_metadata(),
362            )
363        )
364        self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
365
366    def testAbortedUnaryUnary(self):
367        test_cases = zip(
368            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
369        )
370        for abort_code, expected_code, expected_details in test_cases:
371            self._servicer.set_code(abort_code)
372            self._servicer.set_details(_DETAILS)
373            self._servicer.set_abort_call()
374
375            with self.assertRaises(grpc.RpcError) as exception_context:
376                self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
377
378            self.assertTrue(
379                test_common.metadata_transmitted(
380                    _CLIENT_METADATA, self._servicer.received_client_metadata()
381                )
382            )
383            self.assertTrue(
384                test_common.metadata_transmitted(
385                    _SERVER_INITIAL_METADATA,
386                    exception_context.exception.initial_metadata(),
387                )
388            )
389            self.assertTrue(
390                test_common.metadata_transmitted(
391                    _SERVER_TRAILING_METADATA,
392                    exception_context.exception.trailing_metadata(),
393                )
394            )
395            self.assertIs(expected_code, exception_context.exception.code())
396            self.assertEqual(
397                expected_details, exception_context.exception.details()
398            )
399
400    def testAbortedUnaryStream(self):
401        test_cases = zip(
402            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
403        )
404        for abort_code, expected_code, expected_details in test_cases:
405            self._servicer.set_code(abort_code)
406            self._servicer.set_details(_DETAILS)
407            self._servicer.set_abort_call()
408
409            response_iterator_call = self._unary_stream(
410                _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
411            )
412            received_initial_metadata = (
413                response_iterator_call.initial_metadata()
414            )
415            with self.assertRaises(grpc.RpcError):
416                self.assertEqual(len(list(response_iterator_call)), 0)
417
418            self.assertTrue(
419                test_common.metadata_transmitted(
420                    _CLIENT_METADATA, self._servicer.received_client_metadata()
421                )
422            )
423            self.assertTrue(
424                test_common.metadata_transmitted(
425                    _SERVER_INITIAL_METADATA, received_initial_metadata
426                )
427            )
428            self.assertTrue(
429                test_common.metadata_transmitted(
430                    _SERVER_TRAILING_METADATA,
431                    response_iterator_call.trailing_metadata(),
432                )
433            )
434            self.assertIs(expected_code, response_iterator_call.code())
435            self.assertEqual(expected_details, response_iterator_call.details())
436
437    def testAbortedStreamUnary(self):
438        test_cases = zip(
439            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
440        )
441        for abort_code, expected_code, expected_details in test_cases:
442            self._servicer.set_code(abort_code)
443            self._servicer.set_details(_DETAILS)
444            self._servicer.set_abort_call()
445
446            with self.assertRaises(grpc.RpcError) as exception_context:
447                self._stream_unary.with_call(
448                    iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
449                    metadata=_CLIENT_METADATA,
450                )
451
452            self.assertTrue(
453                test_common.metadata_transmitted(
454                    _CLIENT_METADATA, self._servicer.received_client_metadata()
455                )
456            )
457            self.assertTrue(
458                test_common.metadata_transmitted(
459                    _SERVER_INITIAL_METADATA,
460                    exception_context.exception.initial_metadata(),
461                )
462            )
463            self.assertTrue(
464                test_common.metadata_transmitted(
465                    _SERVER_TRAILING_METADATA,
466                    exception_context.exception.trailing_metadata(),
467                )
468            )
469            self.assertIs(expected_code, exception_context.exception.code())
470            self.assertEqual(
471                expected_details, exception_context.exception.details()
472            )
473
474    def testAbortedStreamStream(self):
475        test_cases = zip(
476            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
477        )
478        for abort_code, expected_code, expected_details in test_cases:
479            self._servicer.set_code(abort_code)
480            self._servicer.set_details(_DETAILS)
481            self._servicer.set_abort_call()
482
483            response_iterator_call = self._stream_stream(
484                iter([object()] * test_constants.STREAM_LENGTH),
485                metadata=_CLIENT_METADATA,
486            )
487            received_initial_metadata = (
488                response_iterator_call.initial_metadata()
489            )
490            with self.assertRaises(grpc.RpcError):
491                self.assertEqual(len(list(response_iterator_call)), 0)
492
493            self.assertTrue(
494                test_common.metadata_transmitted(
495                    _CLIENT_METADATA, self._servicer.received_client_metadata()
496                )
497            )
498            self.assertTrue(
499                test_common.metadata_transmitted(
500                    _SERVER_INITIAL_METADATA, received_initial_metadata
501                )
502            )
503            self.assertTrue(
504                test_common.metadata_transmitted(
505                    _SERVER_TRAILING_METADATA,
506                    response_iterator_call.trailing_metadata(),
507                )
508            )
509            self.assertIs(expected_code, response_iterator_call.code())
510            self.assertEqual(expected_details, response_iterator_call.details())
511
512    def testCustomCodeUnaryUnary(self):
513        self._servicer.set_code(_NON_OK_CODE)
514        self._servicer.set_details(_DETAILS)
515
516        with self.assertRaises(grpc.RpcError) as exception_context:
517            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
518
519        self.assertTrue(
520            test_common.metadata_transmitted(
521                _CLIENT_METADATA, self._servicer.received_client_metadata()
522            )
523        )
524        self.assertTrue(
525            test_common.metadata_transmitted(
526                _SERVER_INITIAL_METADATA,
527                exception_context.exception.initial_metadata(),
528            )
529        )
530        self.assertTrue(
531            test_common.metadata_transmitted(
532                _SERVER_TRAILING_METADATA,
533                exception_context.exception.trailing_metadata(),
534            )
535        )
536        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
537        self.assertEqual(_DETAILS, exception_context.exception.details())
538
539    def testCustomCodeUnaryStream(self):
540        self._servicer.set_code(_NON_OK_CODE)
541        self._servicer.set_details(_DETAILS)
542
543        response_iterator_call = self._unary_stream(
544            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
545        )
546        received_initial_metadata = response_iterator_call.initial_metadata()
547        with self.assertRaises(grpc.RpcError):
548            list(response_iterator_call)
549
550        self.assertTrue(
551            test_common.metadata_transmitted(
552                _CLIENT_METADATA, self._servicer.received_client_metadata()
553            )
554        )
555        self.assertTrue(
556            test_common.metadata_transmitted(
557                _SERVER_INITIAL_METADATA, received_initial_metadata
558            )
559        )
560        self.assertTrue(
561            test_common.metadata_transmitted(
562                _SERVER_TRAILING_METADATA,
563                response_iterator_call.trailing_metadata(),
564            )
565        )
566        self.assertIs(_NON_OK_CODE, response_iterator_call.code())
567        self.assertEqual(_DETAILS, response_iterator_call.details())
568
569    def testCustomCodeStreamUnary(self):
570        self._servicer.set_code(_NON_OK_CODE)
571        self._servicer.set_details(_DETAILS)
572
573        with self.assertRaises(grpc.RpcError) as exception_context:
574            self._stream_unary.with_call(
575                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
576                metadata=_CLIENT_METADATA,
577            )
578
579        self.assertTrue(
580            test_common.metadata_transmitted(
581                _CLIENT_METADATA, self._servicer.received_client_metadata()
582            )
583        )
584        self.assertTrue(
585            test_common.metadata_transmitted(
586                _SERVER_INITIAL_METADATA,
587                exception_context.exception.initial_metadata(),
588            )
589        )
590        self.assertTrue(
591            test_common.metadata_transmitted(
592                _SERVER_TRAILING_METADATA,
593                exception_context.exception.trailing_metadata(),
594            )
595        )
596        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
597        self.assertEqual(_DETAILS, exception_context.exception.details())
598
599    def testCustomCodeStreamStream(self):
600        self._servicer.set_code(_NON_OK_CODE)
601        self._servicer.set_details(_DETAILS)
602
603        response_iterator_call = self._stream_stream(
604            iter([object()] * test_constants.STREAM_LENGTH),
605            metadata=_CLIENT_METADATA,
606        )
607        received_initial_metadata = response_iterator_call.initial_metadata()
608        with self.assertRaises(grpc.RpcError) as exception_context:
609            list(response_iterator_call)
610
611        self.assertTrue(
612            test_common.metadata_transmitted(
613                _CLIENT_METADATA, self._servicer.received_client_metadata()
614            )
615        )
616        self.assertTrue(
617            test_common.metadata_transmitted(
618                _SERVER_INITIAL_METADATA, received_initial_metadata
619            )
620        )
621        self.assertTrue(
622            test_common.metadata_transmitted(
623                _SERVER_TRAILING_METADATA,
624                exception_context.exception.trailing_metadata(),
625            )
626        )
627        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
628        self.assertEqual(_DETAILS, exception_context.exception.details())
629
630    def testCustomCodeExceptionUnaryUnary(self):
631        self._servicer.set_code(_NON_OK_CODE)
632        self._servicer.set_details(_DETAILS)
633        self._servicer.set_exception()
634
635        with self.assertRaises(grpc.RpcError) as exception_context:
636            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
637
638        self.assertTrue(
639            test_common.metadata_transmitted(
640                _CLIENT_METADATA, self._servicer.received_client_metadata()
641            )
642        )
643        self.assertTrue(
644            test_common.metadata_transmitted(
645                _SERVER_INITIAL_METADATA,
646                exception_context.exception.initial_metadata(),
647            )
648        )
649        self.assertTrue(
650            test_common.metadata_transmitted(
651                _SERVER_TRAILING_METADATA,
652                exception_context.exception.trailing_metadata(),
653            )
654        )
655        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
656        self.assertEqual(_DETAILS, exception_context.exception.details())
657
658    def testCustomCodeExceptionUnaryStream(self):
659        self._servicer.set_code(_NON_OK_CODE)
660        self._servicer.set_details(_DETAILS)
661        self._servicer.set_exception()
662
663        response_iterator_call = self._unary_stream(
664            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
665        )
666        received_initial_metadata = response_iterator_call.initial_metadata()
667        with self.assertRaises(grpc.RpcError):
668            list(response_iterator_call)
669
670        self.assertTrue(
671            test_common.metadata_transmitted(
672                _CLIENT_METADATA, self._servicer.received_client_metadata()
673            )
674        )
675        self.assertTrue(
676            test_common.metadata_transmitted(
677                _SERVER_INITIAL_METADATA, received_initial_metadata
678            )
679        )
680        self.assertTrue(
681            test_common.metadata_transmitted(
682                _SERVER_TRAILING_METADATA,
683                response_iterator_call.trailing_metadata(),
684            )
685        )
686        self.assertIs(_NON_OK_CODE, response_iterator_call.code())
687        self.assertEqual(_DETAILS, response_iterator_call.details())
688
689    def testCustomCodeExceptionStreamUnary(self):
690        self._servicer.set_code(_NON_OK_CODE)
691        self._servicer.set_details(_DETAILS)
692        self._servicer.set_exception()
693
694        with self.assertRaises(grpc.RpcError) as exception_context:
695            self._stream_unary.with_call(
696                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
697                metadata=_CLIENT_METADATA,
698            )
699
700        self.assertTrue(
701            test_common.metadata_transmitted(
702                _CLIENT_METADATA, self._servicer.received_client_metadata()
703            )
704        )
705        self.assertTrue(
706            test_common.metadata_transmitted(
707                _SERVER_INITIAL_METADATA,
708                exception_context.exception.initial_metadata(),
709            )
710        )
711        self.assertTrue(
712            test_common.metadata_transmitted(
713                _SERVER_TRAILING_METADATA,
714                exception_context.exception.trailing_metadata(),
715            )
716        )
717        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
718        self.assertEqual(_DETAILS, exception_context.exception.details())
719
720    def testCustomCodeExceptionStreamStream(self):
721        self._servicer.set_code(_NON_OK_CODE)
722        self._servicer.set_details(_DETAILS)
723        self._servicer.set_exception()
724
725        response_iterator_call = self._stream_stream(
726            iter([object()] * test_constants.STREAM_LENGTH),
727            metadata=_CLIENT_METADATA,
728        )
729        received_initial_metadata = response_iterator_call.initial_metadata()
730        with self.assertRaises(grpc.RpcError):
731            list(response_iterator_call)
732
733        self.assertTrue(
734            test_common.metadata_transmitted(
735                _CLIENT_METADATA, self._servicer.received_client_metadata()
736            )
737        )
738        self.assertTrue(
739            test_common.metadata_transmitted(
740                _SERVER_INITIAL_METADATA, received_initial_metadata
741            )
742        )
743        self.assertTrue(
744            test_common.metadata_transmitted(
745                _SERVER_TRAILING_METADATA,
746                response_iterator_call.trailing_metadata(),
747            )
748        )
749        self.assertIs(_NON_OK_CODE, response_iterator_call.code())
750        self.assertEqual(_DETAILS, response_iterator_call.details())
751
752    def testCustomCodeReturnNoneUnaryUnary(self):
753        self._servicer.set_code(_NON_OK_CODE)
754        self._servicer.set_details(_DETAILS)
755        self._servicer.set_return_none()
756
757        with self.assertRaises(grpc.RpcError) as exception_context:
758            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
759
760        self.assertTrue(
761            test_common.metadata_transmitted(
762                _CLIENT_METADATA, self._servicer.received_client_metadata()
763            )
764        )
765        self.assertTrue(
766            test_common.metadata_transmitted(
767                _SERVER_INITIAL_METADATA,
768                exception_context.exception.initial_metadata(),
769            )
770        )
771        self.assertTrue(
772            test_common.metadata_transmitted(
773                _SERVER_TRAILING_METADATA,
774                exception_context.exception.trailing_metadata(),
775            )
776        )
777        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
778        self.assertEqual(_DETAILS, exception_context.exception.details())
779
780    def testCustomCodeReturnNoneStreamUnary(self):
781        self._servicer.set_code(_NON_OK_CODE)
782        self._servicer.set_details(_DETAILS)
783        self._servicer.set_return_none()
784
785        with self.assertRaises(grpc.RpcError) as exception_context:
786            self._stream_unary.with_call(
787                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
788                metadata=_CLIENT_METADATA,
789            )
790
791        self.assertTrue(
792            test_common.metadata_transmitted(
793                _CLIENT_METADATA, self._servicer.received_client_metadata()
794            )
795        )
796        self.assertTrue(
797            test_common.metadata_transmitted(
798                _SERVER_INITIAL_METADATA,
799                exception_context.exception.initial_metadata(),
800            )
801        )
802        self.assertTrue(
803            test_common.metadata_transmitted(
804                _SERVER_TRAILING_METADATA,
805                exception_context.exception.trailing_metadata(),
806            )
807        )
808        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
809        self.assertEqual(_DETAILS, exception_context.exception.details())
810
811
812class _InspectServicer(_Servicer):
813    def __init__(self):
814        super(_InspectServicer, self).__init__()
815        self.actual_code = None
816        self.actual_details = None
817        self.actual_trailing_metadata = None
818
819    def unary_unary(self, request, context):
820        super(_InspectServicer, self).unary_unary(request, context)
821
822        self.actual_code = context.code()
823        self.actual_details = context.details()
824        self.actual_trailing_metadata = context.trailing_metadata()
825
826
827class InspectContextTest(unittest.TestCase):
828    def setUp(self):
829        self._servicer = _InspectServicer()
830        self._server = test_common.test_server()
831        self._server.add_registered_method_handlers(
832            _SERVICE, get_method_handlers(self._servicer)
833        )
834        port = self._server.add_insecure_port("[::]:0")
835        self._server.start()
836
837        self._channel = grpc.insecure_channel("localhost:{}".format(port))
838        unary_unary_method_name = "/".join(
839            (
840                "",
841                _SERVICE,
842                _UNARY_UNARY,
843            )
844        )
845        self._unary_unary = self._channel.unary_unary(
846            unary_unary_method_name,
847            request_serializer=_REQUEST_SERIALIZER,
848            response_deserializer=_RESPONSE_DESERIALIZER,
849            _registered_method=True,
850        )
851
852    def tearDown(self):
853        self._server.stop(None)
854        self._channel.close()
855
856    def testCodeDetailsInContext(self):
857        self._servicer.set_code(_NON_OK_CODE)
858        self._servicer.set_details(_DETAILS)
859
860        with self.assertRaises(grpc.RpcError) as exc_info:
861            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
862
863        err = exc_info.exception
864        self.assertEqual(_NON_OK_CODE, err.code())
865
866        self.assertEqual(self._servicer.actual_code, _NON_OK_CODE)
867        self.assertEqual(
868            self._servicer.actual_details.decode("utf-8"), _DETAILS
869        )
870        self.assertEqual(
871            self._servicer.actual_trailing_metadata, _SERVER_TRAILING_METADATA
872        )
873
874
875if __name__ == "__main__":
876    logging.basicConfig()
877    unittest.main(verbosity=2)
878