• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/tpu_global_init.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/cc/framework/scope.h"
26 #include "tensorflow/cc/ops/tpu_configuration_ops.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/device_set.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/graph_runner.h"
33 #include "tensorflow/core/common_runtime/optimization_registry.h"
34 #include "tensorflow/core/common_runtime/session_factory.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/public/session.h"
40 #include "tensorflow/core/public/session_options.h"
41 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
42 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
43 #include "tensorflow/core/tpu/tpu_defs.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45 
46 namespace tensorflow {
47 
48 namespace {
49 
50 ABSL_CONST_INIT static absl::Mutex global_init_tpu_mutex(absl::kConstInit);
51 static tpu::TopologyProto* global_tpu_topology
52     ABSL_GUARDED_BY(global_init_tpu_mutex) = nullptr;
53 
54 constexpr char kTaskSpec[] = "/job:localhost/replica:0/task:0";
55 
CreateDeviceMgr(Env * env,std::unique_ptr<DeviceMgr> * device_mgr)56 Status CreateDeviceMgr(Env* env, std::unique_ptr<DeviceMgr>* device_mgr) {
57   SessionOptions session_options;
58   session_options.env = env;
59   std::vector<std::unique_ptr<Device>> devices;
60   DeviceFactory* device_factory = DeviceFactory::GetFactory(DEVICE_TPU_SYSTEM);
61   if (device_factory == nullptr) {
62     return errors::Internal("Unable to initialize DeviceFactory.");
63   }
64   TF_RETURN_IF_ERROR(
65       device_factory->CreateDevices(session_options, kTaskSpec, &devices));
66   *device_mgr = std::make_unique<DynamicDeviceMgr>(std::move(devices));
67   return OkStatus();
68 }
69 
DeviceSetFromDeviceMgr(const DeviceMgr & device_mgr,DeviceSet * device_set)70 void DeviceSetFromDeviceMgr(const DeviceMgr& device_mgr,
71                             DeviceSet* device_set) {
72   int devices_added = 0;
73   for (auto d : device_mgr.ListDevices()) {
74     device_set->AddDevice(d);
75     if (devices_added == 0) {
76       device_set->set_client_device(d);
77     }
78     ++devices_added;
79   }
80 }
81 
GetTPUSystemDevice(absl::string_view job_name)82 const std::string GetTPUSystemDevice(absl::string_view job_name) {
83   if (job_name.empty()) {
84     return DeviceNameUtils::LocalName(DEVICE_TPU_SYSTEM, 0);
85   } else {
86     return absl::StrCat("/job:", job_name, "/device:TPU_SYSTEM:0");
87   }
88 }
89 
ConstructDistributedInitializationGraph(absl::string_view job_name,const DeviceSet & device_set,Graph * graph_to_run)90 Status ConstructDistributedInitializationGraph(absl::string_view job_name,
91                                                const DeviceSet& device_set,
92                                                Graph* graph_to_run) {
93   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
94   GraphOptimizationPassOptions options;
95   options.graph = &graph;
96   options.device_set = &device_set;
97   {
98     Scope scope = Scope::NewRootScope();
99     auto init_op = ops::ConfigureDistributedTPU(
100         scope.WithOpName("InitializeTPUSystemGlobally")
101             .WithDevice(GetTPUSystemDevice(job_name)),
102         ops::ConfigureDistributedTPU::IsGlobalInit(true));
103     TF_RETURN_IF_ERROR(scope.ToGraph(options.graph->get()));
104   }
105   DistributedTPUConfigurationRewritePass rewriter;
106   TF_RETURN_IF_ERROR(rewriter.Run(options));
107 
108   // Graph doesn't update the node-def's after adding edges, which causes
109   // node-def validation to fail in the executor. So we explicitly do a
110   // round-trip through GraphDef, so that node-defs are updated.
111   TF_RETURN_IF_ERROR(
112       ConvertGraphDefToGraph({}, graph->ToGraphDefDebug(), graph_to_run));
113 
114   return OkStatus();
115 }
116 
InitializeFromSession(absl::string_view session_target,const Graph * graph_to_run,std::vector<Tensor> * outputs)117 Status InitializeFromSession(absl::string_view session_target,
118                              const Graph* graph_to_run,
119                              std::vector<Tensor>* outputs) {
120   tensorflow::SessionOptions s_opts;
121   s_opts.target = std::string(session_target);
122 
123   std::unique_ptr<tensorflow::Session> sess(tensorflow::NewSession(s_opts));
124 
125   GraphDef g_def;
126   graph_to_run->ToGraphDef(&g_def);
127 
128   TF_RETURN_IF_ERROR(sess->Create(g_def));
129   TF_RETURN_IF_ERROR(
130       sess->Run({}, {"InitializeTPUSystemGlobally:0"}, {}, outputs));
131 
132   return OkStatus();
133 }
134 
135 }  // namespace
136 
InitializeTPUSystemGlobally(absl::string_view job_name,absl::string_view session_target,const DeviceSet & device_set,Env * env,tpu::TopologyProto * tpu_topology)137 Status InitializeTPUSystemGlobally(absl::string_view job_name,
138                                    absl::string_view session_target,
139                                    const DeviceSet& device_set, Env* env,
140                                    tpu::TopologyProto* tpu_topology) {
141   VLOG(1) << "InitializeTpuSystemGlobally";
142 
143   absl::MutexLock lock(&global_init_tpu_mutex);
144   if (global_tpu_topology != nullptr) {
145     *tpu_topology = *global_tpu_topology;
146     return OkStatus();
147   }
148 
149   std::unique_ptr<Graph> graph_to_run(new Graph(OpRegistry::Global()));
150 
151   DeviceNameUtils::ParsedName system_spec;
152   Device* tpu_system_device;
153 
154   std::string task_spec =
155       job_name.empty() ? kTaskSpec
156                        : absl::StrCat("/job:", job_name, "/replica:0/task:0");
157   // Placed here, much before usage, to get a sane error if TPU_SYSTEM_DEVICE
158   // hasn't been linked in. Otherwise we may get a cryptic error down the line.
159   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
160       task_spec, device_set, &system_spec, &tpu_system_device));
161 
162   TF_RETURN_IF_ERROR(ConstructDistributedInitializationGraph(
163       job_name, device_set, graph_to_run.get()));
164 
165   std::vector<Tensor> outputs;
166   // Being a bit conservative here to run non-distributed initialization with
167   // graph runner.
168   // TODO(hthu): Re-evaluate the choice of using session for running the
169   // initialization graph given that we need to a session in distributed
170   // initialization anyway.
171   if (session_target.empty()) {
172     GraphRunner graph_runner(tpu_system_device);
173     TF_RETURN_IF_ERROR(graph_runner.Run(graph_to_run.get(), nullptr, {},
174                                         {"InitializeTPUSystemGlobally:0"},
175                                         &outputs));
176   } else {
177     TF_RETURN_IF_ERROR(
178         InitializeFromSession(session_target, graph_to_run.get(), &outputs));
179   }
180 
181   if (outputs.empty()) {
182     return errors::Internal("No output from running TPU initialization.");
183   }
184 
185   global_tpu_topology = new tpu::TopologyProto();
186   if (!global_tpu_topology->ParseFromString(outputs[0].scalar<tstring>()())) {
187     return errors::Internal(
188         "Unable to parse output from running TPU initialization as "
189         "TopologyProto proto.");
190   }
191 
192   *tpu_topology = *global_tpu_topology;
193   return OkStatus();
194 }
195 
196 // NOTE: Session would have been the obvious first choice to run the graph
197 // here, but instead we use a GraphRunner because Session creates a global
198 // EigenThreadPool based on the SessionOptions it receives the first time it
199 // runs. This means that we need to create the right options and pass it to this
200 // API to make it work correctly. We felt it was an onerous restriction to place
201 // on the API, so we went with the current approach.
InitializeTPUSystemGlobally(Env * env,tpu::TopologyProto * tpu_topology)202 Status InitializeTPUSystemGlobally(Env* env, tpu::TopologyProto* tpu_topology) {
203   std::unique_ptr<DeviceMgr> device_mgr;
204   TF_RETURN_IF_ERROR(CreateDeviceMgr(env, &device_mgr));
205   DeviceSet device_set;
206   DeviceSetFromDeviceMgr(*device_mgr, &device_set);
207 
208   return InitializeTPUSystemGlobally(/*job_name=*/absl::string_view(),
209                                      /*session_target=*/absl::string_view(),
210                                      device_set, env, tpu_topology);
211 }
212 
InitializeTPUSystemGlobally()213 Status InitializeTPUSystemGlobally() {
214   tensorflow::tpu::TopologyProto tpu_topology;
215   return InitializeTPUSystemGlobally(tensorflow::Env::Default(), &tpu_topology);
216 }
217 
218 }  // namespace tensorflow
219