• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for RpcOp."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ctypes as ct
22import os
23
24import grpc
25from grpc.framework.foundation import logging_pool
26import portpicker
27
28from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
29from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_servicer
30from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
31from tensorflow.python.platform import test
32
33
34class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
35  _protocol = 'grpc'
36
37  invalid_method_string = 'Method not found'
38  connect_failed_string = 'Connect Failed'
39
40  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
41    super(RpcOpTest, self).__init__(methodName)
42    lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
43    if os.path.isfile(lib):
44      ct.cdll.LoadLibrary(lib)
45
46  def get_method_name(self, suffix):
47    return '/tensorflow.contrib.rpc.TestCaseService/%s' % suffix
48
49  def setUp(self):
50    super(RpcOpTest, self).setUp()
51
52    service_port = portpicker.pick_unused_port()
53
54    server = grpc.server(logging_pool.pool(max_workers=25))
55    servicer = rpc_op_test_servicer.RpcOpTestServicer()
56    test_example_pb2_grpc.add_TestCaseServiceServicer_to_server(
57        servicer, server)
58    self._address = 'localhost:%d' % service_port
59    server.add_insecure_port(self._address)
60    server.start()
61    self._server = server
62
63  def tearDown(self):
64    self._server.stop(grace=None)
65    super(RpcOpTest, self).tearDown()
66
67
68if __name__ == '__main__':
69  test.main()
70