• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1syntax = "proto3";
2
3package tensorflow.tpu;
4
5import "tensorflow/compiler/xla/xla.proto";
6import "tensorflow/compiler/xla/xla_data.proto";
7import "tensorflow/core/framework/tensor_shape.proto";
8import "tensorflow/core/framework/types.proto";
9import "tensorflow/core/protobuf/tpu/dynamic_padding.proto";
10
11option cc_enable_arenas = true;
12
13// This is an experimental proto used in the TF/XLA bridge to store metadata to
14// a compile op (e.g. _TPUCompileMlir).
15// TODO(lyandy): Deprecate proto once generic metadata proto is created.
16message TPUCompileMetadataProto {
17  // Description of the types and shapes of the arguments to a computation.
18  message Arg {
19    enum Kind {
20      INVALID = 0;
21      PARAMETER = 1;
22      VARIABLE = 2;
23      // These are args which have been guaranteed to be constants during the
24      // session lifetime by the use of the GuaranteeConstOp (or ConstantOp).
25      GUARANTEED_CONSTANT = 3;
26    }
27    DataType dtype = 1;
28    TensorShapeProto shape = 2;
29    Kind kind = 3;
30
31    // The cross-core sharding of this input within each replica, e.g.,
32    // assigning to one core, or replicate across all cores.
33    xla.OpSharding sharding = 4;
34
35    // Whether this argument will receive the same data across all replicas.
36    bool is_same_data_across_replicas = 5;
37
38    enum EnableXlaSharding {
39      DISALLOWED = 0;
40      // Sharding is allowed if host training loop exists.
41      TENTATIVE = 1;
42      ALLOWED = 2;
43    }
44    // Whether to allow XLA to produce separate programs to shard/unshard this
45    // argument. Requires this arg to be an on-device Kind::VARIABLE, or a
46    // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of
47    // a variable, and retval_index_for_sharding must be specified for the
48    // corresponding updated value.
49    EnableXlaSharding enable_xla_sharding = 6;
50
51    // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to
52    // specify the corresponding updated value in the return values. Use -1 for
53    // variables that are not updated.
54    int32 retval_index_for_sharding = 8;
55
56    // Whether this argument is placed on fast memory or not.
57    bool fast_mem = 7;
58
59    // Whether to let XLA to decide the layout during compilation, as opposed to
60    // using a fixed layout determined by the shape.
61    bool unrestricted_layout = 9;
62  }
63  repeated Arg args = 1;
64
65  // Description of the return values from a computation.
66  message Retval {
67    // The cross-core sharding of this return value within each replica, e.g.,
68    // assigning to one core, or replicate across all cores.
69    xla.OpSharding sharding = 1;
70  }
71  repeated Retval retvals = 2;
72
73  // Number of replicas of the computation and number of cores in each replica.
74  // TODO(b/140721404): it may not be necessary to state the number of cores per
75  // replica here. Reconsider when replicated model-parallelism is implemented
76  // in XLA.
77  int32 num_replicas = 3;
78  int32 num_cores_per_replica = 4;
79
80  reserved 5;  // was device_names
81  reserved 7;  // was replica_device_assignment
82
83  xla.DeviceAssignmentProto device_assignment = 8;
84
85  // A fingerprint of the function library. Ensures that any functions called
86  // by the computation have matching definitions.
87  uint64 function_library_fingerprint = 6;
88
89  // Unique session identifier. Can be empty.
90  string session_handle = 9;
91
92  // Fingerprint of guaranteed_const value. The fingerprint computation inside
93  // tpu_compile_op may be slow. The compuation can be avoided by setting the
94  // fingerprint value here.
95  string guaranteed_const_fingerprint = 10;
96
97  repeated tpu.PaddingMap padding_maps = 11;
98
99  // The location of step markers that XLA compile will instrument.
100  xla.DebugOptions.StepMarkerLocation step_marker_location = 12;
101
102  // Minimum number of batches run through the XLA graph before XLA fusion
103  // autotuner is enabled. Default value of zero disables the autotuner.
104  // The XLA fusion autotuner can improve performance by executing a heuristic
105  // search on the compiler parameters.
106  int64 xla_fusion_autotuner_thresh = 13;
107
108  // Enables TPU compiler to add sharding policies for inputs/outputs to
109  // the XLA computation for model parallelism.
110  bool enable_automatic_model_parallelism = 14;
111}
112