1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
17
18 #include <algorithm>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/types.h"
26
27 namespace xla {
28
Create(HloComputation * computation,string domain_kind)29 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
30 HloComputation* computation, string domain_kind) {
31 auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
32 TF_RETURN_IF_ERROR(domain_map->Populate(computation));
33 return std::move(domain_map);
34 }
35
Create(HloModule * module,string domain_kind)36 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
37 HloModule* module, string domain_kind) {
38 auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
39 for (HloComputation* computation : module->computations()) {
40 TF_RETURN_IF_ERROR(domain_map->Populate(computation));
41 }
42 return std::move(domain_map);
43 }
44
InSameDomain(const HloInstruction * instruction1,const HloInstruction * instruction2) const45 bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
46 const HloInstruction* instruction2) const {
47 int64 domain_id1 = GetDomainId(instruction1);
48 int64 domain_id2 = GetDomainId(instruction2);
49 return domain_id1 >= 0 && domain_id1 == domain_id2;
50 }
51
GetDomainId(const HloInstruction * instruction) const52 int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
53 return FindOrDefault(instruction_to_domain_, instruction, -1);
54 }
55
GetDomainMetadataId(const HloInstruction * instruction) const56 int64 HloDomainMap::GetDomainMetadataId(
57 const HloInstruction* instruction) const {
58 return FindOrDie(domain_metadata_id_, instruction);
59 }
60
TryProcessEmptyDomain(HloInstruction * instruction)61 Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
62 TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
63 // We only check operands, so we are sure to not process the empty domain from
64 // both sides.
65 for (HloInstruction* operand : instruction->unique_operands()) {
66 if (IsDomainInstruction(operand)) {
67 auto domain = absl::make_unique<DomainMetadata::Domain>();
68 domain->enter_domains.insert(operand);
69 domain->exit_domains.insert(instruction);
70 TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
71 }
72 }
73 if (instruction == instruction->parent()->root_instruction()) {
74 auto domain = absl::make_unique<DomainMetadata::Domain>();
75 domain->enter_domains.insert(instruction);
76 TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
77 }
78 return Status::OK();
79 }
80
Populate(HloComputation * computation)81 Status HloDomainMap::Populate(HloComputation* computation) {
82 InstructionOrderMap instructions_post_order;
83 int64 count = 0;
84 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
85 instructions_post_order.insert(std::make_pair(instruction, count++));
86 }
87 for (HloInstruction* instruction : computation->instructions()) {
88 if (IsDomainInstruction(instruction)) {
89 // If this is a kDomain of the kind we are currently processing, check
90 // whether this is an "empty domain".
91 TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction));
92 continue;
93 }
94 int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1);
95 if (domain_id >= 0) {
96 // We have already processed this instruction.
97 continue;
98 }
99 TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
100 CreateDomain(instruction, instructions_post_order));
101 TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
102 }
103 TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
104 return Status::OK();
105 }
106
PopulateDomainMetadataMap()107 Status HloDomainMap::PopulateDomainMetadataMap() {
108 auto hash = [](const DomainMetadata* m) { return m->Hash(); };
109 auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
110 return a->Matches(*b);
111 };
112 absl::flat_hash_map<const DomainMetadata*, int64, decltype(hash),
113 decltype(equal)>
114 domain_metadata(1024, hash, equal);
115
116 for (auto& domain : instruction_domains_) {
117 int64 domain_metadata_id = -1;
118 if (!domain->enter_domains.empty()) {
119 const HloInstruction* domain_instruction = *domain->enter_domains.begin();
120 domain_metadata_id =
121 domain_metadata
122 .insert({&domain_instruction->user_side_metadata(),
123 domain_metadata.size() + 1})
124 .first->second;
125 } else if (!domain->exit_domains.empty()) {
126 const HloInstruction* domain_instruction = *domain->exit_domains.begin();
127 domain_metadata_id =
128 domain_metadata
129 .insert({&domain_instruction->operand_side_metadata(),
130 domain_metadata.size() + 1})
131 .first->second;
132 } else {
133 domain_metadata_id = 0;
134 }
135 TF_RET_CHECK(domain_metadata_id >= 0);
136 for (HloInstruction* instruction : domain->instructions) {
137 domain_metadata_id_[instruction] = domain_metadata_id;
138 }
139 }
140 return Status::OK();
141 }
142
InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain)143 Status HloDomainMap::InsertDomain(
144 std::unique_ptr<DomainMetadata::Domain> domain) {
145 int64 domain_id = instruction_domains_.size();
146 instruction_domains_.push_back(std::move(domain));
147 for (HloInstruction* instruction : instruction_domains_.back()->reach_set) {
148 instruction_to_domain_[instruction] = domain_id;
149 }
150 return Status::OK();
151 }
152
ExpandDomain(HloInstruction * instruction,DomainMetadata::Domain * domain) const153 Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
154 DomainMetadata::Domain* domain) const {
155 std::vector<HloInstruction*> in_queue;
156 in_queue.push_back(instruction);
157 while (!in_queue.empty()) {
158 HloInstruction* current_instruction = in_queue.back();
159 in_queue.pop_back();
160 if (domain->reach_set.insert(current_instruction).second) {
161 // We should not be finding instructions with assigned domain here.
162 // If we assigned a domain to the instruction, it means that all the
163 // instructions reached by it, should have a domain as well.
164 int64 domain_id =
165 FindOrDefault(instruction_to_domain_, current_instruction, -1);
166 TF_RET_CHECK(domain_id < 0)
167 << "Instruction " << current_instruction->ToString()
168 << " already has domain " << domain_id;
169 for (HloInstruction* operand : current_instruction->operands()) {
170 if (IsDomainInstruction(operand)) {
171 // The reach set instruction is a user of the domain instruction
172 // (the instruction sees the kDomain as operand).
173 // IOW the dataflow enters the domain through the kDomain instruction.
174 domain->enter_domains.insert(operand);
175 } else {
176 in_queue.push_back(operand);
177 }
178 }
179 for (HloInstruction* user : current_instruction->users()) {
180 if (IsDomainInstruction(user)) {
181 // The reach set instruction is an operand of the domain instruction
182 // (the instruction sees the kDomain as user).
183 // IOW the dataflow exits the domain through the kDomain instruction.
184 domain->exit_domains.insert(user);
185 } else {
186 in_queue.push_back(user);
187 }
188 }
189 }
190 }
191 return Status::OK();
192 }
193
CreateDomain(HloInstruction * instruction,const InstructionOrderMap & instructions_order) const194 StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
195 HloInstruction* instruction,
196 const InstructionOrderMap& instructions_order) const {
197 auto domain = absl::make_unique<DomainMetadata::Domain>();
198 TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
199 domain->instructions =
200 MakeNonDomainInstructions(domain->reach_set, instructions_order);
201 return std::move(domain);
202 }
203
IsDomainInstruction(const HloInstruction * instruction) const204 bool HloDomainMap::IsDomainInstruction(
205 const HloInstruction* instruction) const {
206 if (instruction->opcode() != HloOpcode::kDomain) {
207 return false;
208 }
209 if (!domain_kind_.empty()) {
210 if (instruction->user_side_metadata().Kind() != domain_kind_) {
211 return false;
212 }
213 // Both user and operand side of the metadata must be of the same kind.
214 CHECK(instruction->operand_side_metadata().Kind() == domain_kind_)
215 << "Instruction " << instruction->ToString()
216 << " has mismatching metadata kinds";
217 }
218 return true;
219 }
220
221 /* static */ std::vector<HloInstruction*>
MakeNonDomainInstructions(const absl::flat_hash_set<HloInstruction * > & instruction_set,const InstructionOrderMap & instructions_order)222 HloDomainMap::MakeNonDomainInstructions(
223 const absl::flat_hash_set<HloInstruction*>& instruction_set,
224 const InstructionOrderMap& instructions_order) {
225 std::vector<HloInstruction*> instructions;
226 instructions.reserve(instruction_set.size());
227 for (HloInstruction* instruction : instruction_set) {
228 if (instruction->opcode() != HloOpcode::kDomain) {
229 instructions.push_back(instruction);
230 }
231 }
232 // sort instructions according to instructions_order
233 absl::c_sort(instructions,
234 [&instructions_order](HloInstruction* a, HloInstruction* b) {
235 return instructions_order.at(a) < instructions_order.at(b);
236 });
237 return instructions;
238 }
239
240 } // namespace xla
241