1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/compiler/xla/service/hlo.pb.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24
25 namespace xla {
26
27 class XlaDebugInfoManagerTestPeer {
28 public:
RegisterModule(ModuleIdentifier module_id,std::shared_ptr<const HloModule> hlo_module,std::shared_ptr<const BufferAssignmentProto> buffer_assignment)29 void RegisterModule(
30 ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
31 std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
32 return xla_debug_info_manager_.RegisterModule(module_id, hlo_module,
33 buffer_assignment);
34 }
35
UnregisterModule(ModuleIdentifier module_id)36 void UnregisterModule(ModuleIdentifier module_id) {
37 return xla_debug_info_manager_.UnregisterModule(module_id);
38 }
39
StartTracing()40 void StartTracing() { return xla_debug_info_manager_.StartTracing(); }
41
StopTracing()42 absl::flat_hash_set<ModuleIdentifier> StopTracing() {
43 std::vector<std::unique_ptr<HloProto>> module_debug_info;
44 xla_debug_info_manager_.StopTracing(&module_debug_info);
45 absl::flat_hash_set<ModuleIdentifier> module_ids;
46 for (const auto& hlo_proto : module_debug_info) {
47 module_ids.insert(hlo_proto->hlo_module().id());
48 }
49 return module_ids;
50 }
51
GetModuleIds()52 absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
53 absl::flat_hash_set<ModuleIdentifier> module_ids;
54 absl::MutexLock lock(&xla_debug_info_manager_.mutex_);
55 for (const auto& it : xla_debug_info_manager_.modules_) {
56 module_ids.insert(it.first);
57 }
58 return module_ids;
59 }
60
61 private:
62 XlaDebugInfoManager xla_debug_info_manager_;
63 };
64
65 namespace {
66
67 using ::testing::IsEmpty;
68 using ::testing::UnorderedElementsAre;
69
70 class XlaDebugInfoManagerTest : public HloTestBase {
71 protected:
72 struct DebugMetadata {
73 // We allow same id to be registered multiple times. we need unique id to
74 // know which program is referenced (such as in UnregisterProgram).
75 ModuleIdentifier unique_id;
76 std::shared_ptr<HloModule> module;
77 std::shared_ptr<BufferAssignmentProto> buffer_assignment;
78 };
79
80 // Return unique id of this module.
RegisterProgram(const std::string & module_name)81 ModuleIdentifier RegisterProgram(const std::string& module_name) {
82 DebugMetadata debug_info;
83 HloModuleConfig config;
84 debug_info.module = std::make_shared<HloModule>(module_name, config);
85 debug_info.buffer_assignment = nullptr;
86 ModuleIdentifier unique_id = debug_info.module->unique_id();
87 debug_info.unique_id = unique_id;
88 xla_debug_info_manager_.RegisterModule(unique_id, debug_info.module,
89 debug_info.buffer_assignment);
90 external_references_.push_back(std::move(debug_info));
91 return unique_id;
92 }
93
UnregisterProgram(ModuleIdentifier unique_id)94 void UnregisterProgram(ModuleIdentifier unique_id) {
95 for (int i = 0; i < external_references_.size(); i++) {
96 if (external_references_[i].unique_id == unique_id) {
97 xla_debug_info_manager_.UnregisterModule(unique_id);
98 external_references_.erase(external_references_.begin() + i);
99 break;
100 }
101 }
102 }
103
GetModuleIds()104 absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
105 return xla_debug_info_manager_.GetModuleIds();
106 }
107
StartTrace()108 void StartTrace() { xla_debug_info_manager_.StartTracing(); }
109
StopTrace()110 absl::flat_hash_set<ModuleIdentifier> StopTrace() {
111 return xla_debug_info_manager_.StopTracing();
112 }
113
114 // Simulation of compilation cache.
115 std::vector<DebugMetadata> external_references_;
116
117 // Use an instance per test instead of singleton to avoid interferences.
118 XlaDebugInfoManagerTestPeer xla_debug_info_manager_;
119 };
120
121 // Test the cases where no trace session is involved.
TEST_F(XlaDebugInfoManagerTest,NoTraceBasic)122 TEST_F(XlaDebugInfoManagerTest, NoTraceBasic) {
123 auto program0 = RegisterProgram("program0");
124 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0));
125
126 auto program1 = RegisterProgram("program1");
127 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0, program1));
128
129 UnregisterProgram(program0);
130 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
131 UnregisterProgram(program1);
132 EXPECT_TRUE(GetModuleIds().empty());
133 }
134
TEST_F(XlaDebugInfoManagerTest,NoTraceDuplicateIds)135 TEST_F(XlaDebugInfoManagerTest, NoTraceDuplicateIds) {
136 auto program0A = RegisterProgram("program0");
137 auto program0B = RegisterProgram("program0"); // duplicates
138 auto program1 = RegisterProgram("program1");
139 EXPECT_THAT(GetModuleIds(),
140 UnorderedElementsAre(program0A, program0B, program1));
141
142 UnregisterProgram(program1);
143 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A, program0B));
144 UnregisterProgram(program0A);
145 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B));
146 UnregisterProgram(program0B);
147 EXPECT_THAT(GetModuleIds(), IsEmpty());
148 }
149
150 // Test the cases where an active trace session is involved.
TEST_F(XlaDebugInfoManagerTest,ActiveTrace)151 TEST_F(XlaDebugInfoManagerTest, ActiveTrace) {
152 auto program0A = RegisterProgram("program0");
153 auto program0B = RegisterProgram("program0"); // duplicates
154 auto program1 = RegisterProgram("program1");
155
156 StartTrace();
157 auto program2 = RegisterProgram("program2");
158 EXPECT_THAT(StopTrace(),
159 UnorderedElementsAre(program0A, program0B, program1, program2));
160
161 StartTrace();
162 EXPECT_THAT(StopTrace(),
163 UnorderedElementsAre(program0A, program0B, program1, program2));
164
165 UnregisterProgram(program2);
166 EXPECT_THAT(GetModuleIds(),
167 UnorderedElementsAre(program0A, program0B, program1));
168 UnregisterProgram(program0A);
169 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B, program1));
170 UnregisterProgram(program0B);
171 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
172 UnregisterProgram(program1);
173 EXPECT_THAT(GetModuleIds(), IsEmpty());
174 }
175
TEST_F(XlaDebugInfoManagerTest,UnregisterDuringTrace)176 TEST_F(XlaDebugInfoManagerTest, UnregisterDuringTrace) {
177 auto program0A = RegisterProgram("program0");
178 auto program0B = RegisterProgram("program0"); // duplicates
179 auto program1 = RegisterProgram("program1");
180
181 StartTrace();
182 UnregisterProgram(program1);
183 UnregisterProgram(program0B);
184 EXPECT_THAT(StopTrace(),
185 UnorderedElementsAre(program0A, program0B, program1));
186 EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A));
187
188 UnregisterProgram(program0A);
189 }
190
191 } // namespace
192 } // namespace xla
193