• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Base class for RpcOp tests."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23import numpy as np
24
25from tensorflow.contrib.proto.python.ops import decode_proto_op
26from tensorflow.contrib.proto.python.ops import encode_proto_op
27from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2
28from tensorflow.contrib.rpc.python.ops import rpc_op
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32
33__all__ = ['I_WARNED_YOU', 'RpcOpTestBase']
34
35I_WARNED_YOU = 'I warned you!'
36
37
38class RpcOpTestBase(object):
39  # pylint: disable=missing-docstring,invalid-name
40  """Base class for RpcOp tests."""
41
42  def get_method_name(self, suffix):
43    raise NotImplementedError
44
45  def rpc(self, *args, **kwargs):
46    return rpc_op.rpc(*args, protocol=self._protocol, **kwargs)
47
48  def try_rpc(self, *args, **kwargs):
49    return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
50
51  def testScalarHostPortRpc(self):
52    with self.cached_session() as sess:
53      request_tensors = (
54          test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
55      response_tensors = self.rpc(
56          method=self.get_method_name('Increment'),
57          address=self._address,
58          request=request_tensors)
59      self.assertEqual(response_tensors.shape, ())
60      response_values = sess.run(response_tensors)
61    response_message = test_example_pb2.TestCase()
62    self.assertTrue(response_message.ParseFromString(response_values))
63    self.assertAllEqual([2, 3, 4], response_message.values)
64
65  def testScalarHostPortTryRpc(self):
66    with self.cached_session() as sess:
67      request_tensors = (
68          test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
69      response_tensors, status_code, status_message = self.try_rpc(
70          method=self.get_method_name('Increment'),
71          address=self._address,
72          request=request_tensors)
73      self.assertEqual(status_code.shape, ())
74      self.assertEqual(status_message.shape, ())
75      self.assertEqual(response_tensors.shape, ())
76      response_values, status_code_values, status_message_values = (
77          sess.run((response_tensors, status_code, status_message)))
78    response_message = test_example_pb2.TestCase()
79    self.assertTrue(response_message.ParseFromString(response_values))
80    self.assertAllEqual([2, 3, 4], response_message.values)
81    # For the base Rpc op, don't expect to get error status back.
82    self.assertEqual(errors.OK, status_code_values)
83    self.assertEqual(b'', status_message_values)
84
85  def testEmptyHostPortRpc(self):
86    with self.cached_session() as sess:
87      request_tensors = []
88      response_tensors = self.rpc(
89          method=self.get_method_name('Increment'),
90          address=self._address,
91          request=request_tensors)
92      self.assertAllEqual(response_tensors.shape, [0])
93      response_values = sess.run(response_tensors)
94    self.assertAllEqual(response_values.shape, [0])
95
96  def testInvalidMethod(self):
97    for method in [
98        '/InvalidService.Increment',
99        self.get_method_name('InvalidMethodName')
100    ]:
101      with self.cached_session() as sess:
102        with self.assertRaisesOpError(self.invalid_method_string):
103          sess.run(self.rpc(method=method, address=self._address, request=''))
104
105        _, status_code_value, status_message_value = sess.run(
106            self.try_rpc(method=method, address=self._address, request=''))
107        self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
108        self.assertTrue(
109            self.invalid_method_string in status_message_value.decode('ascii'))
110
111  def testInvalidAddress(self):
112    # This covers the case of address='' and address='localhost:293874293874'
113    address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
114    with self.cached_session() as sess:
115      with self.assertRaises(errors.UnavailableError):
116        sess.run(
117            self.rpc(
118                method=self.get_method_name('Increment'),
119                address=address,
120                request=''))
121      _, status_code_value, status_message_value = sess.run(
122          self.try_rpc(
123              method=self.get_method_name('Increment'),
124              address=address,
125              request=''))
126      self.assertEqual(errors.UNAVAILABLE, status_code_value)
127      self.assertTrue(
128          self.connect_failed_string in status_message_value.decode('ascii'))
129
130  def testAlwaysFailingMethod(self):
131    with self.cached_session() as sess:
132      response_tensors = self.rpc(
133          method=self.get_method_name('AlwaysFailWithInvalidArgument'),
134          address=self._address,
135          request='')
136      self.assertEqual(response_tensors.shape, ())
137      with self.assertRaisesOpError(I_WARNED_YOU):
138        sess.run(response_tensors)
139
140      response_tensors, status_code, status_message = self.try_rpc(
141          method=self.get_method_name('AlwaysFailWithInvalidArgument'),
142          address=self._address,
143          request='')
144      self.assertEqual(response_tensors.shape, ())
145      self.assertEqual(status_code.shape, ())
146      self.assertEqual(status_message.shape, ())
147      status_code_value, status_message_value = sess.run((status_code,
148                                                          status_message))
149      self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)
150      self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
151
152  def testSometimesFailingMethodWithManyRequests(self):
153    with self.cached_session() as sess:
154      # Fail hard by default.
155      response_tensors = self.rpc(
156          method=self.get_method_name('SometimesFailWithInvalidArgument'),
157          address=self._address,
158          request=[''] * 20)
159      self.assertEqual(response_tensors.shape, (20,))
160      with self.assertRaisesOpError(I_WARNED_YOU):
161        sess.run(response_tensors)
162
163      # Don't fail hard, use TryRpc - return the failing status instead.
164      response_tensors, status_code, status_message = self.try_rpc(
165          method=self.get_method_name('SometimesFailWithInvalidArgument'),
166          address=self._address,
167          request=[''] * 20)
168      self.assertEqual(response_tensors.shape, (20,))
169      self.assertEqual(status_code.shape, (20,))
170      self.assertEqual(status_message.shape, (20,))
171      status_code_values, status_message_values = sess.run((status_code,
172                                                            status_message))
173      self.assertTrue([
174          x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values
175      ])
176      expected_message_values = np.where(
177          status_code_values == errors.INVALID_ARGUMENT,
178          I_WARNED_YOU.encode('ascii'), b'')
179      for msg, expected in zip(status_message_values, expected_message_values):
180        self.assertTrue(expected in msg,
181                        '"%s" did not contain "%s"' % (msg, expected))
182
183  def testVecHostPortRpc(self):
184    with self.cached_session() as sess:
185      request_tensors = [
186          test_example_pb2.TestCase(
187              values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
188      ]
189      response_tensors = self.rpc(
190          method=self.get_method_name('Increment'),
191          address=self._address,
192          request=request_tensors)
193      self.assertEqual(response_tensors.shape, (20,))
194      response_values = sess.run(response_tensors)
195    self.assertEqual(response_values.shape, (20,))
196    for i in range(20):
197      response_message = test_example_pb2.TestCase()
198      self.assertTrue(response_message.ParseFromString(response_values[i]))
199      self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
200
201  def testVecHostPortManyParallelRpcs(self):
202    with self.cached_session() as sess:
203      request_tensors = [
204          test_example_pb2.TestCase(
205              values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
206      ]
207      many_response_tensors = [
208          self.rpc(
209              method=self.get_method_name('Increment'),
210              address=self._address,
211              request=request_tensors) for _ in range(10)
212      ]
213      # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
214      many_response_values = sess.run(many_response_tensors)
215    self.assertEqual(10, len(many_response_values))
216    for response_values in many_response_values:
217      self.assertEqual(response_values.shape, (20,))
218      for i in range(20):
219        response_message = test_example_pb2.TestCase()
220        self.assertTrue(response_message.ParseFromString(response_values[i]))
221        self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
222
223  def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
224    with self.cached_session() as sess:
225      request_tensors = encode_proto_op.encode_proto(
226          message_type='tensorflow.contrib.rpc.TestCase',
227          field_names=['values'],
228          sizes=[[3]] * 20,
229          values=[
230              [[i, i + 1, i + 2] for i in range(20)],
231          ])
232      response_tensor_strings = self.rpc(
233          method=self.get_method_name('Increment'),
234          address=self._address,
235          request=request_tensors)
236      _, (response_shape,) = decode_proto_op.decode_proto(
237          bytes=response_tensor_strings,
238          message_type='tensorflow.contrib.rpc.TestCase',
239          field_names=['values'],
240          output_types=[dtypes.int32])
241      response_shape_values = sess.run(response_shape)
242    self.assertAllEqual([[i + 1, i + 2, i + 3]
243                         for i in range(20)], response_shape_values)
244
245  def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
246    with self.cached_session() as sess:
247      request_tensors = [''] * 25  # This will launch 25 RPC requests.
248      response_tensors = self.rpc(
249          method=self.get_method_name('SleepForever'),
250          address=self._address,
251          request=request_tensors)
252      for timeout_ms in [1, 500, 1000]:
253        options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)
254        with self.assertRaises((errors.UnavailableError,
255                                errors.DeadlineExceededError)):
256          sess.run(response_tensors, options=options)
257
258  def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
259    with self.cached_session() as sess:
260      request_tensors = [''] * 25  # This will launch 25 RPC requests.
261      response_tensors = self.rpc(
262          method=self.get_method_name('SleepForever'),
263          address=self._address,
264          timeout_in_ms=1000,
265          request=request_tensors)
266      with self.assertRaises(errors.DeadlineExceededError):
267        sess.run(response_tensors)
268
269  def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
270    with self.cached_session() as sess:
271      response_tensors, status_code, status_message = self.try_rpc(
272          method=self.get_method_name('SometimesSleepForever'),
273          timeout_in_ms=1000,
274          address=self._address,
275          request=[''] * 20)
276      self.assertEqual(response_tensors.shape, (20,))
277      self.assertEqual(status_code.shape, (20,))
278      self.assertEqual(status_message.shape, (20,))
279      status_code_values = sess.run(status_code)
280      self.assertTrue([
281          x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values
282      ])
283
284  def testTryRpcWithMultipleAddressesSingleRequest(self):
285    flatten = lambda x: list(itertools.chain.from_iterable(x))
286    with self.cached_session() as sess:
287      addresses = flatten([[
288          self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
289      ] for _ in range(10)])
290      request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
291      response_tensors, status_code, _ = self.try_rpc(
292          method=self.get_method_name('Increment'),
293          address=addresses,
294          request=request)
295      response_tensors_values, status_code_values = sess.run((response_tensors,
296                                                              status_code))
297      self.assertAllEqual(
298          flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
299          status_code_values)
300      for i in range(10):
301        self.assertTrue(response_tensors_values[2 * i])
302        self.assertFalse(response_tensors_values[2 * i + 1])
303
304  def testTryRpcWithMultipleMethodsSingleRequest(self):
305    flatten = lambda x: list(itertools.chain.from_iterable(x))
306    with self.cached_session() as sess:
307      methods = flatten(
308          [[self.get_method_name('Increment'), 'InvalidMethodName']
309           for _ in range(10)])
310      request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
311      response_tensors, status_code, _ = self.try_rpc(
312          method=methods, address=self._address, request=request)
313      response_tensors_values, status_code_values = sess.run((response_tensors,
314                                                              status_code))
315      self.assertAllEqual(
316          flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
317          status_code_values)
318      for i in range(10):
319        self.assertTrue(response_tensors_values[2 * i])
320        self.assertFalse(response_tensors_values[2 * i + 1])
321
322  def testTryRpcWithMultipleAddressesAndRequests(self):
323    flatten = lambda x: list(itertools.chain.from_iterable(x))
324    with self.cached_session() as sess:
325      addresses = flatten([[
326          self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
327      ] for _ in range(10)])
328      requests = [
329          test_example_pb2.TestCase(
330              values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
331      ]
332      response_tensors, status_code, _ = self.try_rpc(
333          method=self.get_method_name('Increment'),
334          address=addresses,
335          request=requests)
336      response_tensors_values, status_code_values = sess.run((response_tensors,
337                                                              status_code))
338      self.assertAllEqual(
339          flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
340          status_code_values)
341      for i in range(20):
342        if i % 2 == 1:
343          self.assertFalse(response_tensors_values[i])
344        else:
345          response_message = test_example_pb2.TestCase()
346          self.assertTrue(
347              response_message.ParseFromString(response_tensors_values[i]))
348          self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
349