1syntax = "proto3"; 2 3package tensorflow.tpu; 4 5import "tensorflow/compiler/xla/xla.proto"; 6import "tensorflow/compiler/xla/xla_data.proto"; 7import "tensorflow/core/framework/tensor_shape.proto"; 8import "tensorflow/core/framework/types.proto"; 9import "tensorflow/core/protobuf/tpu/dynamic_padding.proto"; 10 11option cc_enable_arenas = true; 12 13// This is an experimental proto used in the TF/XLA bridge to store metadata to 14// a compile op (e.g. _TPUCompileMlir). 15// TODO(lyandy): Deprecate proto once generic metadata proto is created. 16message TPUCompileMetadataProto { 17 // Description of the types and shapes of the arguments to a computation. 18 message Arg { 19 enum Kind { 20 INVALID = 0; 21 PARAMETER = 1; 22 VARIABLE = 2; 23 // These are args which have been guaranteed to be constants during the 24 // session lifetime by the use of the GuaranteeConstOp (or ConstantOp). 25 GUARANTEED_CONSTANT = 3; 26 } 27 DataType dtype = 1; 28 TensorShapeProto shape = 2; 29 Kind kind = 3; 30 31 // The cross-core sharding of this input within each replica, e.g., 32 // assigning to one core, or replicate across all cores. 33 xla.OpSharding sharding = 4; 34 35 // Whether this argument will receive the same data across all replicas. 36 bool is_same_data_across_replicas = 5; 37 38 enum EnableXlaSharding { 39 DISALLOWED = 0; 40 // Sharding is allowed if host training loop exists. 41 TENTATIVE = 1; 42 ALLOWED = 2; 43 } 44 // Whether to allow XLA to produce separate programs to shard/unshard this 45 // argument. Requires this arg to be an on-device Kind::VARIABLE, or a 46 // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of 47 // a variable, and retval_index_for_sharding must be specified for the 48 // corresponding updated value. 49 EnableXlaSharding enable_xla_sharding = 6; 50 51 // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to 52 // specify the corresponding updated value in the return values. Use -1 for 53 // variables that are not updated. 54 int32 retval_index_for_sharding = 8; 55 56 // Whether this argument is placed on fast memory or not. 57 bool fast_mem = 7; 58 59 // Whether to let XLA to decide the layout during compilation, as opposed to 60 // using a fixed layout determined by the shape. 61 bool unrestricted_layout = 9; 62 } 63 repeated Arg args = 1; 64 65 // Description of the return values from a computation. 66 message Retval { 67 // The cross-core sharding of this return value within each replica, e.g., 68 // assigning to one core, or replicate across all cores. 69 xla.OpSharding sharding = 1; 70 } 71 repeated Retval retvals = 2; 72 73 // Number of replicas of the computation and number of cores in each replica. 74 // TODO(b/140721404): it may not be necessary to state the number of cores per 75 // replica here. Reconsider when replicated model-parallelism is implemented 76 // in XLA. 77 int32 num_replicas = 3; 78 int32 num_cores_per_replica = 4; 79 80 reserved 5; // was device_names 81 reserved 7; // was replica_device_assignment 82 83 xla.DeviceAssignmentProto device_assignment = 8; 84 85 // A fingerprint of the function library. Ensures that any functions called 86 // by the computation have matching definitions. 87 uint64 function_library_fingerprint = 6; 88 89 // Unique session identifier. Can be empty. 90 string session_handle = 9; 91 92 // Fingerprint of guaranteed_const value. The fingerprint computation inside 93 // tpu_compile_op may be slow. The compuation can be avoided by setting the 94 // fingerprint value here. 95 string guaranteed_const_fingerprint = 10; 96 97 repeated tpu.PaddingMap padding_maps = 11; 98 99 // The location of step markers that XLA compile will instrument. 100 xla.DebugOptions.StepMarkerLocation step_marker_location = 12; 101 102 // Minimum number of batches run through the XLA graph before XLA fusion 103 // autotuner is enabled. Default value of zero disables the autotuner. 104 // The XLA fusion autotuner can improve performance by executing a heuristic 105 // search on the compiler parameters. 106 int64 xla_fusion_autotuner_thresh = 13; 107 108 // Enables TPU compiler to add sharding policies for inputs/outputs to 109 // the XLA computation for model parallelism. 110 bool enable_automatic_model_parallelism = 14; 111} 112