1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
19
20 namespace xla {
21
RegisterModule(const ModuleIdentifier & module_id,std::shared_ptr<HloModule> hlo_module,std::shared_ptr<const BufferAssignmentProto> buffer_assignment)22 void XlaDebugInfoManager::RegisterModule(
23 const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
24 std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
25 tensorflow::mutex_lock lock(mutex_);
26 if (active_modules_.find(module_id) != active_modules_.end()) {
27 active_modules_[module_id].instances.emplace_back(hlo_module,
28 buffer_assignment);
29 } else {
30 XlaModuleEntry m;
31 m.module_id = module_id;
32 m.instances.emplace_back(hlo_module, buffer_assignment);
33 active_modules_[module_id] = std::move(m);
34 }
35 }
36
37 // Unregister an active module, when the last active module of the same
38 // module id is out of scope, we remove it from our database.
39 // However during tracing, we will defer the cleanup after serialization.
UnregisterModule(const ModuleIdentifier & module_id,std::shared_ptr<HloModule> hlo_module,std::shared_ptr<const BufferAssignmentProto> buffer_assignment)40 void XlaDebugInfoManager::UnregisterModule(
41 const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
42 std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
43 tensorflow::mutex_lock lock(mutex_);
44 CHECK(active_modules_.find(module_id) != active_modules_.end());
45 XlaModuleEntry& active_module = active_modules_[module_id];
46 auto instance_it =
47 absl::c_find_if(active_module.instances, [&](XlaModuleInstance& e) {
48 return e.hlo_module == hlo_module &&
49 e.buffer_assignment == buffer_assignment;
50 });
51
52 CHECK(instance_it != active_module.instances.end());
53
54 if (!tracing_active_) {
55 active_module.instances.erase(instance_it);
56 if (active_module.instances.empty()) {
57 active_modules_.erase(module_id);
58 }
59 } else {
60 instance_it->active = false;
61 }
62 }
63
OnModuleStart(ModuleIdentifier module_id)64 void XlaDebugInfoManager::OnModuleStart(ModuleIdentifier module_id) {
65 tensorflow::mutex_lock lock(mutex_);
66 running_module_ids_[module_id]++;
67 }
68
OnModuleStop(ModuleIdentifier module_id)69 void XlaDebugInfoManager::OnModuleStop(ModuleIdentifier module_id) {
70 tensorflow::mutex_lock lock(mutex_);
71 if (--running_module_ids_[module_id] == 0) {
72 if (!tracing_active_) {
73 running_module_ids_.erase(module_id);
74 }
75 }
76 }
77
StartTracing()78 void XlaDebugInfoManager::StartTracing() {
79 tensorflow::mutex_lock lock(mutex_);
80 tracing_active_ = true;
81 }
82
StopTracing(std::vector<XlaModuleDebugInfo> * module_debug_info)83 void XlaDebugInfoManager::StopTracing(
84 std::vector<XlaModuleDebugInfo>* module_debug_info) {
85 std::vector<XlaModuleEntry> modules_to_serialize;
86 {
87 tensorflow::mutex_lock lock(mutex_);
88 if (!tracing_active_) return;
89 tracing_active_ = false;
90 for (const auto& running_module_id : running_module_ids_) {
91 const ModuleIdentifier& module_id = running_module_id.first;
92 if (active_modules_.find(module_id) == active_modules_.end()) {
93 LOG(ERROR) << "Cannot find debug info for module: " << module_id;
94 continue;
95 }
96 const XlaModuleEntry& active_module = active_modules_[module_id];
97
98 // Copy the instance so that we can serialize without holding the lock.
99 // All instances are equivalent from the perspective of symbolization.
100 // We only use the first one.
101 if (!active_module.instances.empty()) {
102 XlaModuleEntry e;
103 e.module_id = active_module.module_id;
104 e.instances.push_back(active_module.instances[0]);
105 modules_to_serialize.push_back(std::move(e));
106 }
107 }
108
109 // Remove all running_module_ids which has a reference count equal to zero.
110 for (auto it = running_module_ids_.begin();
111 it != running_module_ids_.end();) {
112 if (it->second == 0) {
113 running_module_ids_.erase(it++);
114 } else {
115 ++it;
116 }
117 }
118
119 // Remove all active modules which have an instance count equal to zero.
120 for (auto it = active_modules_.begin(); it != active_modules_.end();) {
121 auto& active_module = it->second;
122 for (auto instance = active_module.instances.begin();
123 instance != active_module.instances.end();) {
124 if (instance->active) {
125 ++instance;
126 } else {
127 instance = active_module.instances.erase(instance);
128 }
129 }
130
131 if (active_module.instances.empty()) {
132 active_modules_.erase(it++);
133 } else {
134 ++it;
135 }
136 }
137 }
138
139 if (module_debug_info) {
140 module_debug_info->clear();
141 for (const auto& m : modules_to_serialize) {
142 XlaModuleDebugInfo info;
143 info.module_id = m.module_id;
144 // In real world, hlo_module and buffer_assignment will always be
145 // non-nullptr. Due to the inconvenience of creation of buffer_assignment
146 // object in test, we set it to nullptr and guard this for it.
147 if (m.instances[0].hlo_module && m.instances[0].buffer_assignment) {
148 info.hlo_proto = absl::make_unique<HloProto>(
149 MakeHloProto(*m.instances[0].hlo_module));
150 *info.hlo_proto->mutable_buffer_assignment() =
151 *m.instances[0].buffer_assignment;
152 }
153 module_debug_info->emplace_back(std::move(info));
154 }
155 }
156 }
157
158 } // namespace xla
159