1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
16
17 #include <iterator>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/strings/substitute.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/data/rewrite_utils.h"
30 #include "tensorflow/core/data/service/common.h"
31 #include "tensorflow/core/data/service/common.pb.h"
32 #include "tensorflow/core/framework/dataset_options.pb.h"
33 #include "tensorflow/core/framework/graph.pb.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
37 #include "tensorflow/core/grappler/grappler_item.h"
38 #include "tensorflow/core/grappler/grappler_item_builder.h"
39 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
40 #include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
41 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
42 #include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
43 #include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/protobuf/data_service.pb.h"
48 #include "tensorflow/core/protobuf/meta_graph.pb.h"
49
50 namespace tensorflow {
51 namespace data {
52 namespace {
53
54 using ::tensorflow::data::experimental::AutoShardDatasetOp;
55
56 // Extracts the host from `address`.
GetHost(absl::string_view address)57 std::string GetHost(absl::string_view address) {
58 absl::string_view::size_type port_pos = address.find_last_of(':');
59 return std::string(address.substr(0, port_pos));
60 }
61
62 // Extracts the port from `address`. Returns nullopt if `address` does not
63 // specify a port.
GetPort(absl::string_view address)64 absl::optional<absl::string_view> GetPort(absl::string_view address) {
65 absl::string_view::size_type port_pos = address.find_last_of(':');
66 if (port_pos == absl::string_view::npos) {
67 return absl::nullopt;
68 }
69 return address.substr(port_pos + 1);
70 }
71
72 // A dynamic port has form %port% or %port_foo% that is to be replaced with the
73 // actual port.
HasDynamicPort(absl::string_view address)74 bool HasDynamicPort(absl::string_view address) {
75 absl::optional<absl::string_view> port = GetPort(address);
76 return port && absl::StartsWith(*port, "%port") && absl::EndsWith(*port, "%");
77 }
78
79 // Returns true if `config_address` has no port or a dynamic port (e.g.: %port%)
80 // and `worker_address` has an actual port (number of named port).
81 //
82 // For example, it returns true for the following cases:
83 //
84 // config_address worker_address
85 // ----------------------------------------------------------
86 // /worker/task/0 /worker/task/0:worker
87 // /worker/task/0:%port% /worker/task/0:10000
88 // /worker/task/0:%port_worker% /worker/task/0:worker
89 // /worker/task/0:%port_worker% /worker/task/0:10000
90 // localhost localhost:10000
91 // localhost:%port% localhost:10000
ShouldReplaceDynamicPort(absl::string_view config_address,absl::string_view worker_address)92 bool ShouldReplaceDynamicPort(absl::string_view config_address,
93 absl::string_view worker_address) {
94 return (!GetPort(config_address) || HasDynamicPort(config_address)) &&
95 GetPort(worker_address) &&
96 GetHost(config_address) == GetHost(worker_address);
97 }
98 } // namespace
99
Create(const TaskDef & task_def)100 StatusOr<AutoShardRewriter> AutoShardRewriter::Create(const TaskDef& task_def) {
101 TF_ASSIGN_OR_RETURN(
102 AutoShardPolicy auto_shard_policy,
103 ToAutoShardPolicy(task_def.processing_mode_def().sharding_policy()));
104 return AutoShardRewriter(auto_shard_policy, task_def.num_workers(),
105 task_def.worker_index());
106 }
107
ApplyAutoShardRewrite(const GraphDef & graph_def)108 StatusOr<GraphDef> AutoShardRewriter::ApplyAutoShardRewrite(
109 const GraphDef& graph_def) {
110 if (auto_shard_policy_ == AutoShardPolicy::OFF) {
111 return graph_def;
112 }
113
114 VLOG(2) << "Applying auto-shard policy "
115 << AutoShardPolicy_Name(auto_shard_policy_)
116 << ". Number of workers: " << num_workers_
117 << "; worker index: " << worker_index_ << ".";
118 grappler::AutoShard autoshard;
119 tensorflow::RewriterConfig::CustomGraphOptimizer config = GetRewriteConfig();
120 TF_RETURN_IF_ERROR(autoshard.Init(&config));
121
122 GraphDef input_graph = graph_def;
123 TF_ASSIGN_OR_RETURN(std::string dataset_node, GetDatasetNode(input_graph));
124 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
125 GetGrapplerItem(&input_graph, &dataset_node, /*add_fake_sinks=*/false);
126
127 GraphDef rewritten_graph;
128 std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
129 tensorflow::grappler::VirtualCluster cluster(device_map);
130 grappler::AutoShard::OptimizationStats stats;
131 TF_RETURN_IF_ERROR(autoshard.OptimizeAndCollectStats(
132 &cluster, *grappler_item, &rewritten_graph, &stats));
133 return rewritten_graph;
134 }
135
AutoShardRewriter(AutoShardPolicy auto_shard_policy,int64 num_workers,int64 worker_index)136 AutoShardRewriter::AutoShardRewriter(AutoShardPolicy auto_shard_policy,
137 int64 num_workers, int64 worker_index)
138 : auto_shard_policy_(auto_shard_policy),
139 num_workers_(num_workers),
140 worker_index_(worker_index) {}
141
142 tensorflow::RewriterConfig::CustomGraphOptimizer
GetRewriteConfig() const143 AutoShardRewriter::GetRewriteConfig() const {
144 tensorflow::RewriterConfig::CustomGraphOptimizer config;
145 config.set_name("tf-data-service-auto-shard");
146 (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumWorkers].set_i(
147 num_workers_);
148 (*config.mutable_parameter_map())[AutoShardDatasetOp::kIndex].set_i(
149 worker_index_);
150 (*config.mutable_parameter_map())[AutoShardDatasetOp::kAutoShardPolicy].set_i(
151 auto_shard_policy_);
152 (*config.mutable_parameter_map())[AutoShardDatasetOp::kNumReplicas].set_i(1);
153 return config;
154 }
155
ValidateWorker(absl::string_view worker_address) const156 Status WorkerIndexResolver::ValidateWorker(
157 absl::string_view worker_address) const {
158 if (worker_addresses_.empty()) {
159 return Status::OK();
160 }
161
162 for (absl::string_view config_address : worker_addresses_) {
163 if (config_address == worker_address ||
164 ShouldReplaceDynamicPort(config_address, worker_address)) {
165 return Status::OK();
166 }
167 }
168
169 return errors::FailedPrecondition(absl::Substitute(
170 "Failed to assign an index for worker $0. Configured workers list: [$1]. "
171 "The worker's address is not configured, or other workers are already "
172 "running at the configured host. If your worker has restarted, make sure "
173 "it runs at the same address and port.",
174 worker_address, absl::StrJoin(worker_addresses_, ", ")));
175 }
176
AddWorker(absl::string_view worker_address)177 void WorkerIndexResolver::AddWorker(absl::string_view worker_address) {
178 for (std::string& config_address : worker_addresses_) {
179 if (config_address == worker_address) {
180 return;
181 }
182 if (ShouldReplaceDynamicPort(config_address, worker_address)) {
183 config_address = std::string(worker_address);
184 return;
185 }
186 }
187 }
188
GetWorkerIndex(absl::string_view worker_address) const189 StatusOr<int64> WorkerIndexResolver::GetWorkerIndex(
190 absl::string_view worker_address) const {
191 const auto it = absl::c_find(worker_addresses_, worker_address);
192 if (it == worker_addresses_.cend()) {
193 return errors::NotFound(absl::Substitute(
194 "Failed to shard dataset in tf.data service: Worker $0 is not in the "
195 "workers list. Got workers list $1.",
196 worker_address, absl::StrJoin(worker_addresses_, ",")));
197 }
198 return std::distance(worker_addresses_.cbegin(), it);
199 }
200
201 } // namespace data
202 } // namespace tensorflow
203