• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
19 
20 namespace xla {
21 
RegisterModule(const ModuleIdentifier & module_id,std::shared_ptr<HloModule> hlo_module,std::shared_ptr<const BufferAssignmentProto> buffer_assignment)22 void XlaDebugInfoManager::RegisterModule(
23     const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
24     std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
25   tensorflow::mutex_lock lock(mutex_);
26   if (active_modules_.find(module_id) != active_modules_.end()) {
27     active_modules_[module_id].instances.emplace_back(hlo_module,
28                                                       buffer_assignment);
29   } else {
30     XlaModuleEntry m;
31     m.module_id = module_id;
32     m.instances.emplace_back(hlo_module, buffer_assignment);
33     active_modules_[module_id] = std::move(m);
34   }
35 }
36 
37 // Unregister an active module, when the last active module of the same
38 // module id is out of scope, we remove it from our database.
39 // However during tracing, we will defer the cleanup after serialization.
UnregisterModule(const ModuleIdentifier & module_id,std::shared_ptr<HloModule> hlo_module,std::shared_ptr<const BufferAssignmentProto> buffer_assignment)40 void XlaDebugInfoManager::UnregisterModule(
41     const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
42     std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
43   tensorflow::mutex_lock lock(mutex_);
44   CHECK(active_modules_.find(module_id) != active_modules_.end());
45   XlaModuleEntry& active_module = active_modules_[module_id];
46   auto instance_it =
47       absl::c_find_if(active_module.instances, [&](XlaModuleInstance& e) {
48         return e.hlo_module == hlo_module &&
49                e.buffer_assignment == buffer_assignment;
50       });
51 
52   CHECK(instance_it != active_module.instances.end());
53 
54   if (!tracing_active_) {
55     active_module.instances.erase(instance_it);
56     if (active_module.instances.empty()) {
57       active_modules_.erase(module_id);
58     }
59   } else {
60     instance_it->active = false;
61   }
62 }
63 
OnModuleStart(ModuleIdentifier module_id)64 void XlaDebugInfoManager::OnModuleStart(ModuleIdentifier module_id) {
65   tensorflow::mutex_lock lock(mutex_);
66   running_module_ids_[module_id]++;
67 }
68 
OnModuleStop(ModuleIdentifier module_id)69 void XlaDebugInfoManager::OnModuleStop(ModuleIdentifier module_id) {
70   tensorflow::mutex_lock lock(mutex_);
71   if (--running_module_ids_[module_id] == 0) {
72     if (!tracing_active_) {
73       running_module_ids_.erase(module_id);
74     }
75   }
76 }
77 
StartTracing()78 void XlaDebugInfoManager::StartTracing() {
79   tensorflow::mutex_lock lock(mutex_);
80   tracing_active_ = true;
81 }
82 
StopTracing(std::vector<XlaModuleDebugInfo> * module_debug_info)83 void XlaDebugInfoManager::StopTracing(
84     std::vector<XlaModuleDebugInfo>* module_debug_info) {
85   std::vector<XlaModuleEntry> modules_to_serialize;
86   {
87     tensorflow::mutex_lock lock(mutex_);
88     if (!tracing_active_) return;
89     tracing_active_ = false;
90     for (const auto& running_module_id : running_module_ids_) {
91       const ModuleIdentifier& module_id = running_module_id.first;
92       if (active_modules_.find(module_id) == active_modules_.end()) {
93         LOG(ERROR) << "Cannot find debug info for module: " << module_id;
94         continue;
95       }
96       const XlaModuleEntry& active_module = active_modules_[module_id];
97 
98       // Copy the instance so that we can serialize without holding the lock.
99       // All instances are equivalent from the perspective of symbolization.
100       // We only use the first one.
101       if (!active_module.instances.empty()) {
102         XlaModuleEntry e;
103         e.module_id = active_module.module_id;
104         e.instances.push_back(active_module.instances[0]);
105         modules_to_serialize.push_back(std::move(e));
106       }
107     }
108 
109     // Remove all running_module_ids which has a reference count equal to zero.
110     for (auto it = running_module_ids_.begin();
111          it != running_module_ids_.end();) {
112       if (it->second == 0) {
113         running_module_ids_.erase(it++);
114       } else {
115         ++it;
116       }
117     }
118 
119     // Remove all active modules which have an instance count equal to zero.
120     for (auto it = active_modules_.begin(); it != active_modules_.end();) {
121       auto& active_module = it->second;
122       for (auto instance = active_module.instances.begin();
123            instance != active_module.instances.end();) {
124         if (instance->active) {
125           ++instance;
126         } else {
127           instance = active_module.instances.erase(instance);
128         }
129       }
130 
131       if (active_module.instances.empty()) {
132         active_modules_.erase(it++);
133       } else {
134         ++it;
135       }
136     }
137   }
138 
139   if (module_debug_info) {
140     module_debug_info->clear();
141     for (const auto& m : modules_to_serialize) {
142       XlaModuleDebugInfo info;
143       info.module_id = m.module_id;
144       // In real world, hlo_module and buffer_assignment will always be
145       // non-nullptr. Due to the inconvenience of creation of buffer_assignment
146       // object in test, we set it to nullptr and guard this for it.
147       if (m.instances[0].hlo_module && m.instances[0].buffer_assignment) {
148         info.hlo_proto = absl::make_unique<HloProto>(
149             MakeHloProto(*m.instances[0].hlo_module));
150         *info.hlo_proto->mutable_buffer_assignment() =
151             *m.instances[0].buffer_assignment;
152       }
153       module_debug_info->emplace_back(std::move(info));
154     }
155   }
156 }
157 
158 }  // namespace xla
159