• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/compiler/xla/service/hlo.pb.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 
25 namespace xla {
26 
27 class XlaDebugInfoManagerTestPeer {
28  public:
RegisterModule(ModuleIdentifier module_id,std::shared_ptr<const HloModule> hlo_module,std::shared_ptr<const BufferAssignmentProto> buffer_assignment)29   void RegisterModule(
30       ModuleIdentifier module_id, std::shared_ptr<const HloModule> hlo_module,
31       std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
32     return xla_debug_info_manager_.RegisterModule(module_id, hlo_module,
33                                                   buffer_assignment);
34   }
35 
UnregisterModule(ModuleIdentifier module_id)36   void UnregisterModule(ModuleIdentifier module_id) {
37     return xla_debug_info_manager_.UnregisterModule(module_id);
38   }
39 
StartTracing()40   void StartTracing() { return xla_debug_info_manager_.StartTracing(); }
41 
StopTracing()42   absl::flat_hash_set<ModuleIdentifier> StopTracing() {
43     std::vector<std::unique_ptr<HloProto>> module_debug_info;
44     xla_debug_info_manager_.StopTracing(&module_debug_info);
45     absl::flat_hash_set<ModuleIdentifier> module_ids;
46     for (const auto& hlo_proto : module_debug_info) {
47       module_ids.insert(hlo_proto->hlo_module().id());
48     }
49     return module_ids;
50   }
51 
GetModuleIds()52   absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
53     absl::flat_hash_set<ModuleIdentifier> module_ids;
54     absl::MutexLock lock(&xla_debug_info_manager_.mutex_);
55     for (const auto& it : xla_debug_info_manager_.modules_) {
56       module_ids.insert(it.first);
57     }
58     return module_ids;
59   }
60 
61  private:
62   XlaDebugInfoManager xla_debug_info_manager_;
63 };
64 
65 namespace {
66 
67 using ::testing::IsEmpty;
68 using ::testing::UnorderedElementsAre;
69 
70 class XlaDebugInfoManagerTest : public HloTestBase {
71  protected:
72   struct DebugMetadata {
73     // We allow same id to be registered multiple times. we need unique id to
74     // know which program is referenced (such as in UnregisterProgram).
75     ModuleIdentifier unique_id;
76     std::shared_ptr<HloModule> module;
77     std::shared_ptr<BufferAssignmentProto> buffer_assignment;
78   };
79 
80   // Return unique id of this module.
RegisterProgram(const std::string & module_name)81   ModuleIdentifier RegisterProgram(const std::string& module_name) {
82     DebugMetadata debug_info;
83     HloModuleConfig config;
84     debug_info.module = std::make_shared<HloModule>(module_name, config);
85     debug_info.buffer_assignment = nullptr;
86     ModuleIdentifier unique_id = debug_info.module->unique_id();
87     debug_info.unique_id = unique_id;
88     xla_debug_info_manager_.RegisterModule(unique_id, debug_info.module,
89                                            debug_info.buffer_assignment);
90     external_references_.push_back(std::move(debug_info));
91     return unique_id;
92   }
93 
UnregisterProgram(ModuleIdentifier unique_id)94   void UnregisterProgram(ModuleIdentifier unique_id) {
95     for (int i = 0; i < external_references_.size(); i++) {
96       if (external_references_[i].unique_id == unique_id) {
97         xla_debug_info_manager_.UnregisterModule(unique_id);
98         external_references_.erase(external_references_.begin() + i);
99         break;
100       }
101     }
102   }
103 
GetModuleIds()104   absl::flat_hash_set<ModuleIdentifier> GetModuleIds() {
105     return xla_debug_info_manager_.GetModuleIds();
106   }
107 
StartTrace()108   void StartTrace() { xla_debug_info_manager_.StartTracing(); }
109 
StopTrace()110   absl::flat_hash_set<ModuleIdentifier> StopTrace() {
111     return xla_debug_info_manager_.StopTracing();
112   }
113 
114   // Simulation of compilation cache.
115   std::vector<DebugMetadata> external_references_;
116 
117   // Use an instance per test instead of singleton to avoid interferences.
118   XlaDebugInfoManagerTestPeer xla_debug_info_manager_;
119 };
120 
121 // Test the cases where no trace session is involved.
TEST_F(XlaDebugInfoManagerTest,NoTraceBasic)122 TEST_F(XlaDebugInfoManagerTest, NoTraceBasic) {
123   auto program0 = RegisterProgram("program0");
124   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0));
125 
126   auto program1 = RegisterProgram("program1");
127   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0, program1));
128 
129   UnregisterProgram(program0);
130   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
131   UnregisterProgram(program1);
132   EXPECT_TRUE(GetModuleIds().empty());
133 }
134 
TEST_F(XlaDebugInfoManagerTest,NoTraceDuplicateIds)135 TEST_F(XlaDebugInfoManagerTest, NoTraceDuplicateIds) {
136   auto program0A = RegisterProgram("program0");
137   auto program0B = RegisterProgram("program0");  // duplicates
138   auto program1 = RegisterProgram("program1");
139   EXPECT_THAT(GetModuleIds(),
140               UnorderedElementsAre(program0A, program0B, program1));
141 
142   UnregisterProgram(program1);
143   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A, program0B));
144   UnregisterProgram(program0A);
145   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B));
146   UnregisterProgram(program0B);
147   EXPECT_THAT(GetModuleIds(), IsEmpty());
148 }
149 
150 // Test the cases where an active trace session is involved.
TEST_F(XlaDebugInfoManagerTest,ActiveTrace)151 TEST_F(XlaDebugInfoManagerTest, ActiveTrace) {
152   auto program0A = RegisterProgram("program0");
153   auto program0B = RegisterProgram("program0");  // duplicates
154   auto program1 = RegisterProgram("program1");
155 
156   StartTrace();
157   auto program2 = RegisterProgram("program2");
158   EXPECT_THAT(StopTrace(),
159               UnorderedElementsAre(program0A, program0B, program1, program2));
160 
161   StartTrace();
162   EXPECT_THAT(StopTrace(),
163               UnorderedElementsAre(program0A, program0B, program1, program2));
164 
165   UnregisterProgram(program2);
166   EXPECT_THAT(GetModuleIds(),
167               UnorderedElementsAre(program0A, program0B, program1));
168   UnregisterProgram(program0A);
169   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0B, program1));
170   UnregisterProgram(program0B);
171   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program1));
172   UnregisterProgram(program1);
173   EXPECT_THAT(GetModuleIds(), IsEmpty());
174 }
175 
TEST_F(XlaDebugInfoManagerTest,UnregisterDuringTrace)176 TEST_F(XlaDebugInfoManagerTest, UnregisterDuringTrace) {
177   auto program0A = RegisterProgram("program0");
178   auto program0B = RegisterProgram("program0");  // duplicates
179   auto program1 = RegisterProgram("program1");
180 
181   StartTrace();
182   UnregisterProgram(program1);
183   UnregisterProgram(program0B);
184   EXPECT_THAT(StopTrace(),
185               UnorderedElementsAre(program0A, program0B, program1));
186   EXPECT_THAT(GetModuleIds(), UnorderedElementsAre(program0A));
187 
188   UnregisterProgram(program0A);
189 }
190 
191 }  // namespace
192 }  // namespace xla
193