#!/usr/bin/env python3 # Owner(s): ["oncall: distributed"] import json import os import pickle import socket import tempfile from contextlib import contextmanager from typing import Dict from urllib3.connection import HTTPConnection from urllib3.connectionpool import HTTPConnectionPool from torch.distributed.elastic.control_plane import ( TORCH_WORKER_SERVER_SOCKET, worker_main, ) from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase class UnixHTTPConnection(HTTPConnection): def __init__(self, socket_path: str) -> None: super().__init__("localhost") self.socket_path = socket_path def connect(self) -> None: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.connect(self.socket_path) class UnixHTTPConnectionPool(HTTPConnectionPool): def __init__(self, socket_path: str) -> None: super().__init__("localhost") self.socket_path = socket_path def _new_conn(self): return UnixHTTPConnection(self.socket_path) @contextmanager def local_worker_server() -> None: with tempfile.TemporaryDirectory() as tmpdir: socket_path = os.path.join(tmpdir, "socket.sock") os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path with worker_main(): pool = UnixHTTPConnectionPool(socket_path) yield pool class WorkerServerTest(TestCase): def test_worker_server(self) -> None: with local_worker_server() as pool: resp = pool.request("GET", "/") self.assertEqual(resp.status, 200) self.assertEqual( resp.data, b"""

torch.distributed.WorkerServer

Handler names """, ) resp = pool.request("POST", "/handler/ping") self.assertEqual(resp.status, 200) self.assertEqual(resp.data, b"pong") resp = pool.request("GET", "/handler/") self.assertEqual(resp.status, 200) self.assertIn("ping", json.loads(resp.data)) resp = pool.request("POST", "/handler/nonexistant") self.assertEqual(resp.status, 404) self.assertIn(b"Handler nonexistant not found:", resp.data) @requires_cuda def test_dump_nccl_trace_pickle(self) -> None: with local_worker_server() as pool: resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") self.assertEqual(resp.status, 200) out = pickle.loads(resp.data) self.assertIsInstance(out, dict) self.assertIn("version", out) @requires_cuda def test_dump_nccl_trace_pickle_with_params(self) -> None: with local_worker_server() as pool: # bad key - not lower case resp = pool.request( "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" ) self.assertEqual(resp.status, 400) # unknown key resp = pool.request( "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" ) self.assertEqual(resp.status, 400) # bad value - not a bool resp = pool.request( "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" ) self.assertEqual(resp.status, 400) # bad value - value not lowercase resp = pool.request( "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" ) self.assertEqual(resp.status, 400) # good key and value resp = pool.request( "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" ) self.assertEqual(resp.status, 200) # multiple good keys and values resp = pool.request( "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", ) self.assertEqual(resp.status, 200) @requires_cuda def test_dump_nccl_trace_pickle_with_json(self) -> None: with local_worker_server() as pool: # bad key - not lower case resp = pool.request( "POST", "/handler/dump_nccl_trace_json?includeCollectives=true" ) self.assertEqual(resp.status, 400) # unknown key resp = pool.request("POST", "/handler/dump_nccl_trace_json?unknownkey=true") self.assertEqual(resp.status, 400) # bad value - not a bool resp = pool.request( "POST", "/handler/dump_nccl_trace_json?includecollectives=notabool" ) self.assertEqual(resp.status, 400) # bad value - value not lowercase resp = pool.request( "POST", "/handler/dump_nccl_trace_json?includecollectives=True" ) self.assertEqual(resp.status, 400) # good key and value resp = pool.request( "POST", "/handler/dump_nccl_trace_json?includecollectives=true" ) self.assertEqual(resp.status, 200) # multiple good keys and values resp = pool.request( "POST", "/handler/dump_nccl_trace_json?includecollectives=true&onlyactive=true", ) self.assertEqual(resp.status, 200) def test_tcp(self) -> None: import requests from torch._C._distributed_c10d import _WorkerServer server = _WorkerServer("", 1234) out = requests.get("http://localhost:1234/handler/") self.assertEqual(out.status_code, 200) server.shutdown() def test_dump_traceback(self) -> None: with local_worker_server() as pool: resp = pool.request("POST", "/handler/dump_traceback") self.assertEqual(resp.status, 200) self.assertIn(b"in test_dump_traceback\n", resp.data) def test_run_handler(self) -> None: from torch._C._distributed_c10d import _get_handler, _Request, _Response handler = _get_handler("ping") class Request(_Request): def __init__(self) -> None: _Request.__init__(self) def body(self) -> bytes: return b"dummy" def params(self) -> Dict[str, str]: return {} class Response(_Response): def __init__(self) -> None: _Response.__init__(self) def set_content(self, content: str, content_type: str) -> None: self.content = content self.content_type = content_type def set_status(self, status: int) -> None: self.status = status req = Request() resp = Response() handler(req, resp) self.assertEqual(resp.status, 200) self.assertEqual(resp.content, "pong") self.assertEqual(resp.content_type, "text/plain") def test_get_handler_nonexistant(self) -> None: from torch._C._distributed_c10d import _get_handler with self.assertRaisesRegex(ValueError, "Failed to find handler nonexistant"): _get_handler("nonexistant") def test_get_handler_names(self) -> None: from torch._C._distributed_c10d import _get_handler_names names = _get_handler_names() self.assertIn("ping", names) if __name__ == "__main__": run_tests()