• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_AUTO_SHARD_REWRITER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_AUTO_SHARD_REWRITER_H_
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/framework/dataset_options.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/statusor.h"
27 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
28 
29 namespace tensorflow {
30 namespace data {
31 
32 // Rewrites the dataset graph by applying an auto-shard policy.
33 class AutoShardRewriter {
34  public:
35   // Creates an `AutoShardRewriter` according to `task_def`. Returns an error if
36   // the sharding policy is not a valid auto-shard policy.
37   static StatusOr<AutoShardRewriter> Create(const TaskDef& task_def);
38 
39   // Applies auto-sharding to `graph_def`. If auto-shard policy is OFF, returns
40   // the same graph as `graph_def`. Otherwise, returns the re-written graph.
41   StatusOr<GraphDef> ApplyAutoShardRewrite(const GraphDef& graph_def);
42 
43  private:
44   AutoShardRewriter(AutoShardPolicy auto_shard_policy, int64 num_workers,
45                     int64 worker_index);
46 
47   // Creates a rewrite config based on the auto-shard policy.
48   tensorflow::RewriterConfig::CustomGraphOptimizer GetRewriteConfig() const;
49 
50   const AutoShardPolicy auto_shard_policy_;
51   const int64 num_workers_;
52   const int64 worker_index_;
53 };
54 
55 // Maps a worker to its index, given a list of workers. For example, suppose
56 // `worker_addresses` contains
57 //   /worker/task/0:worker, /worker/task/1:worker, /worker/task/2:worker,
58 // then
59 //   /worker/task/0:worker maps to index 0,
60 //   /worker/task/1:worker maps to index 1,
61 //   /worker/task/2:worker maps to index 2.
62 // This is useful for deterministically sharding a dataset among a fixed set of
63 // tf.data service workers.
64 class WorkerIndexResolver {
65  public:
66   // Constructs a `WorkerIndexResolver` to generate worker indexes according to
67   // the specified worker addresses. The worker addresses can be "host" or
68   // "host:port", where "port" is a number, named port, or "%port%" to be
69   // replaced with the actual port.
70   template <class T>
WorkerIndexResolver(const T & worker_addresses)71   explicit WorkerIndexResolver(const T& worker_addresses)
72       : worker_addresses_(worker_addresses.cbegin(), worker_addresses.cend()) {}
73 
74   // Validates `worker_address`. Returns an error if the `worker_addresses` list
75   // is non-empty and `worker_address` is not specified in the worker addresses
76   // list (with optional port replacement).
77   Status ValidateWorker(absl::string_view worker_address) const;
78 
79   // Processes a worker at address `worker_address`. Its index can be retrieved
80   // by calling `GetWorkerIndex`.
81   void AddWorker(absl::string_view worker_address);
82 
83   // Returns the worker index for the worker at `worker_address`. Returns a
84   // NotFound error if the worker is not registered.
85   StatusOr<int64> GetWorkerIndex(absl::string_view worker_address) const;
86 
87  private:
88   std::vector<std::string> worker_addresses_;
89 };
90 
91 }  // namespace data
92 }  // namespace tensorflow
93 
94 #endif  // TENSORFLOW_CORE_DATA_SERVICE_AUTO_SHARD_REWRITER_H_
95