# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import logging import random import time from typing import List from unittest.mock import patch import proto import pytest import requests from google.api_core import rest_streaming from google.api import http_pb2 from google.api import httpbody_pb2 from ..helpers import Composer, Song, EchoResponse, parse_responses __protobuf__ = proto.module(package=__name__) SEED = int(time.time()) logging.info(f"Starting sync rest streaming tests with random seed: {SEED}") random.seed(SEED) class ResponseMock(requests.Response): class _ResponseItr: def __init__(self, _response_bytes: bytes, random_split=False): self._responses_bytes = _response_bytes self._i = 0 self._random_split = random_split def __next__(self): if self._i == len(self._responses_bytes): raise StopIteration if self._random_split: n = random.randint(1, len(self._responses_bytes[self._i :])) else: n = 1 x = self._responses_bytes[self._i : self._i + n] self._i += n return x.decode("utf-8") def __init__( self, responses: List[proto.Message], response_cls, random_split=False, ): super().__init__() self._responses = responses self._random_split = random_split self._response_message_cls = response_cls def _parse_responses(self): return parse_responses(self._response_message_cls, self._responses) def close(self): raise NotImplementedError() def iter_content(self, *args, **kwargs): return self._ResponseItr( self._parse_responses(), random_split=self._random_split, ) @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [(False, True), (False, False)], ) def test_next_simple(random_split, resp_message_is_proto_plus): if resp_message_is_proto_plus: response_type = EchoResponse responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")] else: response_type = httpbody_pb2.HttpBody responses = [ httpbody_pb2.HttpBody(content_type="hello world"), httpbody_pb2.HttpBody(content_type="yes"), ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [ (True, True), (False, True), (True, False), (False, False), ], ) def test_next_nested(random_split, resp_message_is_proto_plus): if resp_message_is_proto_plus: response_type = Song responses = [ Song(title="some song", composer=Composer(given_name="some name")), Song(title="another song", date_added=datetime.datetime(2021, 12, 17)), ] else: # Although `http_pb2.HttpRule`` is used in the response, any response message # can be used which meets this criteria for the test of having a nested field. response_type = http_pb2.HttpRule responses = [ http_pb2.HttpRule( selector="some selector", custom=http_pb2.CustomHttpPattern(kind="some kind"), ), http_pb2.HttpRule( selector="another selector", custom=http_pb2.CustomHttpPattern(path="some path"), ), ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [ (True, True), (False, True), (True, False), (False, False), ], ) def test_next_stress(random_split, resp_message_is_proto_plus): n = 50 if resp_message_is_proto_plus: response_type = Song responses = [ Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i)) for i in range(n) ] else: response_type = http_pb2.HttpRule responses = [ http_pb2.HttpRule( selector="selector_%d" % i, custom=http_pb2.CustomHttpPattern(path="path_%d" % i), ) for i in range(n) ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @pytest.mark.parametrize( "random_split,resp_message_is_proto_plus", [ (True, True), (False, True), (True, False), (False, False), ], ) def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus): if resp_message_is_proto_plus: response_type = Song composer_with_relateds = Composer() relateds = ["Artist A", "Artist B"] composer_with_relateds.relateds = relateds responses = [ Song( title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n") ), Song( title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\"), ), Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds), ] else: response_type = http_pb2.Http responses = [ http_pb2.Http( rules=[ http_pb2.HttpRule( selector='ti"tle\nfoo\tbar{}', custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"), ) ] ), http_pb2.Http( rules=[ http_pb2.HttpRule( selector='{"this is weird": "totally"}', custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), ) ] ), http_pb2.Http( rules=[ http_pb2.HttpRule( selector='\\{"key": ["value",]}\\', custom=http_pb2.CustomHttpPattern(kind="\\{}\\"), ) ] ), ] resp = ResponseMock( responses=responses, random_split=random_split, response_cls=response_type ) itr = rest_streaming.ResponseIterator(resp, response_type) assert list(itr) == responses @pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) def test_next_not_array(response_type): with patch.object( ResponseMock, "iter_content", return_value=iter('{"hello": 0}') ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) itr = rest_streaming.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() @pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) def test_cancel(response_type): with patch.object(ResponseMock, "close", return_value=None) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) itr = rest_streaming.ResponseIterator(resp, response_type) itr.cancel() mock_method.assert_called_once() @pytest.mark.parametrize( "response_type,return_value", [ (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")), (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")), ], ) def test_check_buffer(response_type, return_value): with patch.object( ResponseMock, "_parse_responses", return_value=return_value, ): resp = ResponseMock(responses=[], response_cls=response_type) itr = rest_streaming.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) next(itr) @pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody]) def test_next_html(response_type): with patch.object( ResponseMock, "iter_content", return_value=iter("") ) as mock_method: resp = ResponseMock(responses=[], response_cls=response_type) itr = rest_streaming.ResponseIterator(resp, response_type) with pytest.raises(ValueError): next(itr) mock_method.assert_called_once() def test_invalid_response_class(): class SomeClass: pass resp = ResponseMock(responses=[], response_cls=SomeClass) with pytest.raises( ValueError, match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message", ): rest_streaming.ResponseIterator(resp, SomeClass)