1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/tpu_global_init.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/cc/ops/tpu_configuration_ops.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/device_set.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/graph_runner.h"
33 #include "tensorflow/core/common_runtime/optimization_registry.h"
34 #include "tensorflow/core/common_runtime/session_factory.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/public/session.h"
40 #include "tensorflow/core/public/session_options.h"
41 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
42 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
43 #include "tensorflow/core/tpu/tpu_defs.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45
46 namespace tensorflow {
47
48 namespace {
49
50 ABSL_CONST_INIT static absl::Mutex global_init_tpu_mutex(absl::kConstInit);
51 static tpu::TopologyProto* global_tpu_topology
52 ABSL_GUARDED_BY(global_init_tpu_mutex) = nullptr;
53
54 constexpr char kTaskSpec[] = "/job:localhost/replica:0/task:0";
55
CreateDeviceMgr(Env * env,std::unique_ptr<DeviceMgr> * device_mgr)56 Status CreateDeviceMgr(Env* env, std::unique_ptr<DeviceMgr>* device_mgr) {
57 SessionOptions session_options;
58 session_options.env = env;
59 std::vector<std::unique_ptr<Device>> devices;
60 DeviceFactory* device_factory = DeviceFactory::GetFactory(DEVICE_TPU_SYSTEM);
61 if (device_factory == nullptr) {
62 return errors::Internal("Unable to initialize DeviceFactory.");
63 }
64 TF_RETURN_IF_ERROR(
65 device_factory->CreateDevices(session_options, kTaskSpec, &devices));
66 *device_mgr = std::make_unique<DynamicDeviceMgr>(std::move(devices));
67 return OkStatus();
68 }
69
DeviceSetFromDeviceMgr(const DeviceMgr & device_mgr,DeviceSet * device_set)70 void DeviceSetFromDeviceMgr(const DeviceMgr& device_mgr,
71 DeviceSet* device_set) {
72 int devices_added = 0;
73 for (auto d : device_mgr.ListDevices()) {
74 device_set->AddDevice(d);
75 if (devices_added == 0) {
76 device_set->set_client_device(d);
77 }
78 ++devices_added;
79 }
80 }
81
GetTPUSystemDevice(absl::string_view job_name)82 const std::string GetTPUSystemDevice(absl::string_view job_name) {
83 if (job_name.empty()) {
84 return DeviceNameUtils::LocalName(DEVICE_TPU_SYSTEM, 0);
85 } else {
86 return absl::StrCat("/job:", job_name, "/device:TPU_SYSTEM:0");
87 }
88 }
89
ConstructDistributedInitializationGraph(absl::string_view job_name,const DeviceSet & device_set,Graph * graph_to_run)90 Status ConstructDistributedInitializationGraph(absl::string_view job_name,
91 const DeviceSet& device_set,
92 Graph* graph_to_run) {
93 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
94 GraphOptimizationPassOptions options;
95 options.graph = &graph;
96 options.device_set = &device_set;
97 {
98 Scope scope = Scope::NewRootScope();
99 auto init_op = ops::ConfigureDistributedTPU(
100 scope.WithOpName("InitializeTPUSystemGlobally")
101 .WithDevice(GetTPUSystemDevice(job_name)),
102 ops::ConfigureDistributedTPU::IsGlobalInit(true));
103 TF_RETURN_IF_ERROR(scope.ToGraph(options.graph->get()));
104 }
105 DistributedTPUConfigurationRewritePass rewriter;
106 TF_RETURN_IF_ERROR(rewriter.Run(options));
107
108 // Graph doesn't update the node-def's after adding edges, which causes
109 // node-def validation to fail in the executor. So we explicitly do a
110 // round-trip through GraphDef, so that node-defs are updated.
111 TF_RETURN_IF_ERROR(
112 ConvertGraphDefToGraph({}, graph->ToGraphDefDebug(), graph_to_run));
113
114 return OkStatus();
115 }
116
InitializeFromSession(absl::string_view session_target,const Graph * graph_to_run,std::vector<Tensor> * outputs)117 Status InitializeFromSession(absl::string_view session_target,
118 const Graph* graph_to_run,
119 std::vector<Tensor>* outputs) {
120 tensorflow::SessionOptions s_opts;
121 s_opts.target = std::string(session_target);
122
123 std::unique_ptr<tensorflow::Session> sess(tensorflow::NewSession(s_opts));
124
125 GraphDef g_def;
126 graph_to_run->ToGraphDef(&g_def);
127
128 TF_RETURN_IF_ERROR(sess->Create(g_def));
129 TF_RETURN_IF_ERROR(
130 sess->Run({}, {"InitializeTPUSystemGlobally:0"}, {}, outputs));
131
132 return OkStatus();
133 }
134
135 } // namespace
136
InitializeTPUSystemGlobally(absl::string_view job_name,absl::string_view session_target,const DeviceSet & device_set,Env * env,tpu::TopologyProto * tpu_topology)137 Status InitializeTPUSystemGlobally(absl::string_view job_name,
138 absl::string_view session_target,
139 const DeviceSet& device_set, Env* env,
140 tpu::TopologyProto* tpu_topology) {
141 VLOG(1) << "InitializeTpuSystemGlobally";
142
143 absl::MutexLock lock(&global_init_tpu_mutex);
144 if (global_tpu_topology != nullptr) {
145 *tpu_topology = *global_tpu_topology;
146 return OkStatus();
147 }
148
149 std::unique_ptr<Graph> graph_to_run(new Graph(OpRegistry::Global()));
150
151 DeviceNameUtils::ParsedName system_spec;
152 Device* tpu_system_device;
153
154 std::string task_spec =
155 job_name.empty() ? kTaskSpec
156 : absl::StrCat("/job:", job_name, "/replica:0/task:0");
157 // Placed here, much before usage, to get a sane error if TPU_SYSTEM_DEVICE
158 // hasn't been linked in. Otherwise we may get a cryptic error down the line.
159 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
160 task_spec, device_set, &system_spec, &tpu_system_device));
161
162 TF_RETURN_IF_ERROR(ConstructDistributedInitializationGraph(
163 job_name, device_set, graph_to_run.get()));
164
165 std::vector<Tensor> outputs;
166 // Being a bit conservative here to run non-distributed initialization with
167 // graph runner.
168 // TODO(hthu): Re-evaluate the choice of using session for running the
169 // initialization graph given that we need to a session in distributed
170 // initialization anyway.
171 if (session_target.empty()) {
172 GraphRunner graph_runner(tpu_system_device);
173 TF_RETURN_IF_ERROR(graph_runner.Run(graph_to_run.get(), nullptr, {},
174 {"InitializeTPUSystemGlobally:0"},
175 &outputs));
176 } else {
177 TF_RETURN_IF_ERROR(
178 InitializeFromSession(session_target, graph_to_run.get(), &outputs));
179 }
180
181 if (outputs.empty()) {
182 return errors::Internal("No output from running TPU initialization.");
183 }
184
185 global_tpu_topology = new tpu::TopologyProto();
186 if (!global_tpu_topology->ParseFromString(outputs[0].scalar<tstring>()())) {
187 return errors::Internal(
188 "Unable to parse output from running TPU initialization as "
189 "TopologyProto proto.");
190 }
191
192 *tpu_topology = *global_tpu_topology;
193 return OkStatus();
194 }
195
196 // NOTE: Session would have been the obvious first choice to run the graph
197 // here, but instead we use a GraphRunner because Session creates a global
198 // EigenThreadPool based on the SessionOptions it receives the first time it
199 // runs. This means that we need to create the right options and pass it to this
200 // API to make it work correctly. We felt it was an onerous restriction to place
201 // on the API, so we went with the current approach.
InitializeTPUSystemGlobally(Env * env,tpu::TopologyProto * tpu_topology)202 Status InitializeTPUSystemGlobally(Env* env, tpu::TopologyProto* tpu_topology) {
203 std::unique_ptr<DeviceMgr> device_mgr;
204 TF_RETURN_IF_ERROR(CreateDeviceMgr(env, &device_mgr));
205 DeviceSet device_set;
206 DeviceSetFromDeviceMgr(*device_mgr, &device_set);
207
208 return InitializeTPUSystemGlobally(/*job_name=*/absl::string_view(),
209 /*session_target=*/absl::string_view(),
210 device_set, env, tpu_topology);
211 }
212
InitializeTPUSystemGlobally()213 Status InitializeTPUSystemGlobally() {
214 tensorflow::tpu::TopologyProto tpu_topology;
215 return InitializeTPUSystemGlobally(tensorflow::Env::Default(), &tpu_topology);
216 }
217
218 } // namespace tensorflow
219