1syntax = "proto3"; 2 3package xrt; 4 5import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; 6import "tensorflow/compiler/xla/service/hlo.proto"; 7import "tensorflow/compiler/xla/xla.proto"; 8import "tensorflow/compiler/xla/xla_data.proto"; 9 10message DeviceAssignment { 11 message ComputationDevice { 12 message DeviceMeshCoordinates { 13 // The mesh coordinates for the device. Usually (X, Y, Z, Core), in the 14 // order in which they are returned in the TopologyProto. 15 // X = value(0) 16 // Y = value(1) 17 // Z = value(2) 18 // Core = value(3) 19 repeated int32 value = 1; 20 } 21 // As many replicas as there are in the replicated computation. 22 repeated DeviceMeshCoordinates replica_devices = 1; 23 } 24 // As many ComputationDevice as many there are computations (number 25 // of cores per replica). 26 repeated ComputationDevice computation_devices = 1; 27} 28 29// Options for an XLA compilation. 30message XLAComputationConfig { 31 // The number of replicas the computation will be run on. If this is 32 // default (0) it is interpreted as 1. 33 int32 num_replicas = 1; 34 // The number of "model-parallel" cores per replica. If this is 35 // default (0) it is interpreted as 1. 36 int32 num_cores_per_replica = 2; 37 // Optional metadata about host sends and recvs. 38 tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; 39 40 // The arg/result shapes for the whole computation. 41 xla.ProgramShapeProto program_shape = 4; 42 // The arg/result shapes for each core of a model-parallel 43 // computation. per_core_args_and_result_shapes is optional for a 44 // single-core computation. 45 repeated xla.ProgramShapeProto per_core_program_shape = 5; 46 // Describes how replicated computation instances should be assigned to 47 // devices. There are num_cores_per_replica computations, and each one will be 48 // sent and executed to the set of replica device numbers described in the 49 // DeviceAssignment proto. 50 DeviceAssignment device_assignment = 6; 51 // The debugging options to be passed to the XLA compilation process. 52 xla.DebugOptions debug_options = 7; 53 54 // Everything inside Experimental is subject to change and is not subject 55 // to API stability guarantees in 56 // https://www.tensorflow.org/guide/version_compat. 57 message Experimental { 58 message UpdateIndexPair { 59 int32 index = 1; 60 bool updated = 2; 61 } 62 63 // stateful_input_indices is only useful when using XRT-compiled 64 // programs together with standard TensorFlow TPU execution ops, so should 65 // be ignored by most clients. 66 // 67 // Optionally the client can pass information about which inputs 68 // to the computation are updates to "stateful" quantities. Each 69 // element of stateful_input_indices includes an index indicating 70 // which input argument it corresponds to, and a bool indicating 71 // whether the value is updated or not. If the XRT computation is 72 // going to be used with a TensorFlow TPU execution op then an 73 // input index must be present for each input that will correspond 74 // to a resource variable in the execution op, and may not be 75 // present for any other input. 76 repeated UpdateIndexPair stateful_input_indices = 1; 77 } 78 79 Experimental experimental = 8; 80} 81 82// Options and XLA computation for a compilation. 83message XLAComputation { 84 XLAComputationConfig config = 1; 85 xla.HloSnapshot hlo_snapshot = 2; 86} 87 88// Literal to allocate space for, and transfer to, device memory. 89message XLAAllocation { 90 reserved 1; 91 xla.LiteralProto value = 2; 92} 93 94// Node in a tree describing a tuple constructed from input handles. A 95// node is an internal node if tuples is non-empty, in which case 96// input_index and release_input_handle are ignored. Otherwise a node 97// is a leaf node. Each leaf XLATupleNode is the index of an input 98// which corresponds to a handle that will be grafted onto the output 99// tuple at that location. If release_input_handle is true that input 100// handle will be released and become invalid. Inputs may be repeated 101// in which case leaves of the output tuple will alias. If an input is 102// repeated, release_input_handle must be false for every leaf where 103// that input appears. 104// 105// For example, if input 0 has shape {} and input 1 has shape {2,3} 106// then the XLATupleNode with structure {1,{0,1}} corresponds to a 107// tuple with shape {{2,3},{{},{2,3}}}. 108message XLATupleNode { 109 int32 input_index = 1; 110 bool release_input_handle = 2; 111 repeated XLATupleNode tuples = 3; 112} 113 114message CommonExecutionConfig { 115 // The replica index this execute is driving. 116 int32 replica_id = 1; 117 // Mapping local device ordinals to global replica IDs. 118 // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID 119 repeated int32 local_replica_mapping = 2; 120 // The execution run ID used to correlate different XRT execute operations 121 // happeining in parallel from different threads. 122 int64 run_id = 3; 123} 124 125// Options for an XLA execution. 126message XRTExecutionConfig { 127 // Local device to run on. This is present because the execute Op 128 // may be placed on a device such as CPU or TPU_SYSTEM that 129 // logically manages multiple cores. 130 int32 device_ordinal = 1; 131 // Which model-parallel computation to run from the compiled bundle. 132 int32 core_index_in_replica = 2; 133 // Optional key to disambiguate between executions. This is only 134 // needed if multiple host send/recvs may be outstanding 135 // concurrently with executions. 136 string execution_instance_key = 3; 137 // If non-zero, rng_seed to reset the core with. 138 uint32 rng_seed = 4; 139 // If true, release allocation handles on the inputs after running. 140 bool release_input_handles = 5; 141 // If true, release the handle to the computation after running. 142 bool release_compilation_handle = 6; 143 // If set to true, and the result shape is a tuple, then instead of returning 144 // a single tuple allocation the execution will return a vector of 145 // allocations, one for each of the first-level elements of the result tuple. 146 bool return_exploded_tuple = 7; 147 reserved 8; 148 // The common configuration for XRT execute operations. 149 CommonExecutionConfig common_config = 9; 150} 151 152message XRTChainedExecuteConfig { 153 // If non-zero, rng_seed to reset the core with. 154 uint32 rng_seed = 1; 155 // Which model-parallel computation to run from the compiled bundle. 156 int32 core_index_in_replica = 2; 157 // Optional key to disambiguate between executions. This is only needed if 158 // multiple host send/recvs may be outstanding concurrently with executions. 159 string execution_instance_key = 3; 160 reserved 4; 161 // The common configuration for XRT execute operations. 162 CommonExecutionConfig common_config = 5; 163} 164 165// A single chained execute operation. An operation can either be a device data 166// load, or an existing (as in, previously compiled and accessible via its int64 167// handle) XLA computation execution. 168message XRTChainedExecuteOp { 169 // Represents an input for this operation. 170 message Input { 171 // The index within the XRTChainedExecutePlan.ops post-order of the source 172 // operation for this input. 173 int64 op_index = 1; 174 // The output index of the value generated by the operation at op_index. 175 // Zero (default value) means no index ({}) while if an indexing is 176 // required, output_index needs to be set to index+1. 177 // Thanks proto3! 178 int64 output_index = 2; 179 } 180 // Represents an output of the XRTChainedExecute operation, which should 181 // originate by the output of this operation. 182 message Output { 183 // The index in the value generated by this operation, which should be 184 // forwarded as XRTChainedExecute output. If output_index is zero (default 185 // value) the whole output will be used as result. This means that if the 186 // output shape is a tuple, the result will be the full tuple. Otherwise the 187 // real sub-tuple index will be output_index - 1. 188 int64 output_index = 1; 189 // The index in the vector of the results returned by the XRTChainedExecute 190 // operation, where this output should be forwarded. 191 int64 result_index = 2; 192 } 193 194 oneof op_oneof { 195 // The handle to an existing XRT device data. 196 int64 data_handle = 1; 197 // The handle to an existing XRT compiled computation. 198 int64 computation_handle = 2; 199 } 200 // The outputs of this XRTChainedExecuteOp operation. 201 repeated Output outputs = 3; 202 // The inputs of this XRTChainedExecuteOp operation. If data_handle is set, 203 // there are no inputs. 204 repeated Input inputs = 4; 205} 206 207// Execution plan for the XRTChainedExecute operation. 208message XRTChainedExecutePlan { 209 // The post order with the XRT computations to be executed. 210 repeated XRTChainedExecuteOp ops = 1; 211} 212 213// The message used to encode the options for the XRTMetricsCollect operation. 214message XRTMetricsCollect { 215 // A list of regular expressions to match the metric names. Empty means to 216 // return all the metrics reported by the collection registry. 217 repeated string metrics_regex = 1; 218} 219 220message Percentiles { 221 message Point { 222 // In the [0, 100] range. 223 double percentile = 1; 224 double value = 2; 225 } 226 227 // The time (in nanoseconds) of the first sample within the samples buffer. 228 uint64 start_nstime = 1; 229 // The time (in nanoseconds) of the last sample within the samples buffer. 230 uint64 end_nstime = 2; 231 // The minimum value of the samples within the samples buffer. 232 double min_value = 3; 233 // The maximum value of the samples within the samples buffer. 234 double max_value = 4; 235 // The mean value of the samples within the samples buffer. 236 double mean = 5; 237 // The stndard deviation of the samples within the samples buffer. 238 double stddev = 6; 239 // The number samples within the samples buffer. 240 uint64 num_samples = 7; 241 // The total number of times this metrics has been posted a value to. 242 uint64 total_samples = 8; 243 // The sum of all the posted values. 244 double accumulator = 9; 245 // The percentile points reported by the metric. 246 repeated Point points = 10; 247} 248 249message MetricValues { 250 enum UnitOfMeasure { 251 INVALID = 0; 252 NUMBER = 1; 253 TIME = 2; 254 BYTES = 3; 255 } 256 257 // The metric name. 258 string name = 1; 259 260 oneof values_oneof { 261 Percentiles percentiles_value = 2; 262 int64 int64_value = 3; 263 } 264 265 UnitOfMeasure unit_of_measure = 4; 266} 267 268message MetricsReport { 269 repeated MetricValues metrics = 1; 270} 271 272message MemoryInfo { 273 // The total memory on a device, in KB. 274 int64 kb_total = 1; 275 // The free memory on a device, in KB. 276 int64 kb_free = 2; 277} 278