1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Base class for RpcOp tests.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23import numpy as np 24 25from tensorflow.contrib.proto.python.ops import decode_proto_op 26from tensorflow.contrib.proto.python.ops import encode_proto_op 27from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2 28from tensorflow.contrib.rpc.python.ops import rpc_op 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32 33__all__ = ['I_WARNED_YOU', 'RpcOpTestBase'] 34 35I_WARNED_YOU = 'I warned you!' 36 37 38class RpcOpTestBase(object): 39 # pylint: disable=missing-docstring,invalid-name 40 """Base class for RpcOp tests.""" 41 42 def get_method_name(self, suffix): 43 raise NotImplementedError 44 45 def rpc(self, *args, **kwargs): 46 return rpc_op.rpc(*args, protocol=self._protocol, **kwargs) 47 48 def try_rpc(self, *args, **kwargs): 49 return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs) 50 51 def testScalarHostPortRpc(self): 52 with self.cached_session() as sess: 53 request_tensors = ( 54 test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) 55 response_tensors = self.rpc( 56 method=self.get_method_name('Increment'), 57 address=self._address, 58 request=request_tensors) 59 self.assertEqual(response_tensors.shape, ()) 60 response_values = sess.run(response_tensors) 61 response_message = test_example_pb2.TestCase() 62 self.assertTrue(response_message.ParseFromString(response_values)) 63 self.assertAllEqual([2, 3, 4], response_message.values) 64 65 def testScalarHostPortTryRpc(self): 66 with self.cached_session() as sess: 67 request_tensors = ( 68 test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) 69 response_tensors, status_code, status_message = self.try_rpc( 70 method=self.get_method_name('Increment'), 71 address=self._address, 72 request=request_tensors) 73 self.assertEqual(status_code.shape, ()) 74 self.assertEqual(status_message.shape, ()) 75 self.assertEqual(response_tensors.shape, ()) 76 response_values, status_code_values, status_message_values = ( 77 sess.run((response_tensors, status_code, status_message))) 78 response_message = test_example_pb2.TestCase() 79 self.assertTrue(response_message.ParseFromString(response_values)) 80 self.assertAllEqual([2, 3, 4], response_message.values) 81 # For the base Rpc op, don't expect to get error status back. 82 self.assertEqual(errors.OK, status_code_values) 83 self.assertEqual(b'', status_message_values) 84 85 def testEmptyHostPortRpc(self): 86 with self.cached_session() as sess: 87 request_tensors = [] 88 response_tensors = self.rpc( 89 method=self.get_method_name('Increment'), 90 address=self._address, 91 request=request_tensors) 92 self.assertAllEqual(response_tensors.shape, [0]) 93 response_values = sess.run(response_tensors) 94 self.assertAllEqual(response_values.shape, [0]) 95 96 def testInvalidMethod(self): 97 for method in [ 98 '/InvalidService.Increment', 99 self.get_method_name('InvalidMethodName') 100 ]: 101 with self.cached_session() as sess: 102 with self.assertRaisesOpError(self.invalid_method_string): 103 sess.run(self.rpc(method=method, address=self._address, request='')) 104 105 _, status_code_value, status_message_value = sess.run( 106 self.try_rpc(method=method, address=self._address, request='')) 107 self.assertEqual(errors.UNIMPLEMENTED, status_code_value) 108 self.assertTrue( 109 self.invalid_method_string in status_message_value.decode('ascii')) 110 111 def testInvalidAddress(self): 112 # This covers the case of address='' and address='localhost:293874293874' 113 address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' 114 with self.cached_session() as sess: 115 with self.assertRaises(errors.UnavailableError): 116 sess.run( 117 self.rpc( 118 method=self.get_method_name('Increment'), 119 address=address, 120 request='')) 121 _, status_code_value, status_message_value = sess.run( 122 self.try_rpc( 123 method=self.get_method_name('Increment'), 124 address=address, 125 request='')) 126 self.assertEqual(errors.UNAVAILABLE, status_code_value) 127 self.assertTrue( 128 self.connect_failed_string in status_message_value.decode('ascii')) 129 130 def testAlwaysFailingMethod(self): 131 with self.cached_session() as sess: 132 response_tensors = self.rpc( 133 method=self.get_method_name('AlwaysFailWithInvalidArgument'), 134 address=self._address, 135 request='') 136 self.assertEqual(response_tensors.shape, ()) 137 with self.assertRaisesOpError(I_WARNED_YOU): 138 sess.run(response_tensors) 139 140 response_tensors, status_code, status_message = self.try_rpc( 141 method=self.get_method_name('AlwaysFailWithInvalidArgument'), 142 address=self._address, 143 request='') 144 self.assertEqual(response_tensors.shape, ()) 145 self.assertEqual(status_code.shape, ()) 146 self.assertEqual(status_message.shape, ()) 147 status_code_value, status_message_value = sess.run((status_code, 148 status_message)) 149 self.assertEqual(errors.INVALID_ARGUMENT, status_code_value) 150 self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii')) 151 152 def testSometimesFailingMethodWithManyRequests(self): 153 with self.cached_session() as sess: 154 # Fail hard by default. 155 response_tensors = self.rpc( 156 method=self.get_method_name('SometimesFailWithInvalidArgument'), 157 address=self._address, 158 request=[''] * 20) 159 self.assertEqual(response_tensors.shape, (20,)) 160 with self.assertRaisesOpError(I_WARNED_YOU): 161 sess.run(response_tensors) 162 163 # Don't fail hard, use TryRpc - return the failing status instead. 164 response_tensors, status_code, status_message = self.try_rpc( 165 method=self.get_method_name('SometimesFailWithInvalidArgument'), 166 address=self._address, 167 request=[''] * 20) 168 self.assertEqual(response_tensors.shape, (20,)) 169 self.assertEqual(status_code.shape, (20,)) 170 self.assertEqual(status_message.shape, (20,)) 171 status_code_values, status_message_values = sess.run((status_code, 172 status_message)) 173 self.assertTrue([ 174 x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values 175 ]) 176 expected_message_values = np.where( 177 status_code_values == errors.INVALID_ARGUMENT, 178 I_WARNED_YOU.encode('ascii'), b'') 179 for msg, expected in zip(status_message_values, expected_message_values): 180 self.assertTrue(expected in msg, 181 '"%s" did not contain "%s"' % (msg, expected)) 182 183 def testVecHostPortRpc(self): 184 with self.cached_session() as sess: 185 request_tensors = [ 186 test_example_pb2.TestCase( 187 values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) 188 ] 189 response_tensors = self.rpc( 190 method=self.get_method_name('Increment'), 191 address=self._address, 192 request=request_tensors) 193 self.assertEqual(response_tensors.shape, (20,)) 194 response_values = sess.run(response_tensors) 195 self.assertEqual(response_values.shape, (20,)) 196 for i in range(20): 197 response_message = test_example_pb2.TestCase() 198 self.assertTrue(response_message.ParseFromString(response_values[i])) 199 self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) 200 201 def testVecHostPortManyParallelRpcs(self): 202 with self.cached_session() as sess: 203 request_tensors = [ 204 test_example_pb2.TestCase( 205 values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) 206 ] 207 many_response_tensors = [ 208 self.rpc( 209 method=self.get_method_name('Increment'), 210 address=self._address, 211 request=request_tensors) for _ in range(10) 212 ] 213 # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests. 214 many_response_values = sess.run(many_response_tensors) 215 self.assertEqual(10, len(many_response_values)) 216 for response_values in many_response_values: 217 self.assertEqual(response_values.shape, (20,)) 218 for i in range(20): 219 response_message = test_example_pb2.TestCase() 220 self.assertTrue(response_message.ParseFromString(response_values[i])) 221 self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) 222 223 def testVecHostPortRpcUsingEncodeAndDecodeProto(self): 224 with self.cached_session() as sess: 225 request_tensors = encode_proto_op.encode_proto( 226 message_type='tensorflow.contrib.rpc.TestCase', 227 field_names=['values'], 228 sizes=[[3]] * 20, 229 values=[ 230 [[i, i + 1, i + 2] for i in range(20)], 231 ]) 232 response_tensor_strings = self.rpc( 233 method=self.get_method_name('Increment'), 234 address=self._address, 235 request=request_tensors) 236 _, (response_shape,) = decode_proto_op.decode_proto( 237 bytes=response_tensor_strings, 238 message_type='tensorflow.contrib.rpc.TestCase', 239 field_names=['values'], 240 output_types=[dtypes.int32]) 241 response_shape_values = sess.run(response_shape) 242 self.assertAllEqual([[i + 1, i + 2, i + 3] 243 for i in range(20)], response_shape_values) 244 245 def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self): 246 with self.cached_session() as sess: 247 request_tensors = [''] * 25 # This will launch 25 RPC requests. 248 response_tensors = self.rpc( 249 method=self.get_method_name('SleepForever'), 250 address=self._address, 251 request=request_tensors) 252 for timeout_ms in [1, 500, 1000]: 253 options = config_pb2.RunOptions(timeout_in_ms=timeout_ms) 254 with self.assertRaises((errors.UnavailableError, 255 errors.DeadlineExceededError)): 256 sess.run(response_tensors, options=options) 257 258 def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self): 259 with self.cached_session() as sess: 260 request_tensors = [''] * 25 # This will launch 25 RPC requests. 261 response_tensors = self.rpc( 262 method=self.get_method_name('SleepForever'), 263 address=self._address, 264 timeout_in_ms=1000, 265 request=request_tensors) 266 with self.assertRaises(errors.DeadlineExceededError): 267 sess.run(response_tensors) 268 269 def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self): 270 with self.cached_session() as sess: 271 response_tensors, status_code, status_message = self.try_rpc( 272 method=self.get_method_name('SometimesSleepForever'), 273 timeout_in_ms=1000, 274 address=self._address, 275 request=[''] * 20) 276 self.assertEqual(response_tensors.shape, (20,)) 277 self.assertEqual(status_code.shape, (20,)) 278 self.assertEqual(status_message.shape, (20,)) 279 status_code_values = sess.run(status_code) 280 self.assertTrue([ 281 x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values 282 ]) 283 284 def testTryRpcWithMultipleAddressesSingleRequest(self): 285 flatten = lambda x: list(itertools.chain.from_iterable(x)) 286 with self.cached_session() as sess: 287 addresses = flatten([[ 288 self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' 289 ] for _ in range(10)]) 290 request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() 291 response_tensors, status_code, _ = self.try_rpc( 292 method=self.get_method_name('Increment'), 293 address=addresses, 294 request=request) 295 response_tensors_values, status_code_values = sess.run((response_tensors, 296 status_code)) 297 self.assertAllEqual( 298 flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)), 299 status_code_values) 300 for i in range(10): 301 self.assertTrue(response_tensors_values[2 * i]) 302 self.assertFalse(response_tensors_values[2 * i + 1]) 303 304 def testTryRpcWithMultipleMethodsSingleRequest(self): 305 flatten = lambda x: list(itertools.chain.from_iterable(x)) 306 with self.cached_session() as sess: 307 methods = flatten( 308 [[self.get_method_name('Increment'), 'InvalidMethodName'] 309 for _ in range(10)]) 310 request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() 311 response_tensors, status_code, _ = self.try_rpc( 312 method=methods, address=self._address, request=request) 313 response_tensors_values, status_code_values = sess.run((response_tensors, 314 status_code)) 315 self.assertAllEqual( 316 flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)), 317 status_code_values) 318 for i in range(10): 319 self.assertTrue(response_tensors_values[2 * i]) 320 self.assertFalse(response_tensors_values[2 * i + 1]) 321 322 def testTryRpcWithMultipleAddressesAndRequests(self): 323 flatten = lambda x: list(itertools.chain.from_iterable(x)) 324 with self.cached_session() as sess: 325 addresses = flatten([[ 326 self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' 327 ] for _ in range(10)]) 328 requests = [ 329 test_example_pb2.TestCase( 330 values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) 331 ] 332 response_tensors, status_code, _ = self.try_rpc( 333 method=self.get_method_name('Increment'), 334 address=addresses, 335 request=requests) 336 response_tensors_values, status_code_values = sess.run((response_tensors, 337 status_code)) 338 self.assertAllEqual( 339 flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)), 340 status_code_values) 341 for i in range(20): 342 if i % 2 == 1: 343 self.assertFalse(response_tensors_values[i]) 344 else: 345 response_message = test_example_pb2.TestCase() 346 self.assertTrue( 347 response_message.ParseFromString(response_tensors_values[i])) 348 self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) 349