1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_value.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
41 using absl::StrAppend;
42 using absl::StrCat;
43
shape() const44 const Shape& HloPosition::shape() const {
45 return ShapeUtil::GetSubshape(instruction->shape(), index);
46 }
47
ToString() const48 std::string HloPosition::ToString() const {
49 std::string index_str =
50 instruction->shape().IsTuple() ? (" " + index.ToString()) : "";
51 return StrCat(instruction->name(), index_str);
52 }
53
operator <<(std::ostream & out,const HloPosition & position)54 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
55 out << position.ToString();
56 return out;
57 }
58
ToString() const59 std::string HloUse::ToString() const {
60 std::string index_str =
61 instruction->operand(operand_number)->shape().IsTuple()
62 ? (" " + operand_index.ToString())
63 : "";
64 return StrCat(instruction->name(), ", operand ", operand_number, index_str);
65 }
66
operator <<(std::ostream & out,const HloUse & use)67 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
68 out << use.ToString();
69 return out;
70 }
71
HloValue(HloValue::Id id,HloInstruction * instruction,const ShapeIndex & index,bool is_phi)72 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
73 const ShapeIndex& index, bool is_phi)
74 : BufferValue(instruction, index, id), is_phi_(is_phi) {
75 // The defining position is always the first element in the positions_ vector.
76 positions_.push_back(HloPosition{instruction, index});
77 }
78
ToShortString() const79 std::string HloValue::ToShortString() const {
80 return absl::StrFormat(
81 "<%d %s%s%s%s>", id(), instruction()->name(),
82 instruction()->shape().IsTuple() ? index().ToString() : "",
83 is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
84 }
85
ToString(int indent) const86 std::string HloValue::ToString(int indent) const {
87 std::string indentation(indent, ' ');
88 std::string out =
89 StrCat(indentation, ToShortString(), "\n", indentation, " positions:\n");
90 for (const HloPosition& position : positions()) {
91 StrAppend(&out, indentation, " ", position.ToString(), "\n");
92 }
93 StrAppend(&out, indentation, " uses:\n");
94 for (const HloUse& use : GetUses()) {
95 StrAppend(&out, indentation, " ", use.ToString(), "\n");
96 }
97 StrAppend(&out, indentation, " from instruction:", instruction()->ToString(),
98 "\n");
99 return out;
100 }
101
102 namespace {
103
104 // Returns true if the instruction 'user' may use the value at the given
105 // ShapeIndex in the given operand. Generally, instruction which pass through
106 // values transparently without reading the value are not considered to use the
107 // value.
MayUseOperandValue(int64_t operand_number,const ShapeIndex & index,const HloInstruction * user)108 bool MayUseOperandValue(int64_t operand_number, const ShapeIndex& index,
109 const HloInstruction* user) {
110 switch (user->opcode()) {
111 case HloOpcode::kGetTupleElement:
112 case HloOpcode::kCopy:
113 // These instructions only access the top-level values of their
114 // operand. Non-top-level (nested) values are passed through
115 // transparently.
116 CHECK_EQ(operand_number, 0);
117 return index.empty();
118 case HloOpcode::kDomain:
119 case HloOpcode::kTuple:
120 // These instructions always pass through their operands transparently.
121 return false;
122
123 case HloOpcode::kCall:
124 case HloOpcode::kWhile:
125 // Although call and while instructions pass through their operands, they
126 // are considered uses.
127 return true;
128
129 default:
130 return true;
131 }
132 }
133
134 } // namespace
135
SetPositions(absl::Span<const HloPosition> positions)136 void HloValue::SetPositions(absl::Span<const HloPosition> positions) {
137 CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
138
139 // The positions must be unique and should not contain the defining position
140 // as this is added at construction time.
141 for (const HloPosition& position_a : positions) {
142 DCHECK_NE(position_a, defining_position());
143 for (const HloPosition& position_b : positions) {
144 if (&position_a != &position_b) {
145 DCHECK_NE(position_a, position_b);
146 }
147 }
148 }
149
150 positions_.insert(positions_.end(), positions.begin(), positions.end());
151 // Update liveout status of this HloValue.
152 live_out_of_module_ |=
153 IsRootOf(defining_instruction()->GetModule()->entry_computation());
154 }
155
ComputeUses(std::vector<HloUse> & uses) const156 void HloValue::ComputeUses(std::vector<HloUse>& uses) const {
157 // Gather the computation roots at which this value appears.
158 absl::flat_hash_set<HloInstruction*> root_positions;
159 for (const HloPosition& position : positions_) {
160 if (position.instruction->IsRoot()) {
161 root_positions.insert(position.instruction);
162 }
163 }
164
165 // Build vector of HloUses for the value.
166 for (const HloPosition& position : positions_) {
167 for (HloInstruction* user : position.instruction->users()) {
168 for (int64_t i = 0; i < user->operand_count(); ++i) {
169 if (user->operand(i) != position.instruction) {
170 continue;
171 }
172
173 // Root instructions of computations are considered to be uses whether
174 // or not the root instruction itself actually uses the value.
175 if (MayUseOperandValue(i, position.index, user) ||
176 root_positions.contains(user)) {
177 HloUse new_use{user, i, position.index};
178
179 // The new use must not already exist in uses.
180 for (const HloUse& use : uses) {
181 DCHECK_NE(use, new_use);
182 }
183
184 uses.push_back(std::move(new_use));
185 }
186 }
187 }
188 }
189 }
190
IsRootOf(const HloComputation * computation) const191 bool HloValue::IsRootOf(const HloComputation* computation) const {
192 return absl::c_any_of(positions_, [&](const HloPosition& position) {
193 return position.instruction->IsRoot() &&
194 position.instruction->parent() == computation;
195 });
196 }
197
operator <<(std::ostream & out,const HloValue & value)198 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
199 out << value.ToShortString();
200 return out;
201 }
202
HloValueSet(absl::Span<const HloValue * const> values)203 HloValueSet::HloValueSet(absl::Span<const HloValue* const> values)
204 : values_(values.begin(), values.end()) {
205 SortAndUniquifyValues();
206 }
207
HloValueSet(const absl::flat_hash_set<const HloValue * > & values)208 HloValueSet::HloValueSet(const absl::flat_hash_set<const HloValue*>& values)
209 : values_(values.begin(), values.end()) {
210 // Values are already unique, so only need to sort.
211 absl::c_sort(values_, HloValue::IdLessThan);
212 }
213
SortAndUniquifyValues()214 void HloValueSet::SortAndUniquifyValues() {
215 absl::c_sort(values_, HloValue::IdLessThan);
216 values_.erase(std::unique(values_.begin(), values_.end()), values_.end());
217 }
218
ToString() const219 std::string HloValueSet::ToString() const {
220 return StrCat("HloValueSet: ",
221 absl::StrJoin(values_, ", ",
222 [](std::string* result, const HloValue* value) {
223 result->append(value->ToShortString());
224 }));
225 }
226
AssignUnionOf(absl::Span<const HloValueSet * const> inputs)227 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
228 HloValueSet union_set;
229 for (const HloValueSet* input : inputs) {
230 for (const HloValue* value : input->values()) {
231 union_set.values_.push_back(value);
232 }
233 }
234 union_set.SortAndUniquifyValues();
235 if (*this != union_set) {
236 *this = union_set;
237 return true;
238 }
239 return false;
240 }
241
AddValue(const HloValue * value)242 bool HloValueSet::AddValue(const HloValue* value) {
243 auto it = std::lower_bound(values_.begin(), values_.end(), value,
244 HloValue::IdLessThan);
245 if (it == values_.end() || (*it)->id() != value->id()) {
246 values_.insert(it, value);
247 return true;
248 }
249 return false; // already exists
250 }
251
operator <<(std::ostream & out,const HloValueSet & value_set)252 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
253 out << value_set.ToString();
254 return out;
255 }
256
IsAmbiguous() const257 bool InstructionValueSet::IsAmbiguous() const {
258 bool ambiguous = false;
259 for (auto& iter : *this) {
260 ambiguous |= iter.second.values().size() > 1;
261 }
262 return ambiguous;
263 }
264
AssignUnionOf(absl::Span<const InstructionValueSet * const> inputs)265 bool InstructionValueSet::AssignUnionOf(
266 absl::Span<const InstructionValueSet* const> inputs) {
267 CHECK_GT(inputs.size(), 0);
268 for (int i = 1; i < inputs.size(); ++i) {
269 DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
270 }
271 bool changed = false;
272 for (auto& pair : *this) {
273 const ShapeIndex& index = pair.first;
274 HloValueSet& value_set = pair.second;
275
276 std::vector<const HloValueSet*> input_value_sets;
277 for (const InstructionValueSet* input : inputs) {
278 input_value_sets.push_back(&input->element(index));
279 }
280 changed |= value_set.AssignUnionOf(input_value_sets);
281 }
282
283 return changed;
284 }
285
AssignUnionOf(const InstructionValueSet & input,ShapeIndexView input_index)286 bool InstructionValueSet::AssignUnionOf(const InstructionValueSet& input,
287 ShapeIndexView input_index) {
288 bool changed = false;
289 for (auto& [index, value_set] : *this) {
290 ShapeIndex source_index(input_index);
291 for (auto i : index) {
292 source_index.push_back(i);
293 }
294 changed |= value_set.AssignUnionOf({&input.element(source_index)});
295 }
296
297 return changed;
298 }
299
operator <<(std::ostream & out,const InstructionValueSet & instruction_value_set)300 std::ostream& operator<<(std::ostream& out,
301 const InstructionValueSet& instruction_value_set) {
302 out << instruction_value_set.ToString();
303 return out;
304 }
305
ToString() const306 std::string InstructionValueSet::ToString() const {
307 std::string out =
308 StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
309 ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
310 StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
311 });
312 return out;
313 }
314
315 } // namespace xla
316