• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// ==============================================================================
15//
16// Distributed XLA service protocol.
17//
18// This is a minimal distributed protocol intended for a small set of purposes
19// * barriers to wait for all clients to start up or shut down
20// * health checking to detect when clients vanish
21// * for sharing GPU topology and NCCL communicator state between distributed
22//   hosts.
23//
24// The intention is that a service is started during cluster initialization and
25// persists for the lifetime of the cluster.
26
27syntax = "proto3";
28
29package xla;
30
31option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/compiler/"
32                    "xla/pjrt/distributed/protocol_go_proto";
33
34// Describes a device local to a host.
35message DeviceProto {
36  int32 local_device_ordinal = 1;
37  string name = 2;
38  string vendor = 3;
39
40  // The following fields are present in the GlobalTopologyProto message
41  // returned by EnumerateDevices() but not in the LocalTopologyProto messages
42  // passed to EnumerateDevices(). In other words, the coordinator node
43  // determines the global device IDs during EnumerateDevices().
44  int32 global_device_id = 4;  // Globally unique ID number.
45}
46
47message LocalTopologyProto {
48  int32 node_id = 1;
49  repeated DeviceProto devices = 2;
50}
51
52message GlobalTopologyProto {
53  repeated LocalTopologyProto nodes = 1;
54}
55
56message ConnectRequest {
57  int32 protocol_version = 1;
58  int32 timeout_milliseconds = 2;
59
60  // We assume that each node knows its globally-unique node ID, provided by
61  // whatever mechanism launches the tasks. Node IDs should form a dense range
62  // of integers [0, num_nodes).
63  int32 node_id = 3;
64
65  // A unique ID number for the client.
66  uint64 client_id = 4;
67}
68
69message ConnectResponse {
70  uint64 session_id = 1;
71}
72
73message EnumerateDevicesRequest {
74  uint64 session_id = 1;
75  LocalTopologyProto local_topology = 3;
76}
77
78message EnumerateDevicesResponse {
79  GlobalTopologyProto global_topology = 1;
80}
81
82message KeyValueGetRequest {
83  uint64 session_id = 1;
84  bytes key = 2;
85  int32 timeout_milliseconds = 3;
86}
87
88message KeyValueGetResponse {
89  bool found = 1;
90  bytes value = 2;
91}
92
93message KeyValueSetRequest {
94  uint64 session_id = 1;
95  bytes key = 2;
96  bytes value = 3;
97}
98
99message KeyValueSetResponse {}
100
101message WaitAtBarrierRequest {
102  uint64 session_id = 1;
103  bytes barrier_id = 2;
104  int32 node_id = 3;
105  int32 timeout_milliseconds = 4;
106}
107
108message WaitAtBarrierResponse {}
109
110message HeartbeatRequest {
111  uint64 session_id = 1;
112  int32 node_id = 2;
113}
114message HeartbeatResponse {}
115
116message ShutdownRequest {
117  uint64 session_id = 1;
118  int32 node_id = 2;
119}
120message ShutdownResponse {}
121
122service DistributedRuntimeService {
123  // Connects a node to the distributed coordinator node. Blocks until all tasks
124  // have connected. The service receives the number of nodes to expect as an
125  // option passed to its constructor.
126  rpc Connect(ConnectRequest) returns (ConnectResponse) {}
127
128  // Blocking enumeration of devices, used by the GPU backend only.
129  // In parallel, all clients call EnumerateDevices() with their local device
130  // topology, and receive back a global topology in response.
131  rpc EnumerateDevices(EnumerateDevicesRequest)
132      returns (EnumerateDevicesResponse) {}
133
134  // Health-checking RPC. Workers send heartbeats to the coordinator at regular
135  // intervals. If the worker does not hear from the coordinator or the
136  // coordinator does not hear from the tasks, the tasks abort.
137  rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) {}
138
139  // Shutdown RPC. Workers send this RPC when they are ready to shut down; the
140  // RPC blocks until all tasks have indicated they are ready to shut down,
141  // or a timeout is reached.
142  rpc Shutdown(ShutdownRequest) returns (ShutdownResponse) {}
143
144  // Simple key-value store used for sharing configuration data.
145  // For example, when using NCCL to communicate between multiple GPUs,
146  // the NCCL communicator IDs are stored here.
147
148  // Looks up a key in the key-value service. Blocks until the key is present
149  // or until `timeout` expires.
150  rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {}
151
152  // Updates the value associated with a key.
153  rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {}
154
155  // Blocks until all nodes are at the barrier or the barrier times out.
156  rpc WaitAtBarrier(WaitAtBarrierRequest) returns (WaitAtBarrierResponse) {}
157}
158