• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17syntax = "proto3";
18
19package com.android.federatedcompute.proto;
20
21import "google/protobuf/any.proto";
22import "tensorflow/core/framework/tensor.proto";
23import "tensorflow/core/framework/tensor_shape.proto";
24import "tensorflow/core/framework/types.proto";
25import "tensorflow/core/protobuf/saver.proto";
26import "tensorflow/core/protobuf/struct.proto";
27
28option java_package = "com.android.federatedcompute.proto";
29option java_multiple_files = true;
30option java_outer_classname = "PlanProto";
31
32// Primitives
33// ===========
34
35// Represents an operation to save or restore from a checkpoint.  Some
36// instances of this message may only be used either for restore or for
37// save, others for both directions. This is documented together with
38// their usage.
39//
40// This op has four essential uses:
41//   1. read and apply a checkpoint.
42//   2. write a checkpoint.
43//   3. read and apply from an aggregated side channel.
44//   4. write to a side channel (grouped with write a checkpoint).
45// We should consider splitting this into four separate messages.
46message CheckpointOp {
47  // An optional standard saver def. If not provided, only the
48  // op(s) below will be executed. This must be a version 1 SaverDef.
49  tensorflow.SaverDef saver_def = 1;
50
51  // An optional operation to run before the saver_def is executed for
52  // restore.
53  string before_restore_op = 2;
54
55  // An optional operation to run after the saver_def has been
56  // executed for restore. If side_channel_tensors are provided, then
57  // they should be provided in a feed_dict to this op.
58  string after_restore_op = 3;
59
60  // An optional operation to run before the saver_def will be
61  // executed for save.
62  string before_save_op = 4;
63
64  // An optional operation to run after the saver_def has been
65  // executed for save. If there are side_channel_tensors, this op
66  // should be run after the side_channel_tensors have been fetched.
67  string after_save_op = 5;
68
69  // In addition to being saved and restored from a checkpoint, one can
70  // also save and restore via a side channel. The keys in this map are
71  // the names of the tensors transmitted by the side channel. These (key)
72  // tensors should be read off just before saving a SaveDef and used
73  // by the code that handles the side channel. Any variables provided this
74  // way should NOT be saved in the SaveDef.
75  //
76  // For restoring, the variables that are provided by the side channel
77  // are restored differently than those for a checkpoint. For those from
78  // the side channel, these should be restored by calling the before_restore_op
79  // with a feed dict whose keys are the restore_names in the SideChannel and
80  // whose values are the values to be restored.
81  map<string, SideChannel> side_channel_tensors = 6;
82
83  // An optional name of a tensor in to which a unique token for the current
84  // session should be written.
85  //
86  // This session identifier allows TensorFlow ops such as `ServeSlices` or
87  // `ExternalDataset` to refer to callbacks and other session-global objects
88  // registered before running the session.
89  string session_token_tensor_name = 7;
90}
91
92message SideChannel {
93  // A side channel whose variables are processed via SecureAggregation.
94  // This side channel implements aggregation via sum over a set of
95  // clients, so the restored tensor will be a sum of multiple clients
96  // inputs into the side channel. Hence this will restore during the
97  // read_aggregate_update restore, not the per-client read_update restore.
98  message SecureAggregand {
99    message Dimension {
100      int64 size = 1;
101    }
102
103    // Dimensions of the aggregand. This is used by the secure aggregation
104    // protocol in its early rounds, not as redundant info which could be
105    // obtained by reading the dimensions of the tensor itself.
106    repeated Dimension dimension = 3;
107
108    // The data type anticipated by the server-side graph.
109    tensorflow.DataType dtype = 4;
110
111    // SecureAggregation will compute sum modulo this modulus.
112    message FixedModulus {
113      uint64 modulus = 1;
114    }
115
116    // SecureAggregation will for each shard compute sum modulo m with m at
117    // least (1 + shard_size * (base_modulus - 1)), then aggregate
118    // shard results with non-modular addition. Here, shard_size is the number
119    // of clients in the shard.
120    //
121    // Note that the modulus for each shard will be greater than the largest
122    // possible (non-modular) sum of the inputs to that shard.  That is,
123    // assuming each client has input on range [0, base_modulus), the result
124    // will be identical to non-modular addition (i.e. federated_sum).
125    //
126    // While any m >= (1 + shard_size * (base_modulus - 1)), the current
127    // implementation takes
128    // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the
129    // smallest possible value of m that is also a power of 2.  This choice is
130    // made because (a) it uses the same number of bits per vector entry as
131    // valid smaller m, using the current on-the-wire encoding scheme, and (b)
132    // it enables the underlying mask-generation PRNG to run in its most
133    // computationally efficient mode, which can be up to 2x faster.
134    message ModulusTimesShardSize {
135      uint64 base_modulus = 1;
136    }
137
138    oneof modulus_scheme {
139      // Bitwidth of the aggregand.
140      //
141      // This is the bitwidth of an input value (i.e. the bitwidth that
142      // quantization should target).  The Secure Aggregation bitwidth (i.e.,
143      // the bitwidth of the *sum* of the input values) will be a function of
144      // this bitwidth and the number of participating clients, as negotiated
145      // with the server when the protocol is initiated.
146      //
147      // Deprecated; prefer fixed_modulus instead.
148      int32 quantized_input_bitwidth = 2 [deprecated = true];
149
150      FixedModulus fixed_modulus = 5;
151      ModulusTimesShardSize modulus_times_shard_size = 6;
152    }
153
154    reserved 1;
155  }
156
157  // What type of side channel is used.
158  oneof type {
159    SecureAggregand secure_aggregand = 1;
160  }
161
162  // When restoring the name of the tensor to restore to.  This is the name
163  // (key) supplied in the feed_dict in the before_restore_op in order to
164  // restore the tensor provided by the side channel (which will be the
165  // value in the feed_dict).
166  string restore_name = 2;
167}
168
169// Container for a metric used by the internal toolkit.
170message Metric {
171  // Name of an Op to run to read the value.
172  string variable_name = 1;
173
174  // A human-readable name for the statistic. Metric names are usually
175  // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
176  // Must be 7-bit ASCII and under 122 characters.
177  string stat_name = 2;
178
179  // The human-readable name of another metric by which this metric should be
180  // normalized, if any. If empty, this Metric should be aggregated with simple
181  // summation. If not empty, the Metric is aggregated according to
182  // weighted_metric_sum = sum_i (metric_i * weight_i)
183  // weight_sum = sum_i weight_i
184  // average_metric_value = weighted_metric_sum / weight_sum
185  string weight_name = 3;
186}
187
188// Controls the format of output metrics users receive. Represents instructions
189// for how metrics are to be output to users, controlling the end format of
190// the metric users receive.
191message OutputMetric {
192  // Metric name.
193  string name = 1;
194
195  oneof value_source {
196    // A metric representing one stat with aggregation type sum.
197    SumOptions sum = 2;
198
199    // A metric representing a ratio between metrics with aggregation
200    // type average.
201    AverageOptions average = 3;
202
203    // A metric that is not aggregated by the MetricReportAggregator or
204    // metrics_loader. This includes metrics like 'num_server_updates' that are
205    // aggregated in TensorFlow.
206    NoneOptions none = 4;
207
208    // A metric representing one stat with aggregation type only sample.
209    // Samples at most 101 clients' values.
210    OnlySampleOptions only_sample = 5;
211  }
212  // Iff True, the metric will be plotted in the default view of the
213  // task level Colab automatically.
214  oneof visualization_info {
215    bool auto_plot = 6 [deprecated = true];
216    VisualizationSpec plot_spec = 7;
217  }
218}
219
220message VisualizationSpec {
221  // Different allowable plot types.
222  enum VisualizationType {
223    NONE = 0;
224    DEFAULT_PLOT_FOR_TASK_TYPE = 1;
225    LINE_PLOT = 2;
226    LINE_PLOT_WITH_PERCENTILES = 3;
227    HISTOGRAM = 4;
228  }
229
230  // Defines the plot type to provide downstream.
231  VisualizationType plot_type = 1;
232
233  // The x-axis which to provide for the given metric. Must be the name of a
234  // metric or counter. Recommended x_axis options are source_round, round,
235  // or time.
236  string x_axis = 2;
237
238  // Iff True, metric will be displayed on a population level dashboard.
239  bool plot_on_population_dashboard = 3;
240}
241
242// A metric representing one stat with aggregation type sum.
243message SumOptions {
244  // Name for corresponding Metric stat_name field.
245  string stat_name = 1;
246
247  // Iff True, a cumulative sum over rounds will be provided in addition to a
248  // sum per round for the value metric.
249  bool include_cumulative_sum = 2;
250
251  // Iff True, sample of at most 101 clients' values.
252  // Used to calculate quantiles in downstream visualization pipeline.
253  bool include_client_samples = 3;
254}
255
256// A metric representing a ratio between metrics with aggregation type average.
257// Represents: numerator stat / denominator stat.
258message AverageOptions {
259  // Numerator stat name pointing to corresponding Metric stat_name.
260  string numerator_stat_name = 1;
261
262  // Denominator stat name pointing to corresponding Metric stat_name.
263  string denominator_stat_name = 2;
264
265  // Name for corresponding Metric stat_name that is the ratio of the
266  // numerator stat / denominator stat.
267  string average_stat_name = 3;
268
269  // Iff True, sample of at most 101 client's values.
270  // Used to calculate quantiles in downstream visualization pipeline.
271  bool include_client_samples = 4;
272}
273
274// A metric representing one stat with aggregation type none.
275message NoneOptions {
276  // Name for corresponding Metric stat_name field.
277  string stat_name = 1;
278}
279
280// A metric representing one stat with aggregation type only sample.
281message OnlySampleOptions {
282  // Name for corresponding Metric stat_name field.
283  string stat_name = 1;
284}
285
286// Represents a data set. This is used for testing.
287message Dataset {
288  // Represents the data set for one client.
289  message ClientDataset {
290    // A string identifying the client.
291    string client_id = 1;
292
293    // A list of serialized tf.Example protos.
294    repeated bytes example = 2;
295
296    // Represents a dataset whose examples are selected by an ExampleSelector.
297    message SelectedExample {
298      ExampleSelector selector = 1;
299      repeated bytes example = 2;
300    }
301
302    // A list of (selector, dataset) pairs. Used in testing some *TFF-based
303    // tasks* that require multiple datasets as client input, e.g., a TFF-based
304    // personalization eval task requires each client to provide at least two
305    // datasets: one for train, and the other for test.
306    repeated SelectedExample selected_example = 3;
307  }
308
309  // A list of client data.
310  repeated ClientDataset client_data = 1;
311}
312
313// Represents predicates over metrics - i.e., expectations. This is used in
314// training/eval tests to encode metric names and values expected to be reported
315// by a client execution.
316message MetricTestPredicates {
317  // The value must lie in [lower_bound; upper_bound]. Can also be used for
318  // approximate matching (lower == value - epsilon; upper = value + epsilon).
319  message Interval {
320    double lower_bound = 1;
321    double upper_bound = 2;
322  }
323
324  // The value must be a real value as long as the value of the weight_name
325  // metric is non-zero. If the weight metric is zero, then it is acceptable for
326  // the value to be non-real.
327  message RealIfNonzeroWeight {
328    string weight_name = 1;
329  }
330
331  message MetricCriterion {
332    // Name of the metric.
333    string name = 1;
334
335    // FL training round this metric is expected to appear in.
336    int32 training_round_index = 2;
337
338    // If none of the following is set, no matching is performed; but the
339    // metric is still expected to be present (with whatever value).
340    oneof Criterion {
341      // The reported metric must be < lt.
342      float lt = 3;
343      // The reported metric must be > gt.
344      float gt = 4;
345      // The reported metric must be <= le.
346      float le = 5;
347      // The reported metric must be >= ge.
348      float ge = 6;
349      // The reported metric must be == eq.
350      float eq = 7;
351      // The reported metric must lie in the interval.
352      Interval interval = 8;
353      // The reported metric is not NaN or +/- infinity.
354      bool real = 9;
355      // The reported metric is real (i.e., not NaN or +/- infinity) if the
356      // value of an associated weight is not 0.
357      RealIfNonzeroWeight real_if_nonzero_weight = 10;
358    }
359  }
360
361  repeated MetricCriterion metric_criterion = 1;
362
363  reserved 2;
364}
365
366// Client Phase
367// ============
368
369// A `TensorflowSpec` that is executed on the client in a single `tf.Session`.
370// In federated optimization, this will correspond to one `ServerPhase`.
371message ClientPhase {
372  // A short CamelCase name for the ClientPhase.
373  string name = 2;
374
375  // Minimum number of clients in aggregation.
376  // In secure aggregation mode this is used to configure the protocol instance
377  // in a way that server can't learn aggregated values with number of
378  // participants lower than this number.
379  // Without secure aggregation server still respects this parameter,
380  // ensuring that aggregated values never leave server RAM unless they include
381  // data from (at least) specified number of participants.
382  int32 minimum_number_of_participants = 3;
383
384  // If populated, `io_router` must be specified.
385  oneof spec {
386    // A functional interface for the TensorFlow logic the client should
387    // perform.
388    TensorflowSpec tensorflow_spec = 4 [lazy = true];
389    // Spec for client plans that issue example queries and send the query
390    // results directly to an aggregator with no or little additional
391    // processing.
392    ExampleQuerySpec example_query_spec = 9 [lazy = true];
393  }
394
395  // The specification of the inputs coming either from customer apps
396  // (Local Compute) or the federated protocol (Federated Compute).
397  oneof io_router {
398    FederatedComputeIORouter federated_compute = 5 [lazy = true];
399    LocalComputeIORouter local_compute = 6 [lazy = true];
400    FederatedComputeEligibilityIORouter federated_compute_eligibility = 7
401        [lazy = true];
402    FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true];
403  }
404
405  reserved 1;
406}
407
408// TensorflowSpec message describes a single call into TensorFlow, including the
409// expected input tensors that must be fed when making that call, which
410// output tensors to be fetched, and any operations that have no output but must
411// be run. The TensorFlow session will then use the input tensors to do some
412// computation, generally reading from one or more datasets, and provide some
413// outputs.
414//
415// Conceptually, client or server code uses this proto along with an IORouter
416// to build maps of names to input tensors, vectors of output tensor names,
417// and vectors of target nodes:
418//
419//    CreateTensorflowArguments(
420//        TensorflowSpec& spec,
421//        IORouter& io_router,
422//        const vector<pair<string, Tensor>>* input_tensors,
423//        const vector<string>* output_tensor_names,
424//        const vector<string>* target_node_names);
425//
426// Where `input_tensor`, `output_tensor_names` and `target_node_names`
427// correspond to the arguments of TensorFlow C++ API for
428// `tensorflow::Session:Run()`, and the client executes only a single
429// invocation.
430//
431// Note: the execution engine never sees any concepts related to the federated
432// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors
433// in, tensors out" interface. New aggregation methods can be added without
434// having to modify the execution engine / TensorflowSpec message, instead they
435// should modify the IORouter messages.
436//
437// Note: both `input_tensor_specs` and `output_tensor_specs` are full
438// `tensorflow.TensorSpecProto` messages, though TensorFlow technically
439// only requires the names to feed the values into the session. The additional
440// dtypes/shape information must always be included in case the runtime
441// executing this TensorflowSpec wants to perform additional, optional static
442// assertions. The runtimes however are free to ignore the dtype/shapes and only
443// rely on the names if so desired.
444//
445// Assertions:
446//   - all names in `input_tensor_specs`, `output_tensor_specs`, and
447//     `target_node_names` must appear in the serialized GraphDef where
448//     the TF execution will be invoked.
449//   - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise
450//     there is nothing to execute in the graph.
451message TensorflowSpec {
452  // The name of a tensor into which a unique token for the current session
453  // should be written. The corresponding tensor is a scalar string tensor and
454  // is separate from `input_tensors` as there is only one.
455  //
456  // A session token allows TensorFlow ops such as `ServeSlices` or
457  // `ExternalDataset` to refer to callbacks and other session-global objects
458  // registered before running the session. In the `ExternalDataset` case, a
459  // single dataset_token is valid for multiple `tf.data.Dataset` objects as
460  // the token can be thought of as a handle to a dataset factory.
461  string dataset_token_tensor_name = 1;
462
463  // TensorSpecs of inputs which will be passed to TF.
464  //
465  // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in
466  // TensorFlow's Python API, excluding the dataset_token listed above.
467  //
468  // Assertions:
469  //   - All the tensor names designated as inputs in the corresponding IORouter
470  //     must be listed (otherwise the IORouter input work is unused).
471  //   - All placeholders in the TF graph must be listed here, with the
472  //     exception of the dataset_token which is explicitly set above (otherwise
473  //     TensorFlow will fail to execute).
474  repeated tensorflow.TensorSpecProto input_tensor_specs = 2;
475
476  // TensorSpecs that should be fetched from TF after execution.
477  //
478  // Corresponds to the `fetches` parameter of `tf.Session.run()` in
479  // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++
480  // API.
481  //
482  // Assertions:
483  //   - The set of tensor names here must strictly match the tensor names
484  //     designated as outputs in the corresponding IORouter (if any exist).
485  repeated tensorflow.TensorSpecProto output_tensor_specs = 3;
486
487  // Node names in the graph that should be executed, but the output not
488  // returned.
489  //
490  // Corresponds to the `fetches` parameter of `tf.Session.run()` in
491  // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++
492  // API.
493  //
494  // This is intended for use with operations that do not produce tensors, but
495  // nonetheless are required to run (e.g. serializing checkpoints).
496  repeated string target_node_names = 4;
497
498  // Map of Tensor names to constant inputs.
499  // Note: tensors specified via this message should not be included in
500  // input_tensor_specs.
501  map<string, tensorflow.TensorProto> constant_inputs = 5;
502}
503
504// ExampleQuerySpec message describes client execution that issues example
505// queries and sends the query results directly to an aggregator with no or
506// little additional processing.
507// This message describes one or more example store queries that perform the
508// client side analytics computation in C++. The corresponding output vectors
509// will be converted into the expected federated protocol output format.
510// This must be used in conjunction with the `FederatedExampleQueryIORouter`.
511message ExampleQuerySpec {
512  message OutputVectorSpec {
513    // The output vector name.
514    string vector_name = 1;
515
516    // Supported data types for the vector of information.
517    enum DataType {
518      UNSPECIFIED = 0;
519      INT32 = 1;
520      INT64 = 2;
521      BOOL = 3;
522      FLOAT = 4;
523      DOUBLE = 5;
524      BYTES = 6;
525      STRING = 7;
526    }
527
528    // The data type for each entry in the vector.
529    DataType data_type = 2;
530  }
531
532  message ExampleQuery {
533    // The `ExampleSelector` to issue the query with.
534    ExampleSelector example_selector = 1;
535
536    // Indicates that the query returns vector data and must return a single
537    // ExampleQueryResult result containing a VectorData entry matching each
538    // OutputVectorSpec.vector_name.
539    //
540    // If the query instead returns no result, then it will be treated as is if
541    // an error was returned. In that case, or if the query explicitly returns
542    // an error, then the client will abort its session.
543    //
544    // The keys in the map are the names the vectors should be aggregated under,
545    // and must match the keys in FederatedExampleQueryIORouter.aggregations.
546    map<string, OutputVectorSpec> output_vector_specs = 2;
547  }
548
549  // The queries to run.
550  repeated ExampleQuery example_queries = 1;
551}
552
553// The input and output router for Federated Compute plans.
554//
555// This proto is the glue between the federated protocol and the TensorFlow
556// execution engine. This message describes how to prepare data coming from the
557// incoming `CheckinResponse` (defined in
558// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what
559// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on
560// the server).
561//
562// TODO(team) we could replace `input_checkpoint_file_tensor_name` with
563// an `input_tensors` field, which would then be a tensor that contains the
564// input TensorProtos directly and skipping disk I/O, rather than referring to a
565// checkpoint file path.
566message FederatedComputeIORouter {
567  // ===========================================================================
568  //   Inputs
569  // ===========================================================================
570  // The name of the scalar string tensor that is fed the file path to the
571  // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint).
572  //
573  // The federated protocol code would copy the `CheckinResponse`'s initial
574  // checkpoint to a temporary file and then pass that file path through this
575  // tensor.
576  //
577  // Ops may be added to the client graph that take this tensor as input and
578  // reads the path.
579  //
580  // This field is optional. It may be omitted if the client graph does not use
581  // an initial checkpoint.
582  string input_filepath_tensor_name = 1;
583
584  // The name of the scalar string tensor that is fed the file path to which
585  // client work should serialize the bytes to send back to the server.
586  //
587  // The federated protocol code generates a temporary file and passes the file
588  // path through this tensor.
589  //
590  // Ops may be be added to the client graph that use this tensor as an argument
591  // to write files (e.g. writing checkpoints to disk).
592  //
593  // This field is optional. It must be omitted if the client graph does not
594  // generate any output files (e.g. when all output tensors of `TensorflowSpec`
595  // use Secure Aggregation). If this field is not set, then the `ReportRequest`
596  // message in the federated protocol will not have the
597  // `Report.update_checkpoint` field set. This absence of a value here can be
598  // used to validate that the plan only uses Secure Aggregation.
599  //
600  // Conversely, if this field is set and executing the associated
601  // TensorflowSpec does not write to the path is indication of an internal
602  // framework error. The runtime should notify the caller that the computation
603  // was setup incorrectly.
604  string output_filepath_tensor_name = 2;
605
606  // ===========================================================================
607  //   Outputs
608  // ===========================================================================
609  // Describes which output tensors should be aggregated using an aggregation
610  // protocol, and the configuration for those protocols.
611  //
612  // Assertions:
613  //   - All keys must exist in the associated `TensorflowSpec` as
614  //     `output_tensor_specs.name` values.
615  map<string, AggregationConfig> aggregations = 3;
616}
617
618// The input and output router for client plans that do not use TensorFlow.
619//
620// This proto is the glue between the federated protocol and the example query
621// execution engine, describing how the query results should ultimately be
622// aggregated.
623message FederatedExampleQueryIORouter {
624  // Describes how each output vector should be aggregated using an aggregation
625  // protocol, and the configuration for those protocols.
626  // Keys must match the keys in ExampleQuerySpec.output_vector_specs.
627  // Note that currently only the TFV1CheckpointAggregation config is supported.
628  map<string, AggregationConfig> aggregations = 1;
629}
630
631// The specification for how to aggregate the associated tensor across clients
632// on the server.
633message AggregationConfig {
634  oneof protocol_config {
635    // Indicates that the given output tensor should be processed using Secure
636    // Aggregation, using the specified config options.
637    SecureAggregationConfig secure_aggregation = 2;
638
639    // Note: in the future we could add a `SimpleAggregationConfig` to add
640    // support for simple aggregation without writing to an intermediate
641    // checkpoint file first.
642
643    // Indicates that the given output tensor or vector (e.g. as produced by an
644    // ExampleQuerySpec) should be placed in an output TF v1 checkpoint.
645    //
646    // Currently only ExampleQuerySpec output vectors are supported by this
647    // aggregation type (i.e. it cannot be used with TensorflowSpec output
648    // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
649    // its corresponding data type.
650    TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3;
651  }
652}
653
654// Parameters for the SecAgg protocol (go/secagg).
655//
656// Currently only the server uses the SecAgg parameters, so we only use this
657// message to signify usage of SecAgg.
658message SecureAggregationConfig {}
659
660// Parameters for the TFV1 Checkpoint Aggregation protocol.
661//
662// Currently only ExampleQuerySpec output vectors are supported by this
663// aggregation type (i.e. it cannot be used with TensorflowSpec output
664// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
665// its corresponding data type.
666message TFV1CheckpointAggregation {}
667
668// The input and output router for eligibility-computing plans. These plans
669// compute which other plans a client is eligible to run, and are returned by
670// clients via a `EligibilityEvalCheckinResponse` (defined in
671// fcp/protos/federated_api.proto).
672message FederatedComputeEligibilityIORouter {
673  // The name of the scalar string tensor that is fed the file path to the
674  // initial checkpoint (e.g. as provided via
675  // `EligibilityEvalPayload.init_checkpoint`).
676  //
677  // For more detail see the
678  // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same
679  // semantics.
680  //
681  // This field is optional. It may be omitted if the client graph does not use
682  // an initial checkpoint.
683  //
684  // This tensor name must exist in the associated
685  // `TensorflowSpec.input_tensor_specs` list.
686  string input_filepath_tensor_name = 1;
687
688  // Name of the output tensor (a string scalar) containing the serialized
689  // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The
690  // client code will parse this proto and place it in the
691  // `task_eligibility_info` field of the subsequent `CheckinRequest`.
692  //
693  // This tensor name must exist in the associated
694  // `TensorflowSpec.output_tensor_specs` list.
695  string task_eligibility_info_tensor_name = 2;
696}
697
698// The input and output router for Local Compute plans.
699//
700// This proto is the glue between the customers app and the TensorFlow
701// execution engine. This message describes how to prepare data coming from the
702// customer app (e.g. the input directory the app setup), and the temporary,
703// scratch output directory that will be notified to the customer app upon
704// completion of `TensorflowSpec`.
705message LocalComputeIORouter {
706  // ===========================================================================
707  //   Inputs
708  // ===========================================================================
709  // The name of the placeholder tensor representing the input resource path(s).
710  // It can be a single input directory or file path (in this case the
711  // `input_dir_tensor_name` is populated) or multiple input resources
712  // represented as a map from names to input directories or file paths (in this
713  // case the `multiple_input_resources` is populated).
714  //
715  // In the multiple input resources case, the placeholder tensors are
716  // represented as a map: the keys are the input resource names defined by the
717  // users when constructing the `LocalComputation` Python object, and the
718  // values are the corresponding placeholder tensor names created by the local
719  // computation plan builder.
720  //
721  // Apps will have the ability to create contracts between their Android code
722  // and `LocalComputation` toolkit code to place files inside the input
723  // resource paths with known names (Android code) and create graphs with ops
724  // to read from these paths (file names can be specified in toolkit code).
725  oneof input_resource {
726    string input_dir_tensor_name = 1;
727    // Directly using the `map` field is not allowed in `oneof`, so we have to
728    // wrap it in a new message.
729    MultipleInputResources multiple_input_resources = 3;
730  }
731
732  // Scalar string tensor name that will contain the output directory path.
733  //
734  // The provided directory should be considered temporary scratch that will be
735  // deleted, not persisted. It is the responsibility of the calling app to
736  // move the desired files to a permanent location once the client returns this
737  // directory back to the calling app.
738  string output_dir_tensor_name = 2;
739
740  // ===========================================================================
741  //   Outputs
742  // ===========================================================================
743  // NOTE: LocalCompute has no outputs other than what the client graph writes
744  // to `output_dir` specified above.
745}
746
747// Describes the multiple input resources in `LocalComputeIORouter`.
748message MultipleInputResources {
749  // The keys are the input resource names (defined by the users when
750  // constructing the `LocalComputation` Python object), and the values are the
751  // corresponding placeholder tensor names created by the local computation
752  // plan builder.
753  map<string, string> input_resource_tensor_name_map = 1;
754}
755
756// Describes a queue to which input is fed.
757message AsyncInputFeed {
758  // The op for enqueuing an example input.
759  string enqueue_op = 1;
760
761  // The input placeholders for the enqueue op.
762  repeated string enqueue_params = 2;
763
764  // The op for closing the input queue.
765  string close_op = 3;
766
767  // Whether the work that should be fed asynchronously is the data itself
768  // or a description of where that data lives.
769  bool feed_values_are_data = 4;
770}
771
772message DatasetInput {
773  // Initializer of iterator corresponding to tf.data.Dataset object which
774  // handles the input data. Stores name of an op in the graph.
775  string initializer = 1;
776
777  // Placeholders necessary to initialize the dataset.
778  DatasetInputPlaceholders placeholders = 2;
779
780  // Batch size to be used in tf.data.Dataset.
781  int32 batch_size = 3;
782}
783
784message DatasetInputPlaceholders {
785  // Name of placeholder corresponding to filename(s) of SSTable(s) to read data
786  // from.
787  string filename = 1;
788
789  // Name of placeholder corresponding to key_prefix initializing the
790  // SSTableDataset. Note the value fed should be unique user id, not a prefix.
791  string key_prefix = 2;
792
793  // Name of placeholder corresponding to number of rounds the local training
794  // should be run for.
795  string num_epochs = 3;
796
797  // Name of placeholder corresponding to batch size.
798  string batch_size = 4;
799}
800
801// Specifies an example selection procedure.
802message ExampleSelector {
803  // Selection criteria following a contract agreed upon between client and
804  // model designers.
805  google.protobuf.Any criteria = 1;
806
807  // A URI identifying the example collection to read from. Format should adhere
808  // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments
809  // should adhere to the following rules:
810  // - The scheme ${COLLECTION} should be one of:
811  //   - "app" for app-hosted example
812  //   - "simulation" for collections not connected to an app (e.g., if used
813  //     purely for simulation)
814  // - The authority ${APP_NAME} identifies the owner of the example
815  //   collection and should be either the app's package name, or be left empty
816  //   (which means "the current app package name").
817  // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with
818  //   a forward slash ("/").
819  // - The query and fragment are currently not used, but they may become used
820  //   for something in the future. To keep open that possibility they must
821  //   currently be left empty.
822  //
823  // Example: "app://com.google.some.app/someCollection/name"
824  // identifies the collection "/someCollection/name" owned and hosted by the
825  // app with package name "com.google.some.app".
826  //
827  // Example: "app:/someCollection/name" or "app:///someCollection/name"
828  // both identify the collection "/someCollection/name" owned and hosted by the
829  // app associated with the training job in which this URI appears.
830  //
831  // The path will not be interpreted by the runtime, and will be passed to the
832  // example collection implementation for interpretation. Thus, in the case of
833  // app-hosted example stores, the path segment's interpretation is a contract
834  // between the app's example store developers, and the app's model designers.
835  //
836  // If an `app://` URI is set, then the `TrainerOptions` collection name must
837  // not be set.
838  string collection_uri = 2;
839
840  // Resumption token following a contract agreed upon between client and
841  // model designers.
842  google.protobuf.Any resumption_token = 3;
843}
844
845// Selector for slices to fetch as part of a `federated_select` operation.
846message SlicesSelector {
847  // The string ID under which the slices are served.
848  //
849  // This value must have been returned by a previous call to the `serve_slices`
850  // op run during the `write_client_init` operation.
851  string served_at_id = 1;
852
853  // The indices of slices to fetch.
854  repeated int32 keys = 2;
855}
856
857// Represents slice data to be served as part of a `federated_select` operation.
858// This is used for testing.
859message SlicesTestDataset {
860  // The test data to use. The keys map to the `SlicesSelector.served_at_id`
861  // field.  E.g. test slice data for a slice with `served_at_id`="foo" and
862  // `keys`=2 would be store in `dataset["foo"].slice_data[2]`.
863  map<string, SlicesTestData> dataset = 1;
864}
865message SlicesTestData {
866  // The test slice data to serve. Each entry's index corresponds to the slice
867  // key it is the test data for.
868  repeated bytes slice_data = 2;
869}
870
871// Server Phase V2
872// ===============
873
874// Represents a server phase with three distinct components: pre-broadcast,
875// aggregation, and post-aggregation.
876//
877// The pre-broadcast and post-aggregation components are described with
878// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec
879// messages, respectively. These messages in combination with the server
880// IORouter messages specify how to set up a single TF sess.run call for each
881// component.
882//
883// The pre-broadcast logic is obtained by transforming the server_prepare TFF
884// computation in the DistributeAggregateForm. It takes the server state as
885// input, and it generates the checkpoint to broadcast to the clients and
886// potentially an intermediate server state. The intermediate server state may
887// be used by the aggregation and post-aggregation logic.
888//
889// The aggregation logic represents the aggregation of client results at the
890// server and is described using a list of ServerAggregationConfig messages.
891// Each ServerAggregationConfig message describes a single aggregation operation
892// on a set of input/output tensors. The input tensors may represent parts of
893// either the client results or the intermediate server state. These messages
894// are obtained by transforming the client_to_server_aggregation TFF computation
895// in the DistributeAggregateForm.
896//
897// The post-aggregation logic is obtained by transforming the server_result TFF
898// computation in the DistributeAggregateForm. It takes the intermediate server
899// state and the aggregated client results as input, and it generates the new
900// server state and potentially other server-side output.
901//
902// Note that while a ServerPhaseV2 message can be generated for all types of
903// intrinsics, it is currently only compatible with the ClientPhase message if
904// the aggregations being used are exclusively federated_sum (not SecAgg). If
905// this compatibility requirement is satisfied, it is also valid to run the
906// aggregation portion of this ServerPhaseV2 message alongside the pre- and
907// post-aggregation logic from the original ServerPhase message. Ultimately,
908// we expect the full ServerPhaseV2 message to be run and the ServerPhase
909// message to be deprecated.
910message ServerPhaseV2 {
911  // A short CamelCase name for the ServerPhaseV2.
912  string name = 1;
913
914  // A functional interface for the TensorFlow logic the server should perform
915  // prior to the server-to-client broadcast. This should be used with the
916  // TensorFlow graph defined in server_graph_prepare_bytes.
917  TensorflowSpec tensorflow_spec_prepare = 3;
918
919  // The specification of inputs needed by the server_prepare TF logic.
920  oneof server_prepare_io_router {
921    ServerPrepareIORouter prepare_router = 4;
922  }
923
924  // A list of client-to-server aggregations to perform.
925  repeated ServerAggregationConfig aggregations = 2;
926
927  // A functional interface for the TensorFlow logic the server should perform
928  // post-aggregation. This should be used with the TensorFlow graph defined
929  // in server_graph_result_bytes.
930  TensorflowSpec tensorflow_spec_result = 5;
931
932  // The specification of inputs and outputs needed by the server_result TF
933  // logic.
934  oneof server_result_io_router {
935    ServerResultIORouter result_router = 6;
936  }
937}
938
939// Routing for server_prepare graph
940message ServerPrepareIORouter {
941  // The name of the scalar string tensor in the server_prepare TF graph that
942  // is fed the filepath to the initial server state checkpoint. The
943  // server_prepare logic reads from this filepath.
944  string prepare_server_state_input_filepath_tensor_name = 1;
945
946  // The name of the scalar string tensor in the server_prepare TF graph that
947  // is fed the filepath where the client checkpoint should be stored. The
948  // server_prepare logic writes to this filepath.
949  string prepare_output_filepath_tensor_name = 2;
950
951  // The name of the scalar string tensor in the server_prepare TF graph that
952  // is fed the filepath where the intermediate state checkpoint should be
953  // stored. The server_prepare logic writes to this filepath. The intermediate
954  // state checkpoint will be consumed by both the logic used to set parameters
955  // for aggregation and the post-aggregation logic.
956  string prepare_intermediate_state_output_filepath_tensor_name = 3;
957}
958
959// Routing for server_result graph
960message ServerResultIORouter {
961  // The name of the scalar string tensor in the server_result TF graph that is
962  // fed the filepath to the intermediate state checkpoint. The server_result
963  // logic reads from this filepath.
964  string result_intermediate_state_input_filepath_tensor_name = 1;
965
966  // The name of the scalar string tensor in the server_result TF graph that is
967  // fed the filepath to the aggregated client result checkpoint. The
968  // server_result logic reads from this filepath.
969  string result_aggregate_result_input_filepath_tensor_name = 2;
970
971  // The name of the scalar string tensor in the server_result TF graph that is
972  // fed the filepath where the updated server state should be stored. The
973  // server_result logic writes to this filepath.
974  string result_server_state_output_filepath_tensor_name = 3;
975}
976
977// Represents a single aggregation operation, combining one or more input
978// tensors from a collection of clients into one or more output tensors on the
979// server.
980message ServerAggregationConfig {
981  // The uri of the aggregation intrinsic (e.g. 'federated_sum').
982  string intrinsic_uri = 1;
983
984  // Describes an argument to the aggregation operation.
985  message IntrinsicArg {
986    oneof arg {
987      // Refers to a tensor within the checkpoint provided by each client.
988      tensorflow.TensorSpecProto input_tensor = 2;
989
990      // Refers to a tensor within the intermediate server state checkpoint.
991      tensorflow.TensorSpecProto state_tensor = 3;
992    }
993  }
994
995  // List of arguments for the aggregation operation. The arguments can be
996  // dependent on client data (in which case they must be retrieved from
997  // clients) or they can be independent of client data (in which case they
998  // can be configured server-side). For now we assume all client-independent
999  // arguments are constants. The arguments must be in the order expected by
1000  // the server.
1001  repeated IntrinsicArg intrinsic_args = 4;
1002
1003  // List of server-side outputs produced by the aggregation operation.
1004  repeated tensorflow.TensorSpecProto output_tensors = 5;
1005}
1006
1007// Server Phase
1008// ============
1009
1010// Represents a server phase which implements TF-based aggregation of multiple
1011// client updates.
1012//
1013// There are two different modes of aggregation that are described
1014// by the values in this message. The first is aggregation that is
1015// coming from coordinated sets of clients. This includes aggregation
1016// done via checkpoints from clients or aggregation done over a set
1017// of clients by a process like secure aggregation. The results of
1018// this first aggregation are saved to intermediate aggregation
1019// checkpoints. The second aggregation then comes from taking
1020// these intermediate checkpoints and aggregating over them.
1021//
1022// These two different modes of aggregation are done on different
1023// servers, the first in the 'L1' servers and the second in the
1024// 'L2' servers, so we use this nomenclature to describe these
1025// phases below.
1026//
1027// The ServerPhase message is currently in the process of being replaced by the
1028// ServerPhaseV2 message as we switch the plan building pipeline to use
1029// DistributeAggregateForm instead of MapReduceForm. During the migration
1030// process, we may generate both messages and use components from either
1031// message during execution.
1032//
1033message ServerPhase {
1034  // A short CamelCase name for the ServerPhase.
1035  string name = 8;
1036
1037  // ===========================================================================
1038  // L1 "Intermediate" Aggregation.
1039  //
1040  // This is the initial aggregation that creates partial aggregates from client
1041  // results. L1 Aggregation may be run on many different instances.
1042  //
1043  // Pre-condition:
1044  //   The execution environment has loaded the graph from `server_graph_bytes`.
1045
1046  // 1. Initialize the phase.
1047  //
1048  // Operation to run before the first aggregation happens.
1049  // For instance, clears the accumulators so that a new aggregation can begin.
1050  string phase_init_op = 1;
1051
1052  // 2. For each client in set of clients:
1053  //   a. Restore variables from the client checkpoint.
1054  //
1055  // Loads a checkpoint from a single client written via
1056  // `FederatedComputeIORouter.output_filepath_tensor_name`.  This is done once
1057  // for every client checkpoint in a round.
1058  CheckpointOp read_update = 3;
1059  //   b. Aggregate the data coming from the client checkpoint.
1060  //
1061  // An operation that aggregates the data from read_update.
1062  // Generally this will add to accumulators and it may leverage internal data
1063  // inside the graph to adjust the weights of the Tensors.
1064  //
1065  // Executed once for each `read_update`, to (for example) update accumulator
1066  // variables using the values loaded during `read_update`.
1067  string aggregate_into_accumulators_op = 4;
1068
1069  // 3. After all clients have been aggregated, possibly restore
1070  //    variables that have been aggregated via a separate process.
1071  //
1072  // Optionally restores variables where aggregation is done across
1073  // an entire round of client data updates. In contrast to `read_update`,
1074  // which restores once per client, this occurs after all clients
1075  // in a round have been processed. This allows, for example, side
1076  // channels where aggregation is done by a separate process (such
1077  // as in secure aggregation), in which the side channel aggregated
1078  // tensor is passed to the `before_restore_op` which ensure the
1079  // variables are restored properly. The `after_restore_op` will then
1080  // be responsible for performing the accumulation.
1081  //
1082  // Note that in current use this should not have a SaverDef, but
1083  // should only be used for side channels.
1084  CheckpointOp read_aggregated_update = 10;
1085
1086  // 4. Write the aggregated variables to an intermediate checkpoint.
1087  //
1088  // We require that `aggregate_into_accumulators_op` is associative and
1089  // commutative, so that the aggregates can be computed across
1090  // multiple TensorFlow sessions.
1091  // As an example, say we are computing the sum of 5 client updates:
1092  //    A = X1 + X2 + X3 + X4 + X5
1093  // We can always do this in one session by calling `read_update`j and
1094  // `aggregate_into_accumulators_op` once for each client checkpoint.
1095  //
1096  // Alternatively, we could compute:
1097  //    A1 = X1 + X2 in one TensorFlow session, and
1098  //    A2 = X3 + X4 + X5 in a different session.
1099  // Each of these sessions can then write their accumulator state
1100  // with the `write_intermediate_update` CheckpointOp, and a yet another third
1101  // session can then call `read_intermediate_update` and
1102  // `aggregate_into_accumulators_op` on each of these checkpoints to compute:
1103  //    A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5).
1104  CheckpointOp write_intermediate_update = 7;
1105  // End L1 "Intermediate" Aggregation.
1106  // ===========================================================================
1107
1108  // ===========================================================================
1109  // L2 Aggregation and Coordinator.
1110  //
1111  // This aggregates intermediate checkpoints from L1 Aggregation and performs
1112  // the finalizing of the update. Unlike L1 there will only be one instance
1113  // that does this aggregation.
1114
1115  // Pre-condition:
1116  //   The execution environment has loaded the graph from `server_graph_bytes`
1117  //   and restored the global model using `server_savepoint` from the parent
1118  //   `Plan` message.
1119
1120  // 1. Initialize the phase.
1121  //
1122  // This currently re-uses the `phase_init_op` from L1 aggregation above.
1123
1124  // 2. Write a checkpoint that can be sent to the client.
1125  //
1126  // Generates a checkpoint to be sent to the client, to be read by
1127  // `FederatedComputeIORouter.input_filepath_tensor_name`.
1128
1129  CheckpointOp write_client_init = 2;
1130
1131  // 3. For each intermediate checkpoint:
1132  //   a. Restore variables from the intermediate checkpoint.
1133  //
1134  // The corresponding read checkpoint op to the write_intermediate_update.
1135  // This is used instead of read_update for intermediate checkpoints because
1136  // the format of these updates may be different than those used in updates
1137  // from clients (which may, for example, be compressed).
1138  CheckpointOp read_intermediate_update = 9;
1139  //   b. Aggregate the data coming from the intermediate checkpoint.
1140  //
1141  // An operation that aggregates the data from `read_intermediate_update`.
1142  // Generally this will add to accumulators and it may leverage internal data
1143  // inside the graph to adjust the weights of the Tensors.
1144  string intermediate_aggregate_into_accumulators_op = 11;
1145
1146  // 4. Write the aggregated intermediate variables to a checkpoint.
1147  //
1148  // This is used for downstream, cross-round aggregation of metrics.
1149  // These variables will be read back into a session with
1150  // read_intermediate_update.
1151  //
1152  // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def
1153  // to disable writing accumulator checkpoints.
1154  CheckpointOp write_accumulators = 12;
1155
1156  // 5. Finalize the round.
1157  //
1158  // This can include:
1159  // - Applying the update aggregated from the intermediate checkpoints to the
1160  //   global model and other updates to cross-round state variables.
1161  // - Computing final round metric values (e.g. the `report` of a
1162  //   `tff.federated_aggregate`).
1163  string apply_aggregrated_updates_op = 5;
1164
1165  // 5. Fetch the server aggregated metrics.
1166  //
1167  // A list of names of metric variables to fetch from the TensorFlow session.
1168  repeated Metric metrics = 6;
1169
1170  // 6. Serialize the updated server state (e.g. the coefficients of the global
1171  //    model in FL) using `server_savepoint` in the parent `Plan` message.
1172
1173  // End L2 Aggregation.
1174  // ===========================================================================
1175}
1176
1177// Represents the server phase in an eligibility computation.
1178//
1179// This phase produces a checkpoint to be sent to clients. This checkpoint is
1180// then used as an input to the clients' task eligibility computations.
1181// This phase *does not include any aggregation.*
1182message ServerEligibilityComputationPhase {
1183  // A short CamelCase name for the ServerEligibilityComputationPhase.
1184  string name = 1;
1185
1186  // The names of the TensorFlow nodes to run in order to produce output.
1187  repeated string target_node_names = 2;
1188
1189  // The specification of inputs and outputs to the TensorFlow graph.
1190  oneof server_eligibility_io_router {
1191    TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true];
1192  }
1193}
1194
1195// Represents the inputs and outputs of a `ServerEligibilityComputationPhase`
1196// which takes a single `TaskEligibilityContext` as input.
1197message TEContextServerEligibilityIORouter {
1198  // The name of the scalar string tensor that must be fed a serialized
1199  // `TaskEligibilityContext`.
1200  string context_proto_input_tensor_name = 1;
1201
1202  // The name of the scalar string tensor that must be fed the path to which
1203  // the server graph should write the checkpoint file to be sent to the client.
1204  string output_filepath_tensor_name = 2;
1205}
1206
1207// Plan
1208// =====
1209
1210// Represents the overall plan for performing federated optimization or
1211// personalization, as handed over to the production system. This will
1212// typically be split down into individual pieces for different production
1213// parts, e.g. server and client side.
1214// NEXT_TAG: 15
1215message Plan {
1216  reserved 1, 3, 5;
1217
1218  // The actual type of the server_*_graph_bytes fields below is expected to be
1219  // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form
1220  // for two reasons.
1221  // 1) We may use execution engines other than TensorFlow.
1222  // 2) We wish to avoid the cost of deserialized and re-serializing large
1223  // graphs, in the Federated Learning service.
1224
1225  // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes,
1226  // server_graph_prepare_bytes, and server_graph_result_bytes may all be set.
1227  // If we're using a MapReduceForm-based server implementation, only
1228  // server_graph_bytes will be used. If we're using a DistributeAggregateForm-
1229  // based server implementation, only server_graph_prepare_bytes and
1230  // server_graph_result_bytes will be used.
1231
1232  // Optional. The TensorFlow graph used for all server processing described by
1233  // ServerPhase. For personalization, this will not be set.
1234  google.protobuf.Any server_graph_bytes = 7;
1235
1236  // Optional. The TensorFlow graph used for all server processing described by
1237  // ServerPhaseV2.tensorflow_spec_prepare.
1238  google.protobuf.Any server_graph_prepare_bytes = 13;
1239
1240  // Optional. The TensorFlow graph used for all server processing described by
1241  // ServerPhaseV2.tensorflow_spec_result.
1242  google.protobuf.Any server_graph_result_bytes = 14;
1243
1244  // A savepoint to sync the server checkpoint with a persistent
1245  // storage system.  The storage initially holds a seeded checkpoint
1246  // which can subsequently read and updated by this savepoint.
1247  // Optional-- not present in eligibility computation plans (those with a
1248  // ServerEligibilityComputationPhase). This is used in conjunction with
1249  // ServerPhase only.
1250  CheckpointOp server_savepoint = 2;
1251
1252  // Required. The TensorFlow graph that describes the TensorFlow logic a client
1253  // should perform. It should be consistent with the `TensorflowSpec` field in
1254  // the `client_phase`. The actual type is expected to be tensorflow.GraphDef.
1255  // The TensorFlow graph is stored in serialized form for two reasons.
1256  // 1) We may use execution engines other than TensorFlow.
1257  // 2) We wish to avoid the cost of deserialized and re-serializing large
1258  // graphs, in the Federated Learning service.
1259  google.protobuf.Any client_graph_bytes = 8;
1260
1261  // Optional. The FlatBuffer used for TFLite training.
1262  // It contains the same model information as the client_graph_bytes, but with
1263  // a different format.
1264  bytes client_tflite_graph_bytes = 12;
1265
1266  // A pair of client phase and server phase which are processed in
1267  // sync. The server execution defines how the results of a client
1268  // phase are aggregated, and how the checkpoints for clients are
1269  // generated.
1270  message Phase {
1271    // Required. The client phase.
1272    ClientPhase client_phase = 1;
1273
1274    // Optional. Server phase for TF-based aggregation; not provided for
1275    // personalization or eligibility tasks.
1276    ServerPhase server_phase = 2;
1277
1278    // Optional. Server phase for native aggregation; only provided for tasks
1279    // that have enabled the corresponding flag.
1280    ServerPhaseV2 server_phase_v2 = 4;
1281
1282    // Optional. Only provided for eligibility tasks.
1283    ServerEligibilityComputationPhase server_eligibility_phase = 3;
1284  }
1285
1286  // A pair of client and server computations to run.
1287  repeated Phase phase = 4;
1288
1289  // Metrics that are persistent across different phases. This
1290  // includes, for example, counters that track how much work of
1291  // different kinds has been done.
1292  repeated Metric metrics = 6;
1293
1294  // Describes how metrics in both the client and server phases should be
1295  // aggregated.
1296  repeated OutputMetric output_metrics = 10;
1297
1298  // Version of the plan:
1299  // version == 0 - Old plan without version field, containing b/65131070
1300  // version >= 1 - plan supports multi-shard aggregation mode (L1/L2)
1301  int32 version = 9;
1302
1303  // A TensorFlow ConfigProto packed in an Any.
1304  //
1305  // If this field is unset, if the Any proto is set but empty, or if the Any
1306  // proto is populated with an empty ConfigProto (i.e. its `type_url` field is
1307  // set, but the `value` field is empty) then the client implementation may
1308  // choose a set of configuration parameters to provide to TensorFlow by
1309  // default.
1310  //
1311  // In all other cases this field must contain a valid packed ConfigProto
1312  // (invalid values will result in an error at execution time), and in this
1313  // case the client will not provide any other configuration parameters by
1314  // default.
1315  google.protobuf.Any tensorflow_config_proto = 11;
1316}
1317
1318// Represents a client part of the plan of federated optimization.
1319// This also used to describe a client-only plan for standalone on-device
1320// training, known as personalization.
1321// NEXT_TAG: 6
1322message ClientOnlyPlan {
1323  reserved 3;
1324
1325  // The graph to use for training, in binary form.
1326  bytes graph = 1;
1327
1328  // Optional. The flatbuffer used for TFLite training.
1329  // Whether "graph" or "tflite_graph" is used for training is up to the client
1330  // code to allow for a flag-controlled a/b rollout.
1331  bytes tflite_graph = 5;
1332
1333  // The client phase to execute.
1334  ClientPhase phase = 2;
1335
1336  // A TensorFlow ConfigProto.
1337  google.protobuf.Any tensorflow_config_proto = 4;
1338}
1339
1340// Represents the cross round aggregation portion for user defined measurements.
1341// This is used by tools that process / analyze accumulator checkpoints
1342// after a round of computation, to achieve aggregation beyond a round.
1343message CrossRoundAggregationExecution {
1344  // Operation to run before reading accumulator checkpoint.
1345  string init_op = 1;
1346
1347  // Reads accumulator checkpoint.
1348  CheckpointOp read_aggregated_update = 2;
1349
1350  // Operation to merge loaded checkpoint into accumulator.
1351  string merge_op = 3;
1352
1353  // Reads and writes the final aggregated accumulator vars.
1354  CheckpointOp read_write_final_accumulators = 6;
1355
1356  // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable`
1357  // to the user defined name of the signal.
1358  repeated Measurement measurements = 4;
1359
1360  // The `tf.Graph` used for aggregating accumulator checkpoints when
1361  // loading metrics.
1362  google.protobuf.Any cross_round_aggregation_graph_bytes = 5;
1363}
1364
1365message Measurement {
1366  // Name of a TensorFlow op to run to read/fetch the value of this measurement.
1367  string read_op_name = 1;
1368
1369  // A human-readable name for the measurement. Names are usually
1370  // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
1371  string name = 2;
1372
1373  reserved 3;
1374
1375  // A serialized `tff.Type` for the measurement.
1376  bytes tff_type = 4;
1377}
1378