• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
16 
17 #include <iterator>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/strings/substitute.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/data/rewrite_utils.h"
30 #include "tensorflow/core/data/service/common.h"
31 #include "tensorflow/core/data/service/common.pb.h"
32 #include "tensorflow/core/framework/dataset_options.pb.h"
33 #include "tensorflow/core/framework/graph.pb.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
37 #include "tensorflow/core/grappler/grappler_item.h"
38 #include "tensorflow/core/grappler/grappler_item_builder.h"
39 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
40 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
41 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
42 #include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
43 #include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/protobuf/data_service.pb.h"
48 #include "tensorflow/core/protobuf/meta_graph.pb.h"
49 
50 namespace tensorflow {
51 namespace data {
52 namespace {
53 
54 using ::tensorflow::data::experimental::AutoShardDatasetOp;
55 
56 // Extracts the host from `address`.
GetHost(absl::string_view address)57 std::string GetHost(absl::string_view address) {
58   absl::string_view::size_type port_pos = address.find_last_of(':');
59   return std::string(address.substr(0, port_pos));
60 }
61 
62 // Extracts the port from `address`. Returns nullopt if `address` does not
63 // specify a port.
GetPort(absl::string_view address)64 absl::optional<absl::string_view> GetPort(absl::string_view address) {
65   absl::string_view::size_type port_pos = address.find_last_of(':');
66   if (port_pos == absl::string_view::npos) {
67     return absl::nullopt;
68   }
69   return address.substr(port_pos + 1);
70 }
71 
72 // A dynamic port has form %port% or %port_foo% that is to be replaced with the
73 // actual port.
HasDynamicPort(absl::string_view address)74 bool HasDynamicPort(absl::string_view address) {
75   absl::optional<absl::string_view> port = GetPort(address);
76   return port && absl::StartsWith(*port, "%port") && absl::EndsWith(*port, "%");
77 }
78 
79 // Returns true if `config_address` has no port or a dynamic port (e.g.: %port%)
80 // and `worker_address` has an actual port (number of named port).
81 //
82 // For example, it returns true for the following cases:
83 //
84 //  config_address                    worker_address
85 //  ----------------------------------------------------------
86 //  /worker/task/0                    /worker/task/0:worker
87 //  /worker/task/0:%port%             /worker/task/0:10000
88 //  /worker/task/0:%port_worker%      /worker/task/0:worker
89 //  /worker/task/0:%port_worker%      /worker/task/0:10000
90 //  localhost                         localhost:10000
91 //  localhost:%port%                  localhost:10000
ShouldReplaceDynamicPort(absl::string_view config_address,absl::string_view worker_address)92 bool ShouldReplaceDynamicPort(absl::string_view config_address,
93                               absl::string_view worker_address) {
94   return (!GetPort(config_address) || HasDynamicPort(config_address)) &&
95          GetPort(worker_address) &&
96          GetHost(config_address) == GetHost(worker_address);
97 }
98 }  // namespace
99 
Create(const TaskDef & task_def)100 StatusOr<AutoShardRewriter> AutoShardRewriter::Create(const TaskDef& task_def) {
101   TF_ASSIGN_OR_RETURN(
102       AutoShardPolicy auto_shard_policy,
103       ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy()));
104   return AutoShardRewriter(auto_shard_policy, task_def.num_workers(),
105                            task_def.worker_index());
106 }
107 
ApplyAutoShardRewrite(const GraphDef & graph_def)108 StatusOr<GraphDef> AutoShardRewriter::ApplyAutoShardRewrite(
109     const GraphDef& graph_def) {
110   if (auto_shard_policy_ == AutoShardPolicy::OFF) {
111     return graph_def;
112   }
113 
114   VLOG(2) << "Applying auto-shard policy "
115           << AutoShardPolicy_Name(auto_shard_policy_)
116           << ". Number of workers: " << num_workers_
117           << "; worker index: " << worker_index_ << ".";
118   grappler::AutoShard autoshard;
119   tensorflow::RewriterConfig::CustomGraphOptimizer config = GetRewriteConfig();
120   TF_RETURN_IF_ERROR(autoshard.Init(&config));
121 
122   GraphDef input_graph = graph_def;
123   TF_ASSIGN_OR_RETURN(std::string dataset_node, GetDatasetNode(input_graph));
124   std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
125       GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false);
126 
127   GraphDef rewritten_graph;
128   std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
129   tensorflow::grappler::VirtualCluster cluster(device_map);
130   grappler::AutoShard::OptimizationStats stats;
131   TF_RETURN_IF_ERROR(autoshard.OptimizeAndCollectStats(
132       &cluster, *grappler_item, &rewritten_graph, &stats));
133   return rewritten_graph;
134 }
135 
AutoShardRewriter(AutoShardPolicy auto_shard_policy,int64 num_workers,int64 worker_index)136 AutoShardRewriter::AutoShardRewriter(AutoShardPolicy auto_shard_policy,
137                                      int64 num_workers, int64 worker_index)
138     : auto_shard_policy_(auto_shard_policy),
139       num_workers_(num_workers),
140       worker_index_(worker_index) {}
141 
142 tensorflow::RewriterConfig::CustomGraphOptimizer
GetRewriteConfig() const143 AutoShardRewriter::GetRewriteConfig() const {
144   tensorflow::RewriterConfig::CustomGraphOptimizer config;
145   config.set_name("tf-data-service-auto-shard");
146   (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumWorkers].set_i(
147       num_workers_);
148   (*config.mutable_parameter_map())[AutoShardDatasetOp::kIndex].set_i(
149       worker_index_);
150   (*config.mutable_parameter_map())[AutoShardDatasetOp::kAutoShardPolicy].set_i(
151       auto_shard_policy_);
152   (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumReplicas].set_i(1);
153   return config;
154 }
155 
ValidateWorker(absl::string_view worker_address) const156 Status WorkerIndexResolver::ValidateWorker(
157     absl::string_view worker_address) const {
158   if (worker_addresses_.empty()) {
159     return Status::OK();
160   }
161 
162   for (absl::string_view config_address : worker_addresses_) {
163     if (config_address == worker_address ||
164         ShouldReplaceDynamicPort(config_address, worker_address)) {
165       return Status::OK();
166     }
167   }
168 
169   return errors::FailedPrecondition(absl::Substitute(
170       "Failed to assign an index for worker $0. Configured workers list: [$1]. "
171       "The worker's address is not configured, or other workers are already "
172       "running at the configured host. If your worker has restarted, make sure "
173       "it runs at the same address and port.",
174       worker_address, absl::StrJoin(worker_addresses_, ", ")));
175 }
176 
AddWorker(absl::string_view worker_address)177 void WorkerIndexResolver::AddWorker(absl::string_view worker_address) {
178   for (std::string& config_address : worker_addresses_) {
179     if (config_address == worker_address) {
180       return;
181     }
182     if (ShouldReplaceDynamicPort(config_address, worker_address)) {
183       config_address = std::string(worker_address);
184       return;
185     }
186   }
187 }
188 
GetWorkerIndex(absl::string_view worker_address) const189 StatusOr<int64> WorkerIndexResolver::GetWorkerIndex(
190     absl::string_view worker_address) const {
191   const auto it = absl::c_find(worker_addresses_, worker_address);
192   if (it == worker_addresses_.cend()) {
193     return errors::NotFound(absl::Substitute(
194         "Failed to shard dataset in tf.data service: Worker $0 is not in the "
195         "workers list. Got workers list $1.",
196         worker_address, absl::StrJoin(worker_addresses_, ",")));
197   }
198   return std::distance(worker_addresses_.cbegin(), it);
199 }
200 
201 }  // namespace data
202 }  // namespace tensorflow
203