• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1syntax = "proto3";
2
3package tensorflow.tpu;
4
5import "google/protobuf/wrappers.proto";
6import "tensorflow/compiler/xla/service/hlo.proto";
7
8message ClippingLimits {
9  google.protobuf.FloatValue lower = 1;  // -inf if not set
10  google.protobuf.FloatValue upper = 2;  // +inf if not set
11}
12
13// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The
14// actual learning rates are provided as a scalar input list to the
15// SendTPUEmbeddingGradients Op indexed by their tag specified through the
16// following proto.
17message DynamicLearningRate {
18  // For tables where learning rates are dynamically computed and communicated
19  // to the TPU embedding program, a tag must be specified for the learning
20  // rate.
21  //
22  // The tag must be a non-negative  integer. The total number of unique tags
23  // must be less than or equal to the number of tables in the TPU embedding
24  // configuration (a table does not specify any tag if it uses a constant
25  // learning rate, and specifies exactly one tag if it uses dynamic learning
26  // rates).
27  //
28  // All tags in the range [0, number_of_unique_tags) must be present in the TPU
29  // embedding configuration, i.e. a tag cannot be skipped if a different tag
30  // numerically greater than it is used in the configuration.
31  //
32  // If multiple tables specify the same tag, they *MUST* have
33  // the same dynamic learning rate, for example, their dynamic learning rate
34  // could be computed by the same TensorFlow sub-graph. The partitioning of the
35  // embedding layer would be more optimal if the number_of_unique_tags is as
36  // *LOW* as possible, i.e., if many tables share the same tag.
37  //
38  // The learning_rate input of the SendTPUEmbeddingGradients op is used to
39  // communicate dynamic learning rates to the TPU embedding program.
40  // The learning_rate input is a list of scalars where the size of the list is
41  // equal to the number of unique tags. The learning rate associated with a
42  // particular tag is specified by populating its corresponding index in the
43  // list of learning_rate scalars.
44  int32 tag = 1;
45}
46
47// Source of learning rate to use.
48message LearningRate {
49  oneof learning_rate {
50    float constant = 1;
51    DynamicLearningRate dynamic = 2;
52  }
53}
54
55// Each optimizer's parameter proto has a link to its documentation and CPU
56// implementation (if available) for user reference.
57
58// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad
59// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1634
60message AdagradParameters {
61  // Old initial accumulator parameter.
62  reserved "initial_accumulator";
63  reserved 1;
64}
65
66// This optimizer combines the Adagrad and Momentum update rules.
67// accum(new) = accum(old) + grad^2
68// mom_accum(new) = momentum * mom_accum(old) + accum(new)^(-1.0 / exponent)
69// update = use_nesterov ?
70//          momentum * mom_accum(new) + accum(new)^(-1.0 / exponent) :
71//          mom_accum(new)
72// var(new) = var(old) - lr * grad * update
73// Algorithm described in https://arxiv.org/abs/2002.11803.
74message AdagradMomentumParameters {
75  float momentum = 1;
76  bool use_nesterov = 2;
77  float exponent = 3;
78}
79
80// Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
81message BoundedAdagradParameters {
82  // Whether to use the updated or the old value of the accumulator when
83  // computing the effective learning rate. When update_accumulator_first is set
84  // to True, the updated value of the accumulator is used.
85  bool update_accumulator_first = 1;
86  // The max_var_update value to use. Set value to 0 (default) to disable using
87  // max_var_update to clip the gradient.
88  float max_var_update = 2;
89  // The maximum value of the accumulator. Set max_accumulator to 0 (default)
90  // to disable using max_accumulator to clip the accumulator.
91  float max_accumulator = 3;
92}
93
94// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
95// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629
96message StochasticGradientDescentParameters {}
97
98// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
99// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf
100// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L2646
101//
102// The hyperparameters for FTRL are the same as for the Keras implementation,
103// with some additions. The "beta" parameter matches the behavior described in
104// the second link above; "beta" / (2 * learning rate) should be added to "l2"
105// to get equivalent behavior in the other TensorFlow implementations of this
106// optimizer. When the multiply_linear_by_lr field is set to true, a modified
107// formula is used for FTRL that treats the "linear" accumulator as being
108// pre-multiplied by the learning rate (i.e., the accumulator named "linear"
109// actually stores "linear * learning_rate"). Other than checkpoint
110// compatibility, this is mathematically equivalent for a static learning rate;
111// for a dynamic learning rate, it is nearly the same as long as the learning
112// rate does not change quickly. The benefit of setting multiply_linear_by_lr to
113// true is that the modified formula handles zero and near-zero learning rates
114// without producing NaNs, improving flexibility for learning rate ramp-up. The
115// allow_zero_accumulator parameter changes some internal formulas to allow zero
116// and near-zero accumulator values at the cost of some performance; this only
117// needs to be set if you are using an initial accumulator value of zero, which
118// is uncommon.
119message FtrlParameters {
120  float l1 = 1;
121  float l2 = 2;
122  float lr_power = 3;
123  float beta = 7;
124  bool multiply_linear_by_lr = 6;
125  bool allow_zero_accumulator = 8;
126
127  // Old initial accumulator parameters.
128  reserved "initial_accum", "initial_linear";
129  reserved 4, 5;
130}
131
132// The Adam optimizer does not implement hyper-parameter update due to hardware
133// limitations; use the dynamic learning rate feature instead, setting the
134// learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
135// Here, t is the current timestep.
136//
137// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam
138// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L32
139//
140// Note that the code by default implements the lazy version of Adam
141// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
142// unless the use_non_lazy_adam parameter is set, in which case it implements
143// the normal version of Adam that updates all parameters in the embedding
144// table, even for entries that are not used in the current minibatch
145// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
146// use_non_lazy_adam is enabled, gradient accumulation is also required to be
147// enabled in order to get correct results; a warning will be printed otherwise
148// (which may change to an error in the future). If use_sum_inside_sqrt is set,
149// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon)
150// to m / sqrt(v + epsilon**2); this option improves the performance of TPU
151// training and is not expected to harm model quality.
152message AdamParameters {
153  float beta1 = 3;
154  float beta2 = 4;
155  float epsilon = 5;
156  bool use_non_lazy_adam = 8;
157  bool use_sum_inside_sqrt = 10;
158
159  // Old initial accumulator parameters.
160  reserved "initial_m", "initial_v";
161  reserved 6, 7;
162}
163
164// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
165// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L3068
166message MomentumParameters {
167  float momentum = 1;
168  bool use_nesterov = 2;
169
170  // Old initial accumulator parameter.
171  reserved "initial_accum";
172  reserved 3;
173}
174
175// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
176// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4229
177message RmsPropParameters {
178  float rho = 1;
179  float momentum = 2;
180  float epsilon = 3;
181
182  // Old initial accumulator parameters.
183  reserved "initial_ms", "initial_mom";
184  reserved 4, 5;
185}
186
187// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop
188// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4358
189message CenteredRmsPropParameters {
190  float rho = 1;
191  float momentum = 2;
192  float epsilon = 3;
193
194  // Old initial accumulator parameters.
195  reserved "initial_ms", "initial_mom", "initial_mg";
196  reserved 4, 5, 6;
197}
198
199// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
200message MdlAdagradLightParameters {
201  float l2 = 1;
202  float lr_power = 2;
203  float min_servable_mdl_benefit = 3;
204  float mdl_mix_in_margin = 4;
205  float mdl_benefit_rampup_coeff = 5;
206  float mdl_min_weight = 6;
207  float benefit_revisit_scale = 7;
208  float max_event_benefit = 8;
209  float max_total_benefit = 9;
210  float mdl_hard_limit = 10;
211  bool hard_limit_min_benefit = 11;
212  bool mdl_regularize = 12;
213
214  // Old initial accumulator parameters.
215  reserved "initial_accumulator", "initial_weight", "initial_benefit";
216  reserved 13, 14, 15;
217}
218
219// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adadelta
220// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L933
221message AdadeltaParameters {
222  float rho = 1;
223  float epsilon = 2;
224
225  // Old initial accumulator parameters.
226  reserved "initial_accumulator", "initial_update";
227  reserved 3, 4;
228}
229
230// https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer
231// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1961
232message ProximalAdagradParameters {
233  float l1 = 1;
234  float l2 = 2;
235
236  // Old initial accumulator parameter.
237  reserved "initial_accumulator";
238  reserved 3;
239}
240
241// The online Yogi optimizer does not implement hyper-parameter update; use the
242// dynamic learning rate feature instead, setting the learning rate to:
243// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
244// Here, t is the current timestep.
245//
246// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
247// plus some extensions based on FTRL.
248//
249// Note that the code by default implements the lazy version of online Yogi.
250message OnlineYogiParameters {
251  // The L1 regularization parameter (used analogously to the one in FTRL).
252  float l1 = 1;
253
254  // The L2 regularization parameter (used analogously to the one in FTRL).
255  float l2 = 2;
256
257  // \beta_2 from Algorithm 2 in the paper.
258  float beta2 = 3;
259
260  // Reserved ids corresponding to removed tanh activation.
261  reserved 6;  // sign
262  reserved 7;  // tanh
263}
264
265// The online Yogi optimizer does not implement hyper-parameter update; use the
266// dynamic learning rate feature instead, setting the learning rate to:
267// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
268// Here, t is the current timestep.
269//
270// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf
271// plus some extensions based on FTRL.
272//
273// Note that the code by default implements the lazy version of proximal Yogi.
274message ProximalYogiParameters {
275  // The L1 regularization parameter.
276  float l1 = 1;
277
278  // The L2 regularization parameter.
279  float l2 = 2;
280
281  // The exponential decay rate for the 1st moment estimates.
282  float beta1 = 3;
283
284  // The exponential decay rate for the 2nd moment estimates.
285  float beta2 = 4;
286
287  // A constant trading off adaptivity and noise.
288  float epsilon = 5;
289
290  // Reserved ids corresponding to removed tanh activation.
291  reserved 8;  // sign
292  reserved 9;  // tanh
293}
294
295// Estimator for the frequency of updates to a lookup table. It maintains an
296// array (tf.Variable) D, where each element records the average number of
297// global steps between two consecutive batches that hit the corresponding
298// bucket. Once an item with bucket id i is sampled, D[i] is updated by:
299//   D[i] <- D[i] * (1 - tau) + delta[i] * tau,
300//
301// where tau is a learning rate between 0 and 1 (exclusive), and
302//   delta[i] = current global step - last step i is sampled.
303//
304// The estimated frequency (sampling rate in a batch) is thus 1 / D[i].
305//
306// Elements in D are initialized with a large value max_delta. delta[i] will
307// also be capped by this value.
308//
309// The exact sequence of operations used in the optimizer is shown below.
310// last_hit_step[i] is a tf.Variable that holds the last global step at which i
311// was sampled.
312//
313//   delta = global_step - last_hit_step[i]
314//   clipped_delta = min(delta, params.max_delta)
315//   is_outlier = (delta >= params.outlier_threshold * D[i])
316//   D[i] <- is_outlier ? clipped_delta
317//                      : D[i] * (1 - params.tau) + clipped_delta * params.tau
318//   last_hit_step[i] <- global_step
319message FrequencyEstimatorParameters {
320  // Learning rate between (0, 1) that is used to update the array D.
321  float tau = 1;
322
323  // Maximum value of delta: difference between the current global step and the
324  // last global step at which the row was sampled.
325  float max_delta = 2;
326
327  // Threshold used to determine whether the current update is an outlier.
328  float outlier_threshold = 3;
329
330  // The weight exponent used to transform the estimated delta into weights.
331  // The transformation function is: (delta / max_delta) ^ (weight_exponent)
332  float weight_exponent = 4;
333}
334
335// A user-defined optimizer.
336// The contained HLO program must take the following arguments in the following
337// order:
338// 1.  gradients
339// 2.  table weights
340// 3.  slot variables
341// 4.  an optional scalar input that is passed in via the dynamic learning
342//     rate mechanism.
343//
344// It must return/end in a tuple op that contains the following values in the
345// following order:
346// 1.  new table values
347// 2.  new slot variable value
348//
349// The program must have shape (1,1) with dtype float32 throughout and only use
350// HLO that operate elementwise (e.g., no reduce, no variables, no control flow
351// and no broadcasting outside of the single scalar input).
352// The HLO program should be written as if it were a dense update. It will be
353// called on each row that needs an update and will applied elementwise.
354message UserDefinedProgramParameters {
355  xla.HloModuleProto program = 1;
356  reserved 2;  // Was padding_values
357}
358
359// Optimizer that just sets the variable to the value of the gradient. To be
360// correct, this requires either gradient accumulation (to sum the values of a
361// computed expression across the samples) or to deduplicate IDs within a single
362// host (to assign the value from an arbitrary sample).
363message AssignParameters {}
364
365// Status of using gradient accumulation (doing two passes over the input
366// gradients: one to accumulate them into a temporary array and another to apply
367// them using the actual optimization algorithm). The extra message is to wrap
368// the enum for scoping.
369message GradientAccumulationStatus {
370  // if UNSPECIFIED (default), gradient accumulation is ENABLED.
371  enum Status {
372    UNSPECIFIED = 0;
373    ENABLED = 1;
374    DISABLED = 2;
375  }
376}
377
378// Configuration proto for hot ID optimization. This is an experimental feature
379// that is currently disabled (by default).
380message HotIdReplicationConfiguration {
381  // Whether to enable or disable hot ID optimization.
382  // If UNSPECIFIED (default), hot ID optimization is DISABLED.
383  enum Status {
384    UNSPECIFIED = 0;
385    ENABLED = 1;
386    DISABLED = 2;
387  }
388  Status status = 1;
389}
390
391message OptimizationParameters {
392  // Learning rate used for updating the embedding layer parameters.
393  LearningRate learning_rate = 13;
394  reserved 1;  // Old learning rate tag.
395
396  // Limits to which to clip the weight values after the backward pass; not
397  // present means no limits are applied.
398  ClippingLimits clipping_limits = 2;
399
400  // Limits to which to clip the backward pass gradient before using it for
401  // updates; not present means no limits are applied.
402  ClippingLimits gradient_clipping_limits = 7;
403
404  // Amount of weight decay to apply; see weight_decay_optimizers.py for
405  // details. Almost all optimizers are supported with this option (MDL Adagrad
406  // Light does not work, and SGD does not behave as expected if it is enabled).
407  // Although there is no check, users who want weight decay will probably also
408  // want to enable gradient accumulation as well so that the decay will happen
409  // once per minibatch.
410  float weight_decay_factor = 16;
411
412  // If true, the weight decay factor is multiplied by the current learning rate
413  // before use; this is to match the note in DecoupledWeightDecayExtension in
414  // weight_decay_optimizers.py.
415  bool multiply_weight_decay_factor_by_learning_rate = 22;
416
417  // Status of using gradient accumulation (doing two passes over the input
418  // gradients: one to accumulate them into a temporary array and another to
419  // apply them using the actual optimization algorithm).
420  GradientAccumulationStatus.Status gradient_accumulation_status = 17;
421
422  // Configuration proto for hot ID replication. This is an experimental
423  // feature that is currently disabled (by default).
424  HotIdReplicationConfiguration hot_id_replication_configuration = 18;
425
426  // Optimization algorithm parameters; which field is selected determines which
427  // algorithm to use.
428  oneof parameters {
429    AdagradParameters adagrad = 3;
430    AdagradMomentumParameters adagrad_momentum = 26;
431    BoundedAdagradParameters bounded_adagrad = 19;
432    StochasticGradientDescentParameters stochastic_gradient_descent = 4;
433    FtrlParameters ftrl = 5;
434    AdamParameters adam = 6;
435    MomentumParameters momentum = 8;
436    RmsPropParameters rms_prop = 9;
437    CenteredRmsPropParameters centered_rms_prop = 10;
438    MdlAdagradLightParameters mdl_adagrad_light = 11;
439    AdadeltaParameters adadelta = 12;
440    ProximalAdagradParameters proximal_adagrad = 14;
441    OnlineYogiParameters online_yogi = 20;
442    ProximalYogiParameters proximal_yogi = 21;
443    FrequencyEstimatorParameters frequency_estimator = 23;
444    UserDefinedProgramParameters user_defined_program = 24;
445    AssignParameters assign = 25;
446  }
447
448  reserved 15;  // Old use_gradient_accumulation.
449}
450
451// Specification of an optimization algorithm's state variables (both the main
452// value vector and any extra accumulators, etc.). This proto is only used
453// internally by the TPU software and is not exposed directly to the TF model.
454message StateVariableSpecification {
455  // Parameter name for the state variable.
456  string name = 1;
457
458  // A normal state variable that should be saved and restored in checkpoints
459  // and used as an input or output to non-debug TensorFlow ops.
460  message UserDefined {
461    reserved 1;  // Was padding_initial_value.
462  }
463
464  // A state variable that should be filled with a constant and normally hidden
465  // from users (used for intermediate gradients being accumulated, for
466  // example).
467  message FillWithConstant {
468    double initial_value = 1;
469  }
470
471  // Usage type of this state variable.
472  oneof usage {
473    UserDefined user_defined = 2;
474    FillWithConstant fill_with_constant = 3;
475  }
476}
477