1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16syntax = "proto3"; 17 18package xla; 19 20import "tensorflow/compiler/xla/service/hlo.proto"; 21import "tensorflow/compiler/xla/xla_data.proto"; 22 23// Debugging options for XLA. These options may change at any time - there are 24// no guarantees about backward or forward compatibility for these fields. 25message DebugOptions { 26 // Show addresses of HLO ops in graph dump. 27 bool xla_hlo_graph_addresses = 2; 28 29 // Instrument the computation to collect per-HLO cycle counts. 30 bool xla_hlo_profile = 9; 31 32 // List of HLO passes to disable/enable. These names must exactly match the 33 // pass names as specified by the HloPassInterface::name() method. 34 // 35 // At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must 36 // be empty. 37 repeated string xla_disable_hlo_passes = 30; 38 repeated string xla_enable_hlo_passes_only = 124; 39 40 // Disables all HLO passes. Notes that some passes are necessary for 41 // correctness and the invariants that must be satisfied by "fully optimized" 42 // HLO are different for different devices and may change over time. The only 43 // "guarantee", such as it is, is that if you compile XLA and dump the 44 // optimized HLO for some graph, you should be able to run it again on the 45 // same device with the same build of XLA. 46 bool xla_disable_all_hlo_passes = 104; 47 48 // Numerical optimization level for the XLA compiler backend; the specific 49 // interpretation of this value is left to the backends. 50 int32 xla_backend_optimization_level = 31; 51 52 // Embed the compiler IR as a string in the executable. 53 bool xla_embed_ir_in_executable = 33; 54 55 // Eliminate implicit broadcasts when lowering user computations to HLO 56 // instructions; use explicit broadcast instead. 57 bool xla_eliminate_hlo_implicit_broadcast = 35; 58 59 // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen 60 // mode. 61 bool xla_cpu_multi_thread_eigen = 60; 62 63 // Path to directory with cuda/ptx tools and libraries. 64 string xla_gpu_cuda_data_dir = 61; 65 66 // Enable flush-to-zero semantics in the GPU backend. 67 bool xla_gpu_ftz = 62; 68 69 // Has no effect. Multi-streaming used to be supported in the GPU backend but 70 // no longer is, so multi-streaming is always effectively disabled. 71 bool xla_gpu_disable_multi_streaming = 63; 72 73 // Debugging feature: if enabled, the GPU backend will assign HLO operators to 74 // randomly chosen streams. This is intended to trigger concurrency bugs. 75 bool xla_gpu_use_random_streams = 134; 76 77 // If true, in LLVM-based backends, emit !alias.scope metadata in 78 // generated IR. 79 bool xla_llvm_enable_alias_scope_metadata = 70; 80 81 // If true, in LLVM-based backends, emit !noalias metadata in the 82 // generated IR. 83 bool xla_llvm_enable_noalias_metadata = 71; 84 85 // If true, in LLVM-based backends, emit !invariant.load metadata in 86 // the generated IR. 87 bool xla_llvm_enable_invariant_load_metadata = 72; 88 89 // If true, a set of expensive LLVM optimization passes will not be run. 90 bool xla_llvm_disable_expensive_passes = 73; 91 92 reserved 80; // Was hlo_reduce_precision_options 93 94 // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the 95 // computation will run n! times with all permunations of layouts for the 96 // output shape in rank n. For example, with a 3D shape, all permutations of 97 // the set {0, 1, 2} are tried. 98 bool xla_test_all_output_layouts = 90; 99 100 // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the 101 // computation will run for all permunations of layouts of all input 102 // arguments. For example, with 2 input arguments in 2D and 4D shapes, the 103 // computation will run 2! * 4! times. 104 bool xla_test_all_input_layouts = 91; 105 106 // Assign colors based on sharding information when generating the Graphviz 107 // HLO graph. 108 bool xla_hlo_graph_sharding_color = 92; 109 110 reserved 93; // Was xla_hlo_tfgraph_device_scopes 111 reserved 94; // Was xla_gpu_use_cudnn_batchnorm 112 113 // Generate calls to MKL-DNN in the CPU backend. 114 bool xla_cpu_use_mkl_dnn = 97; 115 116 // Enable JitRt in the CPU backend. 117 bool xla_cpu_use_jitrt = 177; 118 119 // Maximum kernel unroll factor for the GPU backend. 120 int32 xla_gpu_max_kernel_unroll_factor = 98; 121 122 // When true, "unsafe" mathematical optimizations are enabled. These 123 // transformations include but are not limited to: 124 // 125 // - Reducing the precision of operations (e.g. using an approximate sin 126 // function, or transforming x/y into x * (1/y)). 127 // - Assuming that operations never produce or consume NaN or +/- Inf (this 128 // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). 129 // - Assuming that +0 and -0 are indistinguishable. 130 bool xla_cpu_enable_fast_math = 99; 131 132 // When xla_cpu_enable_fast_math is true then this controls whether we allow 133 // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is 134 // false. 135 bool xla_cpu_fast_math_honor_nans = 120; 136 137 // When xla_cpu_enable_fast_math is true then this controls whether we allow 138 // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is 139 // false. 140 bool xla_cpu_fast_math_honor_infs = 121; 141 142 // When xla_cpu_enable_fast_math is true then this controls whether we forbid 143 // to use the reciprocal of an argument instead of division. Ignored when 144 // xla_cpu_enable_fast_math is false. 145 bool xla_cpu_fast_math_honor_division = 126; 146 147 // When xla_cpu_enable_fast_math is true then this controls whether we forbid 148 // to approximate calculations for functions. Ignored when 149 // xla_cpu_enable_fast_math is false. 150 bool xla_cpu_fast_math_honor_functions = 129; 151 152 // When false we lower the Minimum and Maximum hlos in the CPU backend such 153 // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag 154 // this is false we always propagate NaNs through Min and Max. 155 // 156 // Note, this does not correspond to the exact same behavior as the gpu flag 157 // below! 158 bool xla_cpu_enable_fast_min_max = 140; 159 160 // When true we lower the Minimum and Maximum hlos in the GPU backend such 161 // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag 162 // this is true we don't propagate NaNs through Min and Max. 163 // 164 // Note, this does not correspond to the exact same behavior as the cpu flag 165 // above! 166 bool xla_gpu_enable_fast_min_max = 100; 167 168 // Allows xla to increase the output precision of floating point operations. 169 bool xla_allow_excess_precision = 122; 170 171 // Crashes the program when any kind of verification fails, instead of just 172 // logging the failures. One example is cross checking of convolution results 173 // among different algorithms. 174 bool xla_gpu_crash_on_verification_failures = 101; 175 176 // 0: Disable gemm and convolution autotuning. 177 // 1: Enable autotuning, but disable correctness checking. 178 // 2: Also set output buffers to random numbers during autotuning. 179 // 3: Also reset output buffers to random numbers after autotuning each 180 // algorithm. 181 // 4+: Also check for correct outputs and for out-of-bounds reads/writes. 182 // 183 // Default: 4. 184 int32 xla_gpu_autotune_level = 123; 185 186 // Force the host platform to pretend that there are these many host 187 // "devices". All these devices are backed by the same threadpool. Defaults 188 // to 1. 189 // 190 // Setting this to anything other than 1 can increase overhead from context 191 // switching but we let the user override this behavior to help run tests on 192 // the host that run models in parallel across multiple devices. 193 int32 xla_force_host_platform_device_count = 102; 194 195 // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). 196 bool xla_gpu_disable_gpuasm_optimizations = 103; 197 198 enum ShapeChecks { 199 // Do not insert any shape checks for dynamically shaped operations; output 200 // buffers might contain garbage data if shapes don't match. 201 IGNORE = 0; 202 203 // Check shapes at runtime, will insert an extra synchronization if shapes 204 // cannot be proven correct at compile time. 205 RUNTIME = 1; 206 207 // Will refuse to compile any program where shape correctness can not be 208 // established at compile time. 209 COMPILE_TIME = 2; 210 } 211 212 ShapeChecks xla_gpu_shape_checks = 170; 213 214 // Enable MLIR-based lowering in XLA:CPU instead of LLVM emitters. 215 bool xla_cpu_enable_mlir_lowering = 171; 216 217 // If true, use MLIR instead of IR emitter to generate device code for 218 // supported lmhlo.fusion ops. See xla::gpu::RewriteFusionOps() for details. 219 bool xla_gpu_enable_mlir_lowering = 173; 220 221 // Enable fast math with eigen in the HLO evaluator. 222 bool xla_hlo_evaluator_use_fast_path = 106; 223 224 // Temporary option to allow support for both the R1 and the scalar index 225 // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. 226 bool xla_allow_scalar_index_dynamic_ops = 107; 227 228 enum StepMarkerLocation { 229 // Generate a step marker at the program entry. This handles the case where 230 // each step is done by one or multiple program execution(s). Only the first 231 // program will be tagged for generating a step marker at the program entry. 232 // This is the default. 233 STEP_MARK_AT_ENTRY = 0; 234 // Generate a step marker at each iteration of the top level while loop, 235 // which is assumed to be a training loop. 236 STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; 237 // Generate a step marker at each iteration of the second level while loops, 238 // which is assumed to be a training or eval loop. 239 STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3; 240 // No step marker generated. 241 STEP_MARK_NONE = 2; 242 } 243 // Option to emit a target-specific marker to indicate the start of a training 244 // step. The location of the marker (if any) is determined by the option 245 // value. 246 StepMarkerLocation xla_step_marker_location = 108; 247 248 // 249 // BEGIN flags controlling dumping HLO modules for debugging. 250 // 251 // When dumping is enabled, HLO modules dumped at the very beginning and end 252 // of compilation, and optionally also during the pass pipeline. 253 // 254 // In general, if you set one of these flags, we will try to infer reasonable 255 // defaults for the others. For example: 256 // 257 // * Setting --xla_dump_to=/tmp/foo without specifying a format 258 // with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text. 259 // 260 // * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will 261 // dump to stdout. 262 // 263 264 // Directory to dump into. 265 string xla_dump_to = 109; 266 267 // If specified, will only dump modules which match this regexp. 268 string xla_dump_hlo_module_re = 110; 269 270 // If this flag is specified, will also dump HLO before and after passes that 271 // match this regular expression. Set to .* to dump before/after all passes. 272 string xla_dump_hlo_pass_re = 111; 273 274 // Specifies the format that HLO is dumped in. Multiple of these may be 275 // specified. 276 bool xla_dump_hlo_as_text = 112; 277 bool xla_dump_hlo_as_proto = 113; 278 bool xla_dump_hlo_as_dot = 114; 279 bool xla_dump_hlo_as_url = 115; 280 281 // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) 282 bool xla_dump_hlo_as_html = 116; 283 284 // Dump the visualization of the fusion progress. 285 bool xla_dump_fusion_visualization = 149; 286 287 // If true, every time an HLO module is run, we will dump an HloSnapshot 288 // (essentially, a serialized module plus its inputs) to the --xla_dump_to 289 // directory. 290 bool xla_dump_hlo_snapshots = 118; 291 292 // Include a timestamp in the dumped filenames. 293 bool xla_dump_include_timestamp = 131; 294 295 // Max number of hlo module dumps in a directory. Set to < 0 for unbounded. 296 int32 xla_dump_max_hlo_modules = 132; 297 298 // Dump HloModuleMetadata as a text proto for each HLO module. 299 bool xla_dump_module_metadata = 144; 300 301 // GZip-compress protos dumped via --xla_dump_hlo_as_proto. 302 bool xla_dump_compress_protos = 151; 303 304 // Dump HLO in long text format. Ignored unless xla_dump_hlo_as_text is true. 305 bool xla_dump_hlo_as_long_text = 164; 306 307 // 308 // END flags controlling dumping HLO modules. 309 // 310 311 // Overrides for XLA GPU's convolution layout heuristic. 312 bool xla_gpu_force_conv_nchw = 125; 313 bool xla_gpu_force_conv_nhwc = 146; 314 315 // Paths to files with ptx code. 316 repeated string xla_gpu_ptx_file = 127; 317 318 // Whether to dump llvm ir when compiling to ptx. 319 bool xla_gpu_dump_llvmir = 155; 320 321 // Denylist for cuDNN convolutions. 322 string xla_gpu_algorithm_denylist_path = 128; 323 324 reserved 130; // Was xla_gpu_deterministic_reductions 325 326 // Debug options that trigger execution errors when NaN or Inf are detected. 327 bool xla_tpu_detect_nan = 135; 328 bool xla_tpu_detect_inf = 136; 329 330 // True if TraceMe annotations are enabled for XLA:CPU. 331 bool xla_cpu_enable_xprof_traceme = 137; 332 333 // It is usually preferable to not fallback to the driver; it can consume more 334 // memory, or have bugs. 335 bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138; 336 337 // Extra parameters to pass the GPU assembler. 338 string xla_gpu_asm_extra_flags = 141; 339 340 // Per-heap size constraint. New heaps will be created if per-heap max size is 341 // reached. 342 int32 xla_multiheap_size_constraint_per_heap = 142; 343 344 // Enable detailed logging into vlog and xla dumping. If this is disabled, no 345 // compilation summary will be printed in the end of computation and no hlo 346 // modules will be dumped. 347 bool xla_detailed_logging_and_dumping = 143; 348 349 // Overrides normal multi-threaded compilation settting to use this many 350 // threads. Setting to 0 (the default value) means no enforcement. 351 int32 xla_gpu_force_compilation_parallelism = 147; 352 353 // Guarantees run-to-run determinism. At present, the HLO ops Scatter and 354 // SelectAndScatter do not have deterministic XLA:GPU implementations. 355 // Compilation errors out if these ops are encountered. 356 bool xla_gpu_deterministic_ops = 148; 357 358 // Paths to files with LLVM code. 359 repeated string xla_gpu_llvm_ir_file = 150; 360 361 // Convert synchronous all-reduces ops into asynchronous. 362 bool xla_gpu_enable_async_all_reduce = 152; 363 364 // Size threshold (in bytes) for the GPU all-reduce combiner. 365 int64 xla_gpu_all_reduce_combine_threshold_bytes = 157; 366 367 // Combine GPU all-reduces into a single operation over a contiguous buffer. 368 bool xla_gpu_all_reduce_contiguous = 158; 369 370 // Number of devices per host for first stage of BlueConnect decomposition 371 // pass. The pass will attempt to decompose all-reduces ops into a 372 // ReduceScatter-AllReduce-AllGather sequence, with the initial ReduceScatter 373 // being performed over all of the devices in the same host. Set to < 1 to 374 // disable all-reduce decomposition. 375 int32 xla_gpu_all_reduce_blueconnect_num_devices_per_host = 159; 376 377 // Whether to use the cuDNN frontend API for convolutions when possible. 378 bool xla_gpu_enable_cudnn_frontend = 160; 379 380 // Disable dumping metadata in HLO dumps. 381 bool xla_dump_disable_metadata = 153; 382 383 // If this flag is specified, will only dump HLO before and after passes in 384 // the pass pipeline that matches this regular expression. Default empty value 385 // enables dumping in all pipelines. 386 string xla_dump_hlo_pipeline_re = 154; 387 388 // If true, abort immediately when conv algorithm picker fails, rather than 389 // logging a warning and proceeding with fallback. 390 bool xla_gpu_strict_conv_algorithm_picker = 156; 391 392 reserved 161; // Was xla_gpu_bef_executable 393 reserved 162; // Was xla_gpu_bef_thunk 394 395 // If true, enable XLIR to compile gpu programs to JitRt. 396 bool xla_gpu_jitrt_executable = 169; 397 398 // Timeout in seconds before terminating jobs that are stuck in a NCCL 399 // Rendezvous. Negative value disables the timeout and will not terminate. 400 int64 xla_gpu_nccl_termination_timeout_seconds = 163; 401 402 // Enables shared constants for XLA/GPU. This allows large constants to be 403 // shared among multiple GPU executables. 404 bool xla_gpu_enable_shared_constants = 165; 405 406 // Whether to use cuBLASLt for GEMMs on GPUs. 407 bool xla_gpu_enable_cublaslt = 166; 408 409 // Size threshold (in megabytes) for the GPU redzone scratch allocator. 410 int64 xla_gpu_redzone_scratch_max_megabytes = 167; 411 412 // Allows all floating-point conversions to be simplified, including those 413 // that affect the numerics. The `BFloat16Normalization` pass inserts many 414 // `f32 -> bf16 -> f32` conversion pairs. These are not removed by the 415 // `AlgebraicSimplifier`, as that will only simplify conversions that are 416 // no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. 417 bool xla_gpu_simplify_all_fp_conversions = 168; 418 419 // An experimental option to force all layouts present in the 420 // after-optimizations HLO to be descending, e.g. 421 // ShapeUtil::MakeShapeWithDescendingLayout is an identity on all 422 // instructions. 423 bool xla_gpu_normalize_layouts = 172; 424 425 // Generate calls to Arm Compute Library in the CPU backend. 426 bool xla_cpu_use_acl = 174; 427 428 // By default, XLA:CPU will run fp16 dot/conv as fp32, as this is generally 429 // (much) faster on our hardware. Set this flag to disable this behavior. 430 bool xla_cpu_strict_dot_conv_math = 175; 431 432 // Next id: 179 433 434 // Extra options to pass to the compilation backend (e.g. LLVM); specific 435 // interpretation of these values is left to the backend. 436 map<string, string> xla_backend_extra_options = 500; 437 438 // Reserved tags were xla_hlo_dump_as_graphdef, xla_dump_to, 439 // xla_gpu_use_horizontal_fusion, 440 // xla_gpu_unsafe_fallback_to_driver_on_ptxas_error, 441 // xla_gpu_simplify_scatters, xla_gpu_simplify_gathers 442 reserved 5, 117, 133, 139, 176, 178; 443} 444 445// These settings control how XLA compiles and/or runs code. Not all settings 446// will have an effect on every platform. 447// 448// When adding new fields, keep in mind that boolean fields default to false. 449message ExecutionOptions { 450 // This optional field's layout is used as a hint when storing the output of 451 // this computation. Subsequent transfers of this output array to the client 452 // may be faster when using this layout. 453 // 454 // We use a Shape here to accommodate computations that return a tuple. 455 ShapeProto shape_with_output_layout = 2; 456 457 // Used to seed random-number generators used in this computation. If this is 458 // 0, we generate a seed ourselves. 459 // 460 // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation. 461 uint64 seed = 3; 462 463 DebugOptions debug_options = 4; 464 465 // This optional field specifies a particular set of devices to run the 466 // computation on. The computation will be partitioned across these devices. 467 // If not provided, the default device will be chosen. 468 repeated DeviceHandle device_handles = 5; 469 470 // Number of replicas of the computation to run. If zero, uses the default 471 // number of replicas for the XLA service. 472 int32 num_replicas = 6; 473 474 // This optional field specifies the device assignment if known at compile 475 // time. 476 DeviceAssignmentProto device_assignment = 7; 477 478 // Alias input and output buffers for parameters that are passed-through XLA 479 // modules without being changed. 480 bool alias_passthrough_params = 8; 481 482 // Number of partitions of the computation to run (model parallelism). 483 // If zero, uses the default number of partitions for the XLA service. 484 int32 num_partitions = 9; 485 486 // Used to identify a set of programs that should be launch together. 487 int32 launch_id = 10; 488 489 // Indicates whether to use SPMD (true) or MPMD (false) partitioning when 490 // num_partitions > 1 and XLA is requested to partition the input program. 491 bool use_spmd_partitioning = 11; 492 493 // Whether to automatically generate XLA shardings for SPMD partitioner. 494 bool use_auto_spmd_partitioning = 15; 495 496 // Device mesh shape used to create the sharding search space when 497 // use_auto_spmd_partitioning=true. 498 repeated int64 auto_spmd_partitioning_mesh_shape = 16; 499 500 // Device mesh ids compatible with the above mesh_shape used when 501 // use_auto_spmd_partitioning=true. 502 repeated int64 auto_spmd_partitioning_mesh_ids = 17; 503 504 // If set, deduplicate hlo into function calls to reduce binary size. Only 505 // works on TPU. 506 bool deduplicate_hlo = 12; 507 508 reserved 13; // Was broadcast_replicated_parameters_via_collectives 509 510 // Allows sharding propagation to propagate to the outputs. This changes the 511 // output shape of the computation (which is undesirable), but it can be used 512 // to allow to run partial compilation to determine what would be the output 513 // sharding of a computation if XLA would be allowed to propagate the sharding 514 // which can be used by higher level framework as a way to query intermediate 515 // sharding of operations when multiple computation would be chained and 516 // merged together. 517 bool allow_spmd_sharding_propagation_to_output = 14; 518} 519 520message GetDeviceHandlesRequest { 521 int64 device_count = 1; 522} 523 524message GetDeviceHandlesResponse { 525 repeated DeviceHandle device_handles = 1; 526} 527 528message TransferToClientRequest { 529 GlobalDataHandle data = 1; 530 531 // This optional field directs the service to return the literal in this 532 // layout. A shape is used to hold the layout to accommodate tuples. 533 ShapeProto shape_with_layout = 2; 534} 535 536message TransferToClientResponse { 537 LiteralProto literal = 1; 538} 539 540message TransferToServerRequest { 541 LiteralProto literal = 1; 542 DeviceHandle device_handle = 2; 543} 544 545message TransferToServerResponse { 546 GlobalDataHandle data = 1; 547} 548 549message TransferToInfeedRequest { 550 LiteralProto literal = 1; 551 int64 replica_id = 2; 552 DeviceHandle device_handle = 3; 553} 554 555message TransferToInfeedResponse {} 556 557message TransferFromOutfeedRequest { 558 // This optional field directs the service to return the literal in this 559 // layout. A shape is used to hold the layout to accommodate tuples. 560 ShapeProto shape_with_layout = 1; 561 562 int64 replica_id = 2; 563 DeviceHandle device_handle = 3; 564} 565 566message TransferFromOutfeedResponse { 567 LiteralProto literal = 1; 568} 569 570message ResetDeviceRequest { 571 DeviceHandle device_handle = 1; 572} 573 574message ResetDeviceResponse {} 575 576message ComputationGraphStatsRequest { 577 HloModuleProto computation = 1; 578 DebugOptions debug_options = 2; 579} 580 581message ComputationStatsResponse { 582 ComputationStats stats = 1; 583} 584 585message CreateChannelHandleRequest { 586 ChannelHandle.ChannelType channel_type = 1; 587} 588 589message CreateChannelHandleResponse { 590 ChannelHandle channel = 1; 591} 592 593message UnregisterRequest { 594 repeated GlobalDataHandle data = 1; 595} 596 597message UnregisterResponse {} 598 599message CompileRequest { 600 // The graph to be compiled. 601 HloModuleProto computation = 1; 602 603 // Options that affect how XLA compiles code to service this request. 604 ExecutionOptions execution_options = 2; 605 606 // The layouts of the input arguments. If not set, the default layout will be 607 // used. Although the real arguments are not needed in compilation, the 608 // layouts of the arguments can affect the compilation. 609 repeated ShapeProto input_shape_with_layout = 3; 610} 611 612message CompileResponse { 613 // The handle to the executable. 614 ExecutionHandle handle = 1; 615} 616 617message ExecuteRequest { 618 ExecutionHandle handle = 1; 619 620 // The shape and layout of the arguments must be the same as the those of the 621 // executable's parameters. 622 repeated GlobalDataHandle arguments = 2; 623} 624 625// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace 626// the uses with calls to Compile and Execute. 627message ExecuteGraphRequest { 628 HloModuleProto computation = 1; 629 repeated GlobalDataHandle arguments = 2; 630 631 // Options that affect how XLA compiles and runs code to service this request. 632 ExecutionOptions execution_options = 3; 633} 634 635message ExecuteGraphParallelRequest { 636 repeated ExecuteGraphRequest requests = 1; 637} 638 639message ExecuteResponse { 640 GlobalDataHandle output = 1; 641 ExecutionProfile profile = 2; 642} 643 644message ExecuteParallelResponse { 645 repeated ExecuteResponse responses = 1; 646} 647 648message WaitForExecutionRequest { 649 ExecutionHandle execution = 1; 650} 651 652message WaitForExecutionResponse { 653 GlobalDataHandle output = 1; 654 ExecutionProfile profile = 2; 655} 656 657message ComputeConstantGraphRequest { 658 HloModuleProto computation = 1; 659 LayoutProto output_layout = 2; 660} 661 662message ComputeConstantResponse { 663 // A LiteralProto is returned directly for this request. 664 LiteralProto literal = 1; 665} 666 667message DeconstructTupleRequest { 668 GlobalDataHandle tuple_handle = 2; 669} 670 671message DeconstructTupleResponse { 672 repeated GlobalDataHandle element_handles = 1; 673} 674 675message LoadDataRequest { 676 // Describes the path of the ColumnIO tablet to load. 677 string columnio_tablet_path = 1; 678 679 // Describes the field to load within the ColumnIO tablet. 680 string columnio_field = 2; 681 682 // Individual element shape, excluding rows. 683 ShapeProto element_shape = 3; 684 685 // Warning: ColumnIO does not support random-access, so use offset with 686 // caution in performance-critical scenarios. 687 int64 offset = 4; 688 689 // Maximum number of elements (with shape element_shape) to load. 690 int64 limit = 5; 691 692 // If more than one item is requested (via limit > 1), then this request 693 // attribute zips together the produced vectors. 694 bool zip = 6; 695} 696 697message LoadDataResponse { 698 GlobalDataHandle data = 1; 699 ShapeProto data_shape = 2; 700 int64 available_rows = 3; 701 int64 rows_loaded = 4; 702 int64 nanoseconds = 5; 703} 704 705message GetShapeRequest { 706 GlobalDataHandle data = 1; 707} 708 709message GetShapeResponse { 710 ShapeProto shape = 1; 711} 712 713message UnpackRequest { 714 GlobalDataHandle data = 1; 715} 716 717message UnpackResponse { 718 repeated GlobalDataHandle tied_data = 1; 719} 720