1/* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17syntax = "proto3"; 18 19package com.android.federatedcompute.proto; 20 21import "google/protobuf/any.proto"; 22import "tensorflow/core/framework/tensor.proto"; 23import "tensorflow/core/framework/tensor_shape.proto"; 24import "tensorflow/core/framework/types.proto"; 25import "tensorflow/core/protobuf/saver.proto"; 26import "tensorflow/core/protobuf/struct.proto"; 27 28option java_package = "com.android.federatedcompute.proto"; 29option java_multiple_files = true; 30option java_outer_classname = "PlanProto"; 31 32// Primitives 33// =========== 34 35// Represents an operation to save or restore from a checkpoint. Some 36// instances of this message may only be used either for restore or for 37// save, others for both directions. This is documented together with 38// their usage. 39// 40// This op has four essential uses: 41// 1. read and apply a checkpoint. 42// 2. write a checkpoint. 43// 3. read and apply from an aggregated side channel. 44// 4. write to a side channel (grouped with write a checkpoint). 45// We should consider splitting this into four separate messages. 46message CheckpointOp { 47 // An optional standard saver def. If not provided, only the 48 // op(s) below will be executed. This must be a version 1 SaverDef. 49 tensorflow.SaverDef saver_def = 1; 50 51 // An optional operation to run before the saver_def is executed for 52 // restore. 53 string before_restore_op = 2; 54 55 // An optional operation to run after the saver_def has been 56 // executed for restore. If side_channel_tensors are provided, then 57 // they should be provided in a feed_dict to this op. 58 string after_restore_op = 3; 59 60 // An optional operation to run before the saver_def will be 61 // executed for save. 62 string before_save_op = 4; 63 64 // An optional operation to run after the saver_def has been 65 // executed for save. If there are side_channel_tensors, this op 66 // should be run after the side_channel_tensors have been fetched. 67 string after_save_op = 5; 68 69 // In addition to being saved and restored from a checkpoint, one can 70 // also save and restore via a side channel. The keys in this map are 71 // the names of the tensors transmitted by the side channel. These (key) 72 // tensors should be read off just before saving a SaveDef and used 73 // by the code that handles the side channel. Any variables provided this 74 // way should NOT be saved in the SaveDef. 75 // 76 // For restoring, the variables that are provided by the side channel 77 // are restored differently than those for a checkpoint. For those from 78 // the side channel, these should be restored by calling the before_restore_op 79 // with a feed dict whose keys are the restore_names in the SideChannel and 80 // whose values are the values to be restored. 81 map<string, SideChannel> side_channel_tensors = 6; 82 83 // An optional name of a tensor in to which a unique token for the current 84 // session should be written. 85 // 86 // This session identifier allows TensorFlow ops such as `ServeSlices` or 87 // `ExternalDataset` to refer to callbacks and other session-global objects 88 // registered before running the session. 89 string session_token_tensor_name = 7; 90} 91 92message SideChannel { 93 // A side channel whose variables are processed via SecureAggregation. 94 // This side channel implements aggregation via sum over a set of 95 // clients, so the restored tensor will be a sum of multiple clients 96 // inputs into the side channel. Hence this will restore during the 97 // read_aggregate_update restore, not the per-client read_update restore. 98 message SecureAggregand { 99 message Dimension { 100 int64 size = 1; 101 } 102 103 // Dimensions of the aggregand. This is used by the secure aggregation 104 // protocol in its early rounds, not as redundant info which could be 105 // obtained by reading the dimensions of the tensor itself. 106 repeated Dimension dimension = 3; 107 108 // The data type anticipated by the server-side graph. 109 tensorflow.DataType dtype = 4; 110 111 // SecureAggregation will compute sum modulo this modulus. 112 message FixedModulus { 113 uint64 modulus = 1; 114 } 115 116 // SecureAggregation will for each shard compute sum modulo m with m at 117 // least (1 + shard_size * (base_modulus - 1)), then aggregate 118 // shard results with non-modular addition. Here, shard_size is the number 119 // of clients in the shard. 120 // 121 // Note that the modulus for each shard will be greater than the largest 122 // possible (non-modular) sum of the inputs to that shard. That is, 123 // assuming each client has input on range [0, base_modulus), the result 124 // will be identical to non-modular addition (i.e. federated_sum). 125 // 126 // While any m >= (1 + shard_size * (base_modulus - 1)), the current 127 // implementation takes 128 // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the 129 // smallest possible value of m that is also a power of 2. This choice is 130 // made because (a) it uses the same number of bits per vector entry as 131 // valid smaller m, using the current on-the-wire encoding scheme, and (b) 132 // it enables the underlying mask-generation PRNG to run in its most 133 // computationally efficient mode, which can be up to 2x faster. 134 message ModulusTimesShardSize { 135 uint64 base_modulus = 1; 136 } 137 138 oneof modulus_scheme { 139 // Bitwidth of the aggregand. 140 // 141 // This is the bitwidth of an input value (i.e. the bitwidth that 142 // quantization should target). The Secure Aggregation bitwidth (i.e., 143 // the bitwidth of the *sum* of the input values) will be a function of 144 // this bitwidth and the number of participating clients, as negotiated 145 // with the server when the protocol is initiated. 146 // 147 // Deprecated; prefer fixed_modulus instead. 148 int32 quantized_input_bitwidth = 2 [deprecated = true]; 149 150 FixedModulus fixed_modulus = 5; 151 ModulusTimesShardSize modulus_times_shard_size = 6; 152 } 153 154 reserved 1; 155 } 156 157 // What type of side channel is used. 158 oneof type { 159 SecureAggregand secure_aggregand = 1; 160 } 161 162 // When restoring the name of the tensor to restore to. This is the name 163 // (key) supplied in the feed_dict in the before_restore_op in order to 164 // restore the tensor provided by the side channel (which will be the 165 // value in the feed_dict). 166 string restore_name = 2; 167} 168 169// Container for a metric used by the internal toolkit. 170message Metric { 171 // Name of an Op to run to read the value. 172 string variable_name = 1; 173 174 // A human-readable name for the statistic. Metric names are usually 175 // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. 176 // Must be 7-bit ASCII and under 122 characters. 177 string stat_name = 2; 178 179 // The human-readable name of another metric by which this metric should be 180 // normalized, if any. If empty, this Metric should be aggregated with simple 181 // summation. If not empty, the Metric is aggregated according to 182 // weighted_metric_sum = sum_i (metric_i * weight_i) 183 // weight_sum = sum_i weight_i 184 // average_metric_value = weighted_metric_sum / weight_sum 185 string weight_name = 3; 186} 187 188// Controls the format of output metrics users receive. Represents instructions 189// for how metrics are to be output to users, controlling the end format of 190// the metric users receive. 191message OutputMetric { 192 // Metric name. 193 string name = 1; 194 195 oneof value_source { 196 // A metric representing one stat with aggregation type sum. 197 SumOptions sum = 2; 198 199 // A metric representing a ratio between metrics with aggregation 200 // type average. 201 AverageOptions average = 3; 202 203 // A metric that is not aggregated by the MetricReportAggregator or 204 // metrics_loader. This includes metrics like 'num_server_updates' that are 205 // aggregated in TensorFlow. 206 NoneOptions none = 4; 207 208 // A metric representing one stat with aggregation type only sample. 209 // Samples at most 101 clients' values. 210 OnlySampleOptions only_sample = 5; 211 } 212 // Iff True, the metric will be plotted in the default view of the 213 // task level Colab automatically. 214 oneof visualization_info { 215 bool auto_plot = 6 [deprecated = true]; 216 VisualizationSpec plot_spec = 7; 217 } 218} 219 220message VisualizationSpec { 221 // Different allowable plot types. 222 enum VisualizationType { 223 NONE = 0; 224 DEFAULT_PLOT_FOR_TASK_TYPE = 1; 225 LINE_PLOT = 2; 226 LINE_PLOT_WITH_PERCENTILES = 3; 227 HISTOGRAM = 4; 228 } 229 230 // Defines the plot type to provide downstream. 231 VisualizationType plot_type = 1; 232 233 // The x-axis which to provide for the given metric. Must be the name of a 234 // metric or counter. Recommended x_axis options are source_round, round, 235 // or time. 236 string x_axis = 2; 237 238 // Iff True, metric will be displayed on a population level dashboard. 239 bool plot_on_population_dashboard = 3; 240} 241 242// A metric representing one stat with aggregation type sum. 243message SumOptions { 244 // Name for corresponding Metric stat_name field. 245 string stat_name = 1; 246 247 // Iff True, a cumulative sum over rounds will be provided in addition to a 248 // sum per round for the value metric. 249 bool include_cumulative_sum = 2; 250 251 // Iff True, sample of at most 101 clients' values. 252 // Used to calculate quantiles in downstream visualization pipeline. 253 bool include_client_samples = 3; 254} 255 256// A metric representing a ratio between metrics with aggregation type average. 257// Represents: numerator stat / denominator stat. 258message AverageOptions { 259 // Numerator stat name pointing to corresponding Metric stat_name. 260 string numerator_stat_name = 1; 261 262 // Denominator stat name pointing to corresponding Metric stat_name. 263 string denominator_stat_name = 2; 264 265 // Name for corresponding Metric stat_name that is the ratio of the 266 // numerator stat / denominator stat. 267 string average_stat_name = 3; 268 269 // Iff True, sample of at most 101 client's values. 270 // Used to calculate quantiles in downstream visualization pipeline. 271 bool include_client_samples = 4; 272} 273 274// A metric representing one stat with aggregation type none. 275message NoneOptions { 276 // Name for corresponding Metric stat_name field. 277 string stat_name = 1; 278} 279 280// A metric representing one stat with aggregation type only sample. 281message OnlySampleOptions { 282 // Name for corresponding Metric stat_name field. 283 string stat_name = 1; 284} 285 286// Represents a data set. This is used for testing. 287message Dataset { 288 // Represents the data set for one client. 289 message ClientDataset { 290 // A string identifying the client. 291 string client_id = 1; 292 293 // A list of serialized tf.Example protos. 294 repeated bytes example = 2; 295 296 // Represents a dataset whose examples are selected by an ExampleSelector. 297 message SelectedExample { 298 ExampleSelector selector = 1; 299 repeated bytes example = 2; 300 } 301 302 // A list of (selector, dataset) pairs. Used in testing some *TFF-based 303 // tasks* that require multiple datasets as client input, e.g., a TFF-based 304 // personalization eval task requires each client to provide at least two 305 // datasets: one for train, and the other for test. 306 repeated SelectedExample selected_example = 3; 307 } 308 309 // A list of client data. 310 repeated ClientDataset client_data = 1; 311} 312 313// Represents predicates over metrics - i.e., expectations. This is used in 314// training/eval tests to encode metric names and values expected to be reported 315// by a client execution. 316message MetricTestPredicates { 317 // The value must lie in [lower_bound; upper_bound]. Can also be used for 318 // approximate matching (lower == value - epsilon; upper = value + epsilon). 319 message Interval { 320 double lower_bound = 1; 321 double upper_bound = 2; 322 } 323 324 // The value must be a real value as long as the value of the weight_name 325 // metric is non-zero. If the weight metric is zero, then it is acceptable for 326 // the value to be non-real. 327 message RealIfNonzeroWeight { 328 string weight_name = 1; 329 } 330 331 message MetricCriterion { 332 // Name of the metric. 333 string name = 1; 334 335 // FL training round this metric is expected to appear in. 336 int32 training_round_index = 2; 337 338 // If none of the following is set, no matching is performed; but the 339 // metric is still expected to be present (with whatever value). 340 oneof Criterion { 341 // The reported metric must be < lt. 342 float lt = 3; 343 // The reported metric must be > gt. 344 float gt = 4; 345 // The reported metric must be <= le. 346 float le = 5; 347 // The reported metric must be >= ge. 348 float ge = 6; 349 // The reported metric must be == eq. 350 float eq = 7; 351 // The reported metric must lie in the interval. 352 Interval interval = 8; 353 // The reported metric is not NaN or +/- infinity. 354 bool real = 9; 355 // The reported metric is real (i.e., not NaN or +/- infinity) if the 356 // value of an associated weight is not 0. 357 RealIfNonzeroWeight real_if_nonzero_weight = 10; 358 } 359 } 360 361 repeated MetricCriterion metric_criterion = 1; 362 363 reserved 2; 364} 365 366// Client Phase 367// ============ 368 369// A `TensorflowSpec` that is executed on the client in a single `tf.Session`. 370// In federated optimization, this will correspond to one `ServerPhase`. 371message ClientPhase { 372 // A short CamelCase name for the ClientPhase. 373 string name = 2; 374 375 // Minimum number of clients in aggregation. 376 // In secure aggregation mode this is used to configure the protocol instance 377 // in a way that server can't learn aggregated values with number of 378 // participants lower than this number. 379 // Without secure aggregation server still respects this parameter, 380 // ensuring that aggregated values never leave server RAM unless they include 381 // data from (at least) specified number of participants. 382 int32 minimum_number_of_participants = 3; 383 384 // If populated, `io_router` must be specified. 385 oneof spec { 386 // A functional interface for the TensorFlow logic the client should 387 // perform. 388 TensorflowSpec tensorflow_spec = 4 [lazy = true]; 389 // Spec for client plans that issue example queries and send the query 390 // results directly to an aggregator with no or little additional 391 // processing. 392 ExampleQuerySpec example_query_spec = 9 [lazy = true]; 393 } 394 395 // The specification of the inputs coming either from customer apps 396 // (Local Compute) or the federated protocol (Federated Compute). 397 oneof io_router { 398 FederatedComputeIORouter federated_compute = 5 [lazy = true]; 399 LocalComputeIORouter local_compute = 6 [lazy = true]; 400 FederatedComputeEligibilityIORouter federated_compute_eligibility = 7 401 [lazy = true]; 402 FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true]; 403 } 404 405 reserved 1; 406} 407 408// TensorflowSpec message describes a single call into TensorFlow, including the 409// expected input tensors that must be fed when making that call, which 410// output tensors to be fetched, and any operations that have no output but must 411// be run. The TensorFlow session will then use the input tensors to do some 412// computation, generally reading from one or more datasets, and provide some 413// outputs. 414// 415// Conceptually, client or server code uses this proto along with an IORouter 416// to build maps of names to input tensors, vectors of output tensor names, 417// and vectors of target nodes: 418// 419// CreateTensorflowArguments( 420// TensorflowSpec& spec, 421// IORouter& io_router, 422// const vector<pair<string, Tensor>>* input_tensors, 423// const vector<string>* output_tensor_names, 424// const vector<string>* target_node_names); 425// 426// Where `input_tensor`, `output_tensor_names` and `target_node_names` 427// correspond to the arguments of TensorFlow C++ API for 428// `tensorflow::Session:Run()`, and the client executes only a single 429// invocation. 430// 431// Note: the execution engine never sees any concepts related to the federated 432// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors 433// in, tensors out" interface. New aggregation methods can be added without 434// having to modify the execution engine / TensorflowSpec message, instead they 435// should modify the IORouter messages. 436// 437// Note: both `input_tensor_specs` and `output_tensor_specs` are full 438// `tensorflow.TensorSpecProto` messages, though TensorFlow technically 439// only requires the names to feed the values into the session. The additional 440// dtypes/shape information must always be included in case the runtime 441// executing this TensorflowSpec wants to perform additional, optional static 442// assertions. The runtimes however are free to ignore the dtype/shapes and only 443// rely on the names if so desired. 444// 445// Assertions: 446// - all names in `input_tensor_specs`, `output_tensor_specs`, and 447// `target_node_names` must appear in the serialized GraphDef where 448// the TF execution will be invoked. 449// - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise 450// there is nothing to execute in the graph. 451message TensorflowSpec { 452 // The name of a tensor into which a unique token for the current session 453 // should be written. The corresponding tensor is a scalar string tensor and 454 // is separate from `input_tensors` as there is only one. 455 // 456 // A session token allows TensorFlow ops such as `ServeSlices` or 457 // `ExternalDataset` to refer to callbacks and other session-global objects 458 // registered before running the session. In the `ExternalDataset` case, a 459 // single dataset_token is valid for multiple `tf.data.Dataset` objects as 460 // the token can be thought of as a handle to a dataset factory. 461 string dataset_token_tensor_name = 1; 462 463 // TensorSpecs of inputs which will be passed to TF. 464 // 465 // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in 466 // TensorFlow's Python API, excluding the dataset_token listed above. 467 // 468 // Assertions: 469 // - All the tensor names designated as inputs in the corresponding IORouter 470 // must be listed (otherwise the IORouter input work is unused). 471 // - All placeholders in the TF graph must be listed here, with the 472 // exception of the dataset_token which is explicitly set above (otherwise 473 // TensorFlow will fail to execute). 474 repeated tensorflow.TensorSpecProto input_tensor_specs = 2; 475 476 // TensorSpecs that should be fetched from TF after execution. 477 // 478 // Corresponds to the `fetches` parameter of `tf.Session.run()` in 479 // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++ 480 // API. 481 // 482 // Assertions: 483 // - The set of tensor names here must strictly match the tensor names 484 // designated as outputs in the corresponding IORouter (if any exist). 485 repeated tensorflow.TensorSpecProto output_tensor_specs = 3; 486 487 // Node names in the graph that should be executed, but the output not 488 // returned. 489 // 490 // Corresponds to the `fetches` parameter of `tf.Session.run()` in 491 // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++ 492 // API. 493 // 494 // This is intended for use with operations that do not produce tensors, but 495 // nonetheless are required to run (e.g. serializing checkpoints). 496 repeated string target_node_names = 4; 497 498 // Map of Tensor names to constant inputs. 499 // Note: tensors specified via this message should not be included in 500 // input_tensor_specs. 501 map<string, tensorflow.TensorProto> constant_inputs = 5; 502} 503 504// ExampleQuerySpec message describes client execution that issues example 505// queries and sends the query results directly to an aggregator with no or 506// little additional processing. 507// This message describes one or more example store queries that perform the 508// client side analytics computation in C++. The corresponding output vectors 509// will be converted into the expected federated protocol output format. 510// This must be used in conjunction with the `FederatedExampleQueryIORouter`. 511message ExampleQuerySpec { 512 message OutputVectorSpec { 513 // The output vector name. 514 string vector_name = 1; 515 516 // Supported data types for the vector of information. 517 enum DataType { 518 UNSPECIFIED = 0; 519 INT32 = 1; 520 INT64 = 2; 521 BOOL = 3; 522 FLOAT = 4; 523 DOUBLE = 5; 524 BYTES = 6; 525 STRING = 7; 526 } 527 528 // The data type for each entry in the vector. 529 DataType data_type = 2; 530 } 531 532 message ExampleQuery { 533 // The `ExampleSelector` to issue the query with. 534 ExampleSelector example_selector = 1; 535 536 // Indicates that the query returns vector data and must return a single 537 // ExampleQueryResult result containing a VectorData entry matching each 538 // OutputVectorSpec.vector_name. 539 // 540 // If the query instead returns no result, then it will be treated as is if 541 // an error was returned. In that case, or if the query explicitly returns 542 // an error, then the client will abort its session. 543 // 544 // The keys in the map are the names the vectors should be aggregated under, 545 // and must match the keys in FederatedExampleQueryIORouter.aggregations. 546 map<string, OutputVectorSpec> output_vector_specs = 2; 547 } 548 549 // The queries to run. 550 repeated ExampleQuery example_queries = 1; 551} 552 553// The input and output router for Federated Compute plans. 554// 555// This proto is the glue between the federated protocol and the TensorFlow 556// execution engine. This message describes how to prepare data coming from the 557// incoming `CheckinResponse` (defined in 558// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what 559// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on 560// the server). 561// 562// TODO(team) we could replace `input_checkpoint_file_tensor_name` with 563// an `input_tensors` field, which would then be a tensor that contains the 564// input TensorProtos directly and skipping disk I/O, rather than referring to a 565// checkpoint file path. 566message FederatedComputeIORouter { 567 // =========================================================================== 568 // Inputs 569 // =========================================================================== 570 // The name of the scalar string tensor that is fed the file path to the 571 // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint). 572 // 573 // The federated protocol code would copy the `CheckinResponse`'s initial 574 // checkpoint to a temporary file and then pass that file path through this 575 // tensor. 576 // 577 // Ops may be added to the client graph that take this tensor as input and 578 // reads the path. 579 // 580 // This field is optional. It may be omitted if the client graph does not use 581 // an initial checkpoint. 582 string input_filepath_tensor_name = 1; 583 584 // The name of the scalar string tensor that is fed the file path to which 585 // client work should serialize the bytes to send back to the server. 586 // 587 // The federated protocol code generates a temporary file and passes the file 588 // path through this tensor. 589 // 590 // Ops may be be added to the client graph that use this tensor as an argument 591 // to write files (e.g. writing checkpoints to disk). 592 // 593 // This field is optional. It must be omitted if the client graph does not 594 // generate any output files (e.g. when all output tensors of `TensorflowSpec` 595 // use Secure Aggregation). If this field is not set, then the `ReportRequest` 596 // message in the federated protocol will not have the 597 // `Report.update_checkpoint` field set. This absence of a value here can be 598 // used to validate that the plan only uses Secure Aggregation. 599 // 600 // Conversely, if this field is set and executing the associated 601 // TensorflowSpec does not write to the path is indication of an internal 602 // framework error. The runtime should notify the caller that the computation 603 // was setup incorrectly. 604 string output_filepath_tensor_name = 2; 605 606 // =========================================================================== 607 // Outputs 608 // =========================================================================== 609 // Describes which output tensors should be aggregated using an aggregation 610 // protocol, and the configuration for those protocols. 611 // 612 // Assertions: 613 // - All keys must exist in the associated `TensorflowSpec` as 614 // `output_tensor_specs.name` values. 615 map<string, AggregationConfig> aggregations = 3; 616} 617 618// The input and output router for client plans that do not use TensorFlow. 619// 620// This proto is the glue between the federated protocol and the example query 621// execution engine, describing how the query results should ultimately be 622// aggregated. 623message FederatedExampleQueryIORouter { 624 // Describes how each output vector should be aggregated using an aggregation 625 // protocol, and the configuration for those protocols. 626 // Keys must match the keys in ExampleQuerySpec.output_vector_specs. 627 // Note that currently only the TFV1CheckpointAggregation config is supported. 628 map<string, AggregationConfig> aggregations = 1; 629} 630 631// The specification for how to aggregate the associated tensor across clients 632// on the server. 633message AggregationConfig { 634 oneof protocol_config { 635 // Indicates that the given output tensor should be processed using Secure 636 // Aggregation, using the specified config options. 637 SecureAggregationConfig secure_aggregation = 2; 638 639 // Note: in the future we could add a `SimpleAggregationConfig` to add 640 // support for simple aggregation without writing to an intermediate 641 // checkpoint file first. 642 643 // Indicates that the given output tensor or vector (e.g. as produced by an 644 // ExampleQuerySpec) should be placed in an output TF v1 checkpoint. 645 // 646 // Currently only ExampleQuerySpec output vectors are supported by this 647 // aggregation type (i.e. it cannot be used with TensorflowSpec output 648 // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of 649 // its corresponding data type. 650 TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3; 651 } 652} 653 654// Parameters for the SecAgg protocol (go/secagg). 655// 656// Currently only the server uses the SecAgg parameters, so we only use this 657// message to signify usage of SecAgg. 658message SecureAggregationConfig {} 659 660// Parameters for the TFV1 Checkpoint Aggregation protocol. 661// 662// Currently only ExampleQuerySpec output vectors are supported by this 663// aggregation type (i.e. it cannot be used with TensorflowSpec output 664// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of 665// its corresponding data type. 666message TFV1CheckpointAggregation {} 667 668// The input and output router for eligibility-computing plans. These plans 669// compute which other plans a client is eligible to run, and are returned by 670// clients via a `EligibilityEvalCheckinResponse` (defined in 671// fcp/protos/federated_api.proto). 672message FederatedComputeEligibilityIORouter { 673 // The name of the scalar string tensor that is fed the file path to the 674 // initial checkpoint (e.g. as provided via 675 // `EligibilityEvalPayload.init_checkpoint`). 676 // 677 // For more detail see the 678 // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same 679 // semantics. 680 // 681 // This field is optional. It may be omitted if the client graph does not use 682 // an initial checkpoint. 683 // 684 // This tensor name must exist in the associated 685 // `TensorflowSpec.input_tensor_specs` list. 686 string input_filepath_tensor_name = 1; 687 688 // Name of the output tensor (a string scalar) containing the serialized 689 // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The 690 // client code will parse this proto and place it in the 691 // `task_eligibility_info` field of the subsequent `CheckinRequest`. 692 // 693 // This tensor name must exist in the associated 694 // `TensorflowSpec.output_tensor_specs` list. 695 string task_eligibility_info_tensor_name = 2; 696} 697 698// The input and output router for Local Compute plans. 699// 700// This proto is the glue between the customers app and the TensorFlow 701// execution engine. This message describes how to prepare data coming from the 702// customer app (e.g. the input directory the app setup), and the temporary, 703// scratch output directory that will be notified to the customer app upon 704// completion of `TensorflowSpec`. 705message LocalComputeIORouter { 706 // =========================================================================== 707 // Inputs 708 // =========================================================================== 709 // The name of the placeholder tensor representing the input resource path(s). 710 // It can be a single input directory or file path (in this case the 711 // `input_dir_tensor_name` is populated) or multiple input resources 712 // represented as a map from names to input directories or file paths (in this 713 // case the `multiple_input_resources` is populated). 714 // 715 // In the multiple input resources case, the placeholder tensors are 716 // represented as a map: the keys are the input resource names defined by the 717 // users when constructing the `LocalComputation` Python object, and the 718 // values are the corresponding placeholder tensor names created by the local 719 // computation plan builder. 720 // 721 // Apps will have the ability to create contracts between their Android code 722 // and `LocalComputation` toolkit code to place files inside the input 723 // resource paths with known names (Android code) and create graphs with ops 724 // to read from these paths (file names can be specified in toolkit code). 725 oneof input_resource { 726 string input_dir_tensor_name = 1; 727 // Directly using the `map` field is not allowed in `oneof`, so we have to 728 // wrap it in a new message. 729 MultipleInputResources multiple_input_resources = 3; 730 } 731 732 // Scalar string tensor name that will contain the output directory path. 733 // 734 // The provided directory should be considered temporary scratch that will be 735 // deleted, not persisted. It is the responsibility of the calling app to 736 // move the desired files to a permanent location once the client returns this 737 // directory back to the calling app. 738 string output_dir_tensor_name = 2; 739 740 // =========================================================================== 741 // Outputs 742 // =========================================================================== 743 // NOTE: LocalCompute has no outputs other than what the client graph writes 744 // to `output_dir` specified above. 745} 746 747// Describes the multiple input resources in `LocalComputeIORouter`. 748message MultipleInputResources { 749 // The keys are the input resource names (defined by the users when 750 // constructing the `LocalComputation` Python object), and the values are the 751 // corresponding placeholder tensor names created by the local computation 752 // plan builder. 753 map<string, string> input_resource_tensor_name_map = 1; 754} 755 756// Describes a queue to which input is fed. 757message AsyncInputFeed { 758 // The op for enqueuing an example input. 759 string enqueue_op = 1; 760 761 // The input placeholders for the enqueue op. 762 repeated string enqueue_params = 2; 763 764 // The op for closing the input queue. 765 string close_op = 3; 766 767 // Whether the work that should be fed asynchronously is the data itself 768 // or a description of where that data lives. 769 bool feed_values_are_data = 4; 770} 771 772message DatasetInput { 773 // Initializer of iterator corresponding to tf.data.Dataset object which 774 // handles the input data. Stores name of an op in the graph. 775 string initializer = 1; 776 777 // Placeholders necessary to initialize the dataset. 778 DatasetInputPlaceholders placeholders = 2; 779 780 // Batch size to be used in tf.data.Dataset. 781 int32 batch_size = 3; 782} 783 784message DatasetInputPlaceholders { 785 // Name of placeholder corresponding to filename(s) of SSTable(s) to read data 786 // from. 787 string filename = 1; 788 789 // Name of placeholder corresponding to key_prefix initializing the 790 // SSTableDataset. Note the value fed should be unique user id, not a prefix. 791 string key_prefix = 2; 792 793 // Name of placeholder corresponding to number of rounds the local training 794 // should be run for. 795 string num_epochs = 3; 796 797 // Name of placeholder corresponding to batch size. 798 string batch_size = 4; 799} 800 801// Specifies an example selection procedure. 802message ExampleSelector { 803 // Selection criteria following a contract agreed upon between client and 804 // model designers. 805 google.protobuf.Any criteria = 1; 806 807 // A URI identifying the example collection to read from. Format should adhere 808 // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments 809 // should adhere to the following rules: 810 // - The scheme ${COLLECTION} should be one of: 811 // - "app" for app-hosted example 812 // - "simulation" for collections not connected to an app (e.g., if used 813 // purely for simulation) 814 // - The authority ${APP_NAME} identifies the owner of the example 815 // collection and should be either the app's package name, or be left empty 816 // (which means "the current app package name"). 817 // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with 818 // a forward slash ("/"). 819 // - The query and fragment are currently not used, but they may become used 820 // for something in the future. To keep open that possibility they must 821 // currently be left empty. 822 // 823 // Example: "app://com.google.some.app/someCollection/name" 824 // identifies the collection "/someCollection/name" owned and hosted by the 825 // app with package name "com.google.some.app". 826 // 827 // Example: "app:/someCollection/name" or "app:///someCollection/name" 828 // both identify the collection "/someCollection/name" owned and hosted by the 829 // app associated with the training job in which this URI appears. 830 // 831 // The path will not be interpreted by the runtime, and will be passed to the 832 // example collection implementation for interpretation. Thus, in the case of 833 // app-hosted example stores, the path segment's interpretation is a contract 834 // between the app's example store developers, and the app's model designers. 835 // 836 // If an `app://` URI is set, then the `TrainerOptions` collection name must 837 // not be set. 838 string collection_uri = 2; 839 840 // Resumption token following a contract agreed upon between client and 841 // model designers. 842 google.protobuf.Any resumption_token = 3; 843} 844 845// Selector for slices to fetch as part of a `federated_select` operation. 846message SlicesSelector { 847 // The string ID under which the slices are served. 848 // 849 // This value must have been returned by a previous call to the `serve_slices` 850 // op run during the `write_client_init` operation. 851 string served_at_id = 1; 852 853 // The indices of slices to fetch. 854 repeated int32 keys = 2; 855} 856 857// Represents slice data to be served as part of a `federated_select` operation. 858// This is used for testing. 859message SlicesTestDataset { 860 // The test data to use. The keys map to the `SlicesSelector.served_at_id` 861 // field. E.g. test slice data for a slice with `served_at_id`="foo" and 862 // `keys`=2 would be store in `dataset["foo"].slice_data[2]`. 863 map<string, SlicesTestData> dataset = 1; 864} 865message SlicesTestData { 866 // The test slice data to serve. Each entry's index corresponds to the slice 867 // key it is the test data for. 868 repeated bytes slice_data = 2; 869} 870 871// Server Phase V2 872// =============== 873 874// Represents a server phase with three distinct components: pre-broadcast, 875// aggregation, and post-aggregation. 876// 877// The pre-broadcast and post-aggregation components are described with 878// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec 879// messages, respectively. These messages in combination with the server 880// IORouter messages specify how to set up a single TF sess.run call for each 881// component. 882// 883// The pre-broadcast logic is obtained by transforming the server_prepare TFF 884// computation in the DistributeAggregateForm. It takes the server state as 885// input, and it generates the checkpoint to broadcast to the clients and 886// potentially an intermediate server state. The intermediate server state may 887// be used by the aggregation and post-aggregation logic. 888// 889// The aggregation logic represents the aggregation of client results at the 890// server and is described using a list of ServerAggregationConfig messages. 891// Each ServerAggregationConfig message describes a single aggregation operation 892// on a set of input/output tensors. The input tensors may represent parts of 893// either the client results or the intermediate server state. These messages 894// are obtained by transforming the client_to_server_aggregation TFF computation 895// in the DistributeAggregateForm. 896// 897// The post-aggregation logic is obtained by transforming the server_result TFF 898// computation in the DistributeAggregateForm. It takes the intermediate server 899// state and the aggregated client results as input, and it generates the new 900// server state and potentially other server-side output. 901// 902// Note that while a ServerPhaseV2 message can be generated for all types of 903// intrinsics, it is currently only compatible with the ClientPhase message if 904// the aggregations being used are exclusively federated_sum (not SecAgg). If 905// this compatibility requirement is satisfied, it is also valid to run the 906// aggregation portion of this ServerPhaseV2 message alongside the pre- and 907// post-aggregation logic from the original ServerPhase message. Ultimately, 908// we expect the full ServerPhaseV2 message to be run and the ServerPhase 909// message to be deprecated. 910message ServerPhaseV2 { 911 // A short CamelCase name for the ServerPhaseV2. 912 string name = 1; 913 914 // A functional interface for the TensorFlow logic the server should perform 915 // prior to the server-to-client broadcast. This should be used with the 916 // TensorFlow graph defined in server_graph_prepare_bytes. 917 TensorflowSpec tensorflow_spec_prepare = 3; 918 919 // The specification of inputs needed by the server_prepare TF logic. 920 oneof server_prepare_io_router { 921 ServerPrepareIORouter prepare_router = 4; 922 } 923 924 // A list of client-to-server aggregations to perform. 925 repeated ServerAggregationConfig aggregations = 2; 926 927 // A functional interface for the TensorFlow logic the server should perform 928 // post-aggregation. This should be used with the TensorFlow graph defined 929 // in server_graph_result_bytes. 930 TensorflowSpec tensorflow_spec_result = 5; 931 932 // The specification of inputs and outputs needed by the server_result TF 933 // logic. 934 oneof server_result_io_router { 935 ServerResultIORouter result_router = 6; 936 } 937} 938 939// Routing for server_prepare graph 940message ServerPrepareIORouter { 941 // The name of the scalar string tensor in the server_prepare TF graph that 942 // is fed the filepath to the initial server state checkpoint. The 943 // server_prepare logic reads from this filepath. 944 string prepare_server_state_input_filepath_tensor_name = 1; 945 946 // The name of the scalar string tensor in the server_prepare TF graph that 947 // is fed the filepath where the client checkpoint should be stored. The 948 // server_prepare logic writes to this filepath. 949 string prepare_output_filepath_tensor_name = 2; 950 951 // The name of the scalar string tensor in the server_prepare TF graph that 952 // is fed the filepath where the intermediate state checkpoint should be 953 // stored. The server_prepare logic writes to this filepath. The intermediate 954 // state checkpoint will be consumed by both the logic used to set parameters 955 // for aggregation and the post-aggregation logic. 956 string prepare_intermediate_state_output_filepath_tensor_name = 3; 957} 958 959// Routing for server_result graph 960message ServerResultIORouter { 961 // The name of the scalar string tensor in the server_result TF graph that is 962 // fed the filepath to the intermediate state checkpoint. The server_result 963 // logic reads from this filepath. 964 string result_intermediate_state_input_filepath_tensor_name = 1; 965 966 // The name of the scalar string tensor in the server_result TF graph that is 967 // fed the filepath to the aggregated client result checkpoint. The 968 // server_result logic reads from this filepath. 969 string result_aggregate_result_input_filepath_tensor_name = 2; 970 971 // The name of the scalar string tensor in the server_result TF graph that is 972 // fed the filepath where the updated server state should be stored. The 973 // server_result logic writes to this filepath. 974 string result_server_state_output_filepath_tensor_name = 3; 975} 976 977// Represents a single aggregation operation, combining one or more input 978// tensors from a collection of clients into one or more output tensors on the 979// server. 980message ServerAggregationConfig { 981 // The uri of the aggregation intrinsic (e.g. 'federated_sum'). 982 string intrinsic_uri = 1; 983 984 // Describes an argument to the aggregation operation. 985 message IntrinsicArg { 986 oneof arg { 987 // Refers to a tensor within the checkpoint provided by each client. 988 tensorflow.TensorSpecProto input_tensor = 2; 989 990 // Refers to a tensor within the intermediate server state checkpoint. 991 tensorflow.TensorSpecProto state_tensor = 3; 992 } 993 } 994 995 // List of arguments for the aggregation operation. The arguments can be 996 // dependent on client data (in which case they must be retrieved from 997 // clients) or they can be independent of client data (in which case they 998 // can be configured server-side). For now we assume all client-independent 999 // arguments are constants. The arguments must be in the order expected by 1000 // the server. 1001 repeated IntrinsicArg intrinsic_args = 4; 1002 1003 // List of server-side outputs produced by the aggregation operation. 1004 repeated tensorflow.TensorSpecProto output_tensors = 5; 1005} 1006 1007// Server Phase 1008// ============ 1009 1010// Represents a server phase which implements TF-based aggregation of multiple 1011// client updates. 1012// 1013// There are two different modes of aggregation that are described 1014// by the values in this message. The first is aggregation that is 1015// coming from coordinated sets of clients. This includes aggregation 1016// done via checkpoints from clients or aggregation done over a set 1017// of clients by a process like secure aggregation. The results of 1018// this first aggregation are saved to intermediate aggregation 1019// checkpoints. The second aggregation then comes from taking 1020// these intermediate checkpoints and aggregating over them. 1021// 1022// These two different modes of aggregation are done on different 1023// servers, the first in the 'L1' servers and the second in the 1024// 'L2' servers, so we use this nomenclature to describe these 1025// phases below. 1026// 1027// The ServerPhase message is currently in the process of being replaced by the 1028// ServerPhaseV2 message as we switch the plan building pipeline to use 1029// DistributeAggregateForm instead of MapReduceForm. During the migration 1030// process, we may generate both messages and use components from either 1031// message during execution. 1032// 1033message ServerPhase { 1034 // A short CamelCase name for the ServerPhase. 1035 string name = 8; 1036 1037 // =========================================================================== 1038 // L1 "Intermediate" Aggregation. 1039 // 1040 // This is the initial aggregation that creates partial aggregates from client 1041 // results. L1 Aggregation may be run on many different instances. 1042 // 1043 // Pre-condition: 1044 // The execution environment has loaded the graph from `server_graph_bytes`. 1045 1046 // 1. Initialize the phase. 1047 // 1048 // Operation to run before the first aggregation happens. 1049 // For instance, clears the accumulators so that a new aggregation can begin. 1050 string phase_init_op = 1; 1051 1052 // 2. For each client in set of clients: 1053 // a. Restore variables from the client checkpoint. 1054 // 1055 // Loads a checkpoint from a single client written via 1056 // `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once 1057 // for every client checkpoint in a round. 1058 CheckpointOp read_update = 3; 1059 // b. Aggregate the data coming from the client checkpoint. 1060 // 1061 // An operation that aggregates the data from read_update. 1062 // Generally this will add to accumulators and it may leverage internal data 1063 // inside the graph to adjust the weights of the Tensors. 1064 // 1065 // Executed once for each `read_update`, to (for example) update accumulator 1066 // variables using the values loaded during `read_update`. 1067 string aggregate_into_accumulators_op = 4; 1068 1069 // 3. After all clients have been aggregated, possibly restore 1070 // variables that have been aggregated via a separate process. 1071 // 1072 // Optionally restores variables where aggregation is done across 1073 // an entire round of client data updates. In contrast to `read_update`, 1074 // which restores once per client, this occurs after all clients 1075 // in a round have been processed. This allows, for example, side 1076 // channels where aggregation is done by a separate process (such 1077 // as in secure aggregation), in which the side channel aggregated 1078 // tensor is passed to the `before_restore_op` which ensure the 1079 // variables are restored properly. The `after_restore_op` will then 1080 // be responsible for performing the accumulation. 1081 // 1082 // Note that in current use this should not have a SaverDef, but 1083 // should only be used for side channels. 1084 CheckpointOp read_aggregated_update = 10; 1085 1086 // 4. Write the aggregated variables to an intermediate checkpoint. 1087 // 1088 // We require that `aggregate_into_accumulators_op` is associative and 1089 // commutative, so that the aggregates can be computed across 1090 // multiple TensorFlow sessions. 1091 // As an example, say we are computing the sum of 5 client updates: 1092 // A = X1 + X2 + X3 + X4 + X5 1093 // We can always do this in one session by calling `read_update`j and 1094 // `aggregate_into_accumulators_op` once for each client checkpoint. 1095 // 1096 // Alternatively, we could compute: 1097 // A1 = X1 + X2 in one TensorFlow session, and 1098 // A2 = X3 + X4 + X5 in a different session. 1099 // Each of these sessions can then write their accumulator state 1100 // with the `write_intermediate_update` CheckpointOp, and a yet another third 1101 // session can then call `read_intermediate_update` and 1102 // `aggregate_into_accumulators_op` on each of these checkpoints to compute: 1103 // A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5). 1104 CheckpointOp write_intermediate_update = 7; 1105 // End L1 "Intermediate" Aggregation. 1106 // =========================================================================== 1107 1108 // =========================================================================== 1109 // L2 Aggregation and Coordinator. 1110 // 1111 // This aggregates intermediate checkpoints from L1 Aggregation and performs 1112 // the finalizing of the update. Unlike L1 there will only be one instance 1113 // that does this aggregation. 1114 1115 // Pre-condition: 1116 // The execution environment has loaded the graph from `server_graph_bytes` 1117 // and restored the global model using `server_savepoint` from the parent 1118 // `Plan` message. 1119 1120 // 1. Initialize the phase. 1121 // 1122 // This currently re-uses the `phase_init_op` from L1 aggregation above. 1123 1124 // 2. Write a checkpoint that can be sent to the client. 1125 // 1126 // Generates a checkpoint to be sent to the client, to be read by 1127 // `FederatedComputeIORouter.input_filepath_tensor_name`. 1128 1129 CheckpointOp write_client_init = 2; 1130 1131 // 3. For each intermediate checkpoint: 1132 // a. Restore variables from the intermediate checkpoint. 1133 // 1134 // The corresponding read checkpoint op to the write_intermediate_update. 1135 // This is used instead of read_update for intermediate checkpoints because 1136 // the format of these updates may be different than those used in updates 1137 // from clients (which may, for example, be compressed). 1138 CheckpointOp read_intermediate_update = 9; 1139 // b. Aggregate the data coming from the intermediate checkpoint. 1140 // 1141 // An operation that aggregates the data from `read_intermediate_update`. 1142 // Generally this will add to accumulators and it may leverage internal data 1143 // inside the graph to adjust the weights of the Tensors. 1144 string intermediate_aggregate_into_accumulators_op = 11; 1145 1146 // 4. Write the aggregated intermediate variables to a checkpoint. 1147 // 1148 // This is used for downstream, cross-round aggregation of metrics. 1149 // These variables will be read back into a session with 1150 // read_intermediate_update. 1151 // 1152 // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def 1153 // to disable writing accumulator checkpoints. 1154 CheckpointOp write_accumulators = 12; 1155 1156 // 5. Finalize the round. 1157 // 1158 // This can include: 1159 // - Applying the update aggregated from the intermediate checkpoints to the 1160 // global model and other updates to cross-round state variables. 1161 // - Computing final round metric values (e.g. the `report` of a 1162 // `tff.federated_aggregate`). 1163 string apply_aggregrated_updates_op = 5; 1164 1165 // 5. Fetch the server aggregated metrics. 1166 // 1167 // A list of names of metric variables to fetch from the TensorFlow session. 1168 repeated Metric metrics = 6; 1169 1170 // 6. Serialize the updated server state (e.g. the coefficients of the global 1171 // model in FL) using `server_savepoint` in the parent `Plan` message. 1172 1173 // End L2 Aggregation. 1174 // =========================================================================== 1175} 1176 1177// Represents the server phase in an eligibility computation. 1178// 1179// This phase produces a checkpoint to be sent to clients. This checkpoint is 1180// then used as an input to the clients' task eligibility computations. 1181// This phase *does not include any aggregation.* 1182message ServerEligibilityComputationPhase { 1183 // A short CamelCase name for the ServerEligibilityComputationPhase. 1184 string name = 1; 1185 1186 // The names of the TensorFlow nodes to run in order to produce output. 1187 repeated string target_node_names = 2; 1188 1189 // The specification of inputs and outputs to the TensorFlow graph. 1190 oneof server_eligibility_io_router { 1191 TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true]; 1192 } 1193} 1194 1195// Represents the inputs and outputs of a `ServerEligibilityComputationPhase` 1196// which takes a single `TaskEligibilityContext` as input. 1197message TEContextServerEligibilityIORouter { 1198 // The name of the scalar string tensor that must be fed a serialized 1199 // `TaskEligibilityContext`. 1200 string context_proto_input_tensor_name = 1; 1201 1202 // The name of the scalar string tensor that must be fed the path to which 1203 // the server graph should write the checkpoint file to be sent to the client. 1204 string output_filepath_tensor_name = 2; 1205} 1206 1207// Plan 1208// ===== 1209 1210// Represents the overall plan for performing federated optimization or 1211// personalization, as handed over to the production system. This will 1212// typically be split down into individual pieces for different production 1213// parts, e.g. server and client side. 1214// NEXT_TAG: 15 1215message Plan { 1216 reserved 1, 3, 5; 1217 1218 // The actual type of the server_*_graph_bytes fields below is expected to be 1219 // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form 1220 // for two reasons. 1221 // 1) We may use execution engines other than TensorFlow. 1222 // 2) We wish to avoid the cost of deserialized and re-serializing large 1223 // graphs, in the Federated Learning service. 1224 1225 // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes, 1226 // server_graph_prepare_bytes, and server_graph_result_bytes may all be set. 1227 // If we're using a MapReduceForm-based server implementation, only 1228 // server_graph_bytes will be used. If we're using a DistributeAggregateForm- 1229 // based server implementation, only server_graph_prepare_bytes and 1230 // server_graph_result_bytes will be used. 1231 1232 // Optional. The TensorFlow graph used for all server processing described by 1233 // ServerPhase. For personalization, this will not be set. 1234 google.protobuf.Any server_graph_bytes = 7; 1235 1236 // Optional. The TensorFlow graph used for all server processing described by 1237 // ServerPhaseV2.tensorflow_spec_prepare. 1238 google.protobuf.Any server_graph_prepare_bytes = 13; 1239 1240 // Optional. The TensorFlow graph used for all server processing described by 1241 // ServerPhaseV2.tensorflow_spec_result. 1242 google.protobuf.Any server_graph_result_bytes = 14; 1243 1244 // A savepoint to sync the server checkpoint with a persistent 1245 // storage system. The storage initially holds a seeded checkpoint 1246 // which can subsequently read and updated by this savepoint. 1247 // Optional-- not present in eligibility computation plans (those with a 1248 // ServerEligibilityComputationPhase). This is used in conjunction with 1249 // ServerPhase only. 1250 CheckpointOp server_savepoint = 2; 1251 1252 // Required. The TensorFlow graph that describes the TensorFlow logic a client 1253 // should perform. It should be consistent with the `TensorflowSpec` field in 1254 // the `client_phase`. The actual type is expected to be tensorflow.GraphDef. 1255 // The TensorFlow graph is stored in serialized form for two reasons. 1256 // 1) We may use execution engines other than TensorFlow. 1257 // 2) We wish to avoid the cost of deserialized and re-serializing large 1258 // graphs, in the Federated Learning service. 1259 google.protobuf.Any client_graph_bytes = 8; 1260 1261 // Optional. The FlatBuffer used for TFLite training. 1262 // It contains the same model information as the client_graph_bytes, but with 1263 // a different format. 1264 bytes client_tflite_graph_bytes = 12; 1265 1266 // A pair of client phase and server phase which are processed in 1267 // sync. The server execution defines how the results of a client 1268 // phase are aggregated, and how the checkpoints for clients are 1269 // generated. 1270 message Phase { 1271 // Required. The client phase. 1272 ClientPhase client_phase = 1; 1273 1274 // Optional. Server phase for TF-based aggregation; not provided for 1275 // personalization or eligibility tasks. 1276 ServerPhase server_phase = 2; 1277 1278 // Optional. Server phase for native aggregation; only provided for tasks 1279 // that have enabled the corresponding flag. 1280 ServerPhaseV2 server_phase_v2 = 4; 1281 1282 // Optional. Only provided for eligibility tasks. 1283 ServerEligibilityComputationPhase server_eligibility_phase = 3; 1284 } 1285 1286 // A pair of client and server computations to run. 1287 repeated Phase phase = 4; 1288 1289 // Metrics that are persistent across different phases. This 1290 // includes, for example, counters that track how much work of 1291 // different kinds has been done. 1292 repeated Metric metrics = 6; 1293 1294 // Describes how metrics in both the client and server phases should be 1295 // aggregated. 1296 repeated OutputMetric output_metrics = 10; 1297 1298 // Version of the plan: 1299 // version == 0 - Old plan without version field, containing b/65131070 1300 // version >= 1 - plan supports multi-shard aggregation mode (L1/L2) 1301 int32 version = 9; 1302 1303 // A TensorFlow ConfigProto packed in an Any. 1304 // 1305 // If this field is unset, if the Any proto is set but empty, or if the Any 1306 // proto is populated with an empty ConfigProto (i.e. its `type_url` field is 1307 // set, but the `value` field is empty) then the client implementation may 1308 // choose a set of configuration parameters to provide to TensorFlow by 1309 // default. 1310 // 1311 // In all other cases this field must contain a valid packed ConfigProto 1312 // (invalid values will result in an error at execution time), and in this 1313 // case the client will not provide any other configuration parameters by 1314 // default. 1315 google.protobuf.Any tensorflow_config_proto = 11; 1316} 1317 1318// Represents a client part of the plan of federated optimization. 1319// This also used to describe a client-only plan for standalone on-device 1320// training, known as personalization. 1321// NEXT_TAG: 6 1322message ClientOnlyPlan { 1323 reserved 3; 1324 1325 // The graph to use for training, in binary form. 1326 bytes graph = 1; 1327 1328 // Optional. The flatbuffer used for TFLite training. 1329 // Whether "graph" or "tflite_graph" is used for training is up to the client 1330 // code to allow for a flag-controlled a/b rollout. 1331 bytes tflite_graph = 5; 1332 1333 // The client phase to execute. 1334 ClientPhase phase = 2; 1335 1336 // A TensorFlow ConfigProto. 1337 google.protobuf.Any tensorflow_config_proto = 4; 1338} 1339 1340// Represents the cross round aggregation portion for user defined measurements. 1341// This is used by tools that process / analyze accumulator checkpoints 1342// after a round of computation, to achieve aggregation beyond a round. 1343message CrossRoundAggregationExecution { 1344 // Operation to run before reading accumulator checkpoint. 1345 string init_op = 1; 1346 1347 // Reads accumulator checkpoint. 1348 CheckpointOp read_aggregated_update = 2; 1349 1350 // Operation to merge loaded checkpoint into accumulator. 1351 string merge_op = 3; 1352 1353 // Reads and writes the final aggregated accumulator vars. 1354 CheckpointOp read_write_final_accumulators = 6; 1355 1356 // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable` 1357 // to the user defined name of the signal. 1358 repeated Measurement measurements = 4; 1359 1360 // The `tf.Graph` used for aggregating accumulator checkpoints when 1361 // loading metrics. 1362 google.protobuf.Any cross_round_aggregation_graph_bytes = 5; 1363} 1364 1365message Measurement { 1366 // Name of a TensorFlow op to run to read/fetch the value of this measurement. 1367 string read_op_name = 1; 1368 1369 // A human-readable name for the measurement. Names are usually 1370 // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. 1371 string name = 2; 1372 1373 reserved 3; 1374 1375 // A serialized `tff.Type` for the measurement. 1376 bytes tff_type = 4; 1377} 1378