1// Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================== 15// 16// Distributed XLA service protocol. 17// 18// This is a minimal distributed protocol intended for a small set of purposes 19// * barriers to wait for all clients to start up or shut down 20// * health checking to detect when clients vanish 21// * for sharing GPU topology and NCCL communicator state between distributed 22// hosts. 23// 24// The intention is that a service is started during cluster initialization and 25// persists for the lifetime of the cluster. 26 27syntax = "proto3"; 28 29package xla; 30 31option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/compiler/" 32 "xla/pjrt/distributed/protocol_go_proto"; 33 34// Describes a device local to a host. 35message DeviceProto { 36 int32 local_device_ordinal = 1; 37 string name = 2; 38 string vendor = 3; 39 40 // The following fields are present in the GlobalTopologyProto message 41 // returned by EnumerateDevices() but not in the LocalTopologyProto messages 42 // passed to EnumerateDevices(). In other words, the coordinator node 43 // determines the global device IDs during EnumerateDevices(). 44 int32 global_device_id = 4; // Globally unique ID number. 45} 46 47message LocalTopologyProto { 48 int32 node_id = 1; 49 repeated DeviceProto devices = 2; 50} 51 52message GlobalTopologyProto { 53 repeated LocalTopologyProto nodes = 1; 54} 55 56message ConnectRequest { 57 int32 protocol_version = 1; 58 int32 timeout_milliseconds = 2; 59 60 // We assume that each node knows its globally-unique node ID, provided by 61 // whatever mechanism launches the tasks. Node IDs should form a dense range 62 // of integers [0, num_nodes). 63 int32 node_id = 3; 64 65 // A unique ID number for the client. 66 uint64 client_id = 4; 67} 68 69message ConnectResponse { 70 uint64 session_id = 1; 71} 72 73message EnumerateDevicesRequest { 74 uint64 session_id = 1; 75 LocalTopologyProto local_topology = 3; 76} 77 78message EnumerateDevicesResponse { 79 GlobalTopologyProto global_topology = 1; 80} 81 82message KeyValueGetRequest { 83 uint64 session_id = 1; 84 bytes key = 2; 85 int32 timeout_milliseconds = 3; 86} 87 88message KeyValueGetResponse { 89 bool found = 1; 90 bytes value = 2; 91} 92 93message KeyValueSetRequest { 94 uint64 session_id = 1; 95 bytes key = 2; 96 bytes value = 3; 97} 98 99message KeyValueSetResponse {} 100 101message WaitAtBarrierRequest { 102 uint64 session_id = 1; 103 bytes barrier_id = 2; 104 int32 node_id = 3; 105 int32 timeout_milliseconds = 4; 106} 107 108message WaitAtBarrierResponse {} 109 110message HeartbeatRequest { 111 uint64 session_id = 1; 112 int32 node_id = 2; 113} 114message HeartbeatResponse {} 115 116message ShutdownRequest { 117 uint64 session_id = 1; 118 int32 node_id = 2; 119} 120message ShutdownResponse {} 121 122service DistributedRuntimeService { 123 // Connects a node to the distributed coordinator node. Blocks until all tasks 124 // have connected. The service receives the number of nodes to expect as an 125 // option passed to its constructor. 126 rpc Connect(ConnectRequest) returns (ConnectResponse) {} 127 128 // Blocking enumeration of devices, used by the GPU backend only. 129 // In parallel, all clients call EnumerateDevices() with their local device 130 // topology, and receive back a global topology in response. 131 rpc EnumerateDevices(EnumerateDevicesRequest) 132 returns (EnumerateDevicesResponse) {} 133 134 // Health-checking RPC. Workers send heartbeats to the coordinator at regular 135 // intervals. If the worker does not hear from the coordinator or the 136 // coordinator does not hear from the tasks, the tasks abort. 137 rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) {} 138 139 // Shutdown RPC. Workers send this RPC when they are ready to shut down; the 140 // RPC blocks until all tasks have indicated they are ready to shut down, 141 // or a timeout is reached. 142 rpc Shutdown(ShutdownRequest) returns (ShutdownResponse) {} 143 144 // Simple key-value store used for sharing configuration data. 145 // For example, when using NCCL to communicate between multiple GPUs, 146 // the NCCL communicator IDs are stored here. 147 148 // Looks up a key in the key-value service. Blocks until the key is present 149 // or until `timeout` expires. 150 rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {} 151 152 // Updates the value associated with a key. 153 rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {} 154 155 // Blocks until all nodes are at the barrier or the barrier times out. 156 rpc WaitAtBarrier(WaitAtBarrierRequest) returns (WaitAtBarrierResponse) {} 157} 158