1syntax = "proto3"; 2 3package tensorflow.tpu; 4 5import "tensorflow/compiler/xla/xla.proto"; 6import "tensorflow/compiler/xla/xla_data.proto"; 7import "tensorflow/core/framework/tensor_shape.proto"; 8import "tensorflow/core/framework/types.proto"; 9import "tensorflow/core/protobuf/tpu/dynamic_padding.proto"; 10 11option cc_enable_arenas = true; 12 13// This is an experimental proto used in the TF/XLA bridge to store metadata to 14// a compile op (e.g. _TPUCompileMlir). 15// TODO(lyandy): Deprecate proto once generic metadata proto is created. 16message TPUCompileMetadataProto { 17 // Description of the types and shapes of the arguments to a computation. 18 message Arg { 19 enum Kind { 20 INVALID = 0; 21 PARAMETER = 1; 22 VARIABLE = 2; 23 // These are args which have been guaranteed to be constants during the 24 // session lifetime by the use of the GuaranteeConstOp (or ConstantOp). 25 GUARANTEED_CONSTANT = 3; 26 } 27 DataType dtype = 1; 28 TensorShapeProto shape = 2; 29 Kind kind = 3; 30 31 // The cross-core sharding of this input within each replica, e.g., 32 // assigning to one core, or replicate across all cores. 33 xla.OpSharding sharding = 4; 34 35 // Whether this argument will receive the same data across all replicas. 36 bool is_same_data_across_replicas = 5; 37 38 enum EnableXlaSharding { 39 DISALLOWED = 0; 40 // Sharding is allowed if host training loop exists. 41 TENTATIVE = 1; 42 ALLOWED = 2; 43 } 44 // Whether to allow XLA to produce separate programs to shard/unshard this 45 // argument. Requires this arg to be an on-device Kind::VARIABLE, or a 46 // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of 47 // a variable, and retval_index_for_sharding must be specified for the 48 // corresponding updated value. 49 EnableXlaSharding enable_xla_sharding = 6; 50 51 // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to 52 // specify the corresponding updated value in the return values. Use -1 for 53 // variables that are not updated. 54 int32 retval_index_for_sharding = 8; 55 56 // Whether this argument is placed on fast memory or not. 57 bool fast_mem = 7; 58 59 // Whether to let XLA to decide the layout during compilation, as opposed to 60 // using a fixed layout determined by the shape. 61 bool unrestricted_layout = 9; 62 63 // Name of the node that the arg comes from. 64 string name = 10; 65 } 66 repeated Arg args = 1; 67 68 // Description of the return values from a computation. 69 message Retval { 70 // The cross-core sharding of this return value within each replica, e.g., 71 // assigning to one core, or replicate across all cores. 72 xla.OpSharding sharding = 1; 73 } 74 repeated Retval retvals = 2; 75 76 // Number of replicas of the computation and number of cores in each replica. 77 // TODO(b/140721404): it may not be necessary to state the number of cores per 78 // replica here. Reconsider when replicated model-parallelism is implemented 79 // in XLA. 80 int32 num_replicas = 3; 81 int32 num_cores_per_replica = 4; 82 83 reserved 5; // was device_names 84 reserved 7; // was replica_device_assignment 85 86 xla.DeviceAssignmentProto device_assignment = 8; 87 88 // A fingerprint of the function library. Ensures that any functions called 89 // by the computation have matching definitions. 90 uint64 function_library_fingerprint = 6; 91 92 // Unique session identifier. Can be empty. 93 string session_handle = 9; 94 95 // Fingerprint of guaranteed_const value. The fingerprint computation inside 96 // tpu_compile_op may be slow. The computation can be avoided by setting the 97 // fingerprint value here. 98 string guaranteed_const_fingerprint = 10; 99 100 repeated tpu.PaddingMap padding_maps = 11; 101 102 // The location of step markers that XLA compile will instrument. 103 xla.DebugOptions.StepMarkerLocation step_marker_location = 12; 104 105 // Minimum number of batches run through the XLA graph before XLA fusion 106 // autotuner is enabled. Default value of zero disables the autotuner. 107 // The XLA fusion autotuner can improve performance by executing a heuristic 108 // search on the compiler parameters. 109 int64 xla_fusion_autotuner_thresh = 13; 110 111 // Enables TPU compiler to add partitioning policies for inputs/outputs to 112 // the XLA computation for model parallelism. 113 bool enable_automatic_model_parallelism = 14; 114 115 // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is 116 // requested. 117 bool use_spmd_for_xla_partitioning = 15; 118 119 // Enables use of XLA collectives for broadcast of replicated parameters to 120 // all replicas, instead of using TensorFlow Send/Recv. 121 bool broadcast_replicated_parameters_via_collectives = 16; 122} 123