1// Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================== 15// 16// Distributed XLA service protocol. 17// 18// This is a minimal distributed protocol intended for a small set of purposes 19// * barriers to wait for all clients to start up or shut down 20// * health checking to detect when clients vanish 21// * for sharing GPU topology and NCCL communicator state between distributed 22// hosts. 23// 24// The intention is that a service is started during cluster initialization and 25// persists for the lifetime of the cluster. 26 27syntax = "proto3"; 28 29package xla; 30 31// Describes a device local to a host. 32message DeviceProto { 33 int32 local_device_ordinal = 1; 34 string name = 2; 35 string vendor = 3; 36 37 // The following fields are present in the GlobalTopologyProto message 38 // returned by Connect() but not in the LocalTopologyProto messages passed to 39 // Connect(). In other words, the coordinator node determines the global 40 // device IDs during Connect(). 41 int32 global_device_id = 4; // Globally unique ID number. 42} 43 44message LocalTopologyProto { 45 int32 node_id = 1; 46 repeated DeviceProto devices = 2; 47} 48 49message GlobalTopologyProto { 50 repeated LocalTopologyProto nodes = 1; 51} 52 53message ConnectRequest { 54 int32 protocol_version = 1; 55 int32 timeout_milliseconds = 2; 56 57 // We assume that each node knows its globally-unique node ID, provided by 58 // whatever mechanism launches the tasks. Node IDs should form a dense range 59 // of integers [0, num_nodes). 60 int32 node_id = 3; 61 62 // A unique ID number for the client. 63 uint64 client_id = 4; 64} 65 66message ConnectResponse { 67 uint64 session_id = 1; 68} 69 70message EnumerateDevicesRequest { 71 uint64 session_id = 1; 72 LocalTopologyProto local_topology = 3; 73} 74 75message EnumerateDevicesResponse { 76 GlobalTopologyProto global_topology = 1; 77} 78 79message KeyValueGetRequest { 80 uint64 session_id = 1; 81 bytes key = 2; 82 int32 timeout_milliseconds = 3; 83} 84 85message KeyValueGetResponse { 86 bool found = 1; 87 bytes value = 2; 88} 89 90message KeyValueSetRequest { 91 uint64 session_id = 1; 92 bytes key = 2; 93 bytes value = 3; 94} 95 96message KeyValueSetResponse {} 97 98message HeartbeatRequest { 99 uint64 session_id = 1; 100 int32 node_id = 2; 101} 102message HeartbeatResponse {} 103 104message ShutdownRequest { 105 uint64 session_id = 1; 106 int32 node_id = 2; 107} 108message ShutdownResponse {} 109 110service DistributedRuntimeService { 111 // Connects a node to the distributed coordinator node. Blocks until all tasks 112 // have connected. The service receives the number of nodes to expect as an 113 // option passed to its constructor. 114 rpc Connect(ConnectRequest) returns (ConnectResponse) {} 115 116 // Blocking enumeration of devices, used by the GPU backend only. 117 // In parallel, all clients call EnumerateDevices() with their local device 118 // topology, and receive back a global topology in response. 119 rpc EnumerateDevices(EnumerateDevicesRequest) 120 returns (EnumerateDevicesResponse) {} 121 122 // Health-checking RPC. Workers send heartbeats to the coordinator at regular 123 // intervals. If the worker does not hear from the coordinator or the 124 // coordinator does not hear from the tasks, the tasks abort. 125 rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) {} 126 127 // Shutdown RPC. Workers send this RPC when they are ready to shut down; the 128 // RPC blocks until all tasks have indicated they are ready to shut down, 129 // or a timeout is reached. 130 rpc Shutdown(ShutdownRequest) returns (ShutdownResponse) {} 131 132 // Simple key-value store used for sharing configuration data. 133 // For example, when using NCCL to communicate between multiple GPUs, 134 // the NCCL communicator IDs are stored here. 135 136 // Looks up a key in the key-value service. Blocks until the key is present 137 // or until `timeout` expires. 138 rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {} 139 140 // Updates the value associated with a key. 141 rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {} 142} 143