1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_value.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace xla {
39
40 using absl::StrAppend;
41 using absl::StrCat;
42
shape() const43 const Shape& HloPosition::shape() const {
44 return ShapeUtil::GetSubshape(instruction->shape(), index);
45 }
46
ToString() const47 string HloPosition::ToString() const {
48 string index_str =
49 instruction->shape().IsTuple() ? (" " + index.ToString()) : "";
50 return StrCat(instruction->name(), index_str);
51 }
52
operator <<(std::ostream & out,const HloPosition & position)53 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
54 out << position.ToString();
55 return out;
56 }
57
ToString() const58 string HloUse::ToString() const {
59 string index_str = instruction->operand(operand_number)->shape().IsTuple()
60 ? (" " + operand_index.ToString())
61 : "";
62 return StrCat(instruction->name(), ", operand ", operand_number, index_str);
63 }
64
operator <<(std::ostream & out,const HloUse & use)65 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
66 out << use.ToString();
67 return out;
68 }
69
HloValue(HloValue::Id id,HloInstruction * instruction,const ShapeIndex & index,bool is_phi)70 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
71 const ShapeIndex& index, bool is_phi)
72 : BufferValue(instruction, index, id), is_phi_(is_phi) {
73 // The defining position is always the first element in the positions_ vector.
74 positions_.push_back(HloPosition{instruction, index});
75 }
76
operator ==(const HloValue & other) const77 bool HloValue::operator==(const HloValue& other) const {
78 bool equal = defining_instruction() == other.defining_instruction() &&
79 defining_index() == other.defining_index();
80 // If the values are equal they most both be phi (or non phi).
81 CHECK(!(equal && is_phi() != other.is_phi()));
82 return equal;
83 }
84
operator !=(const HloValue & other) const85 bool HloValue::operator!=(const HloValue& other) const {
86 return !(*this == other);
87 }
88
ToShortString() const89 string HloValue::ToShortString() const {
90 string index_str = defining_instruction()->shape().IsTuple()
91 ? defining_index().ToString()
92 : "";
93 return StrCat(id(), " ", is_phi_ ? "PHI " : "",
94 defining_instruction()->name(), index_str);
95 }
96
ToString(int indent) const97 string HloValue::ToString(int indent) const {
98 string indentation(indent, ' ');
99 string out = StrCat(indentation, ToShortString(), ", positions:\n");
100 for (const HloPosition& position : positions()) {
101 StrAppend(&out, indentation, " ", position.ToString(), "\n");
102 }
103 StrAppend(&out, indentation, " uses:\n");
104 for (const HloUse& use : uses()) {
105 StrAppend(&out, indentation, " ", use.ToString(), "\n");
106 }
107 return out;
108 }
109
110 namespace {
111
112 // Returns true if the instruction 'user' may use the value at the given
113 // ShapeIndex in the given operand. Generally, instruction which pass through
114 // values transparently without reading the value are not considered to use the
115 // value.
MayUseOperandValue(int64 operand_number,const ShapeIndex & index,const HloInstruction * user)116 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
117 const HloInstruction* user) {
118 switch (user->opcode()) {
119 case HloOpcode::kGetTupleElement:
120 case HloOpcode::kCopy:
121 // These instructions only access the top-level values of their
122 // operand. Non-top-level (nested) values are passed through
123 // transparently.
124 CHECK_EQ(operand_number, 0);
125 return index.empty();
126 case HloOpcode::kTupleSelect:
127 // Select does not use any nested elements of its selected-from operands
128 // (operand 1 and 2)
129 CHECK_GE(operand_number, 0);
130 CHECK_LE(operand_number, 2);
131 return operand_number == 0 || index.empty();
132
133 case HloOpcode::kDomain:
134 case HloOpcode::kTuple:
135 // These instructions always pass through their operands transparently.
136 return false;
137
138 case HloOpcode::kCall:
139 case HloOpcode::kWhile:
140 // Although call and while instructions pass through their operands, they
141 // are considered uses.
142 return true;
143
144 default:
145 return true;
146 }
147 }
148
149 } // namespace
150
SetPositionsAndComputeUses(absl::Span<const HloPosition> positions)151 void HloValue::SetPositionsAndComputeUses(
152 absl::Span<const HloPosition> positions) {
153 CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
154
155 // The positions must be unique and should not contain the defining position
156 // as this is added at construction time.
157 for (const HloPosition& position_a : positions) {
158 DCHECK_NE(position_a, defining_position());
159 for (const HloPosition& position_b : positions) {
160 if (&position_a != &position_b) {
161 DCHECK_NE(position_a, position_b);
162 }
163 }
164 }
165
166 positions_.insert(positions_.end(), positions.begin(), positions.end());
167
168 // Gather the computation roots at which this value appears.
169 absl::flat_hash_set<HloInstruction*> root_positions;
170 for (const HloPosition& position : positions_) {
171 if (position.instruction ==
172 position.instruction->parent()->root_instruction()) {
173 root_positions.insert(position.instruction);
174 }
175 }
176
177 // Build vector of HloUses for the value.
178 for (const HloPosition& position : positions_) {
179 for (HloInstruction* user : position.instruction->users()) {
180 for (int64 operand_number : user->OperandIndices(position.instruction)) {
181 // Root instructions of computations are considered to be uses whether
182 // or not the root instruction itself actually uses the value.
183 if (MayUseOperandValue(operand_number, position.index, user) ||
184 ContainsKey(root_positions, user)) {
185 HloUse new_use{user, operand_number, position.index};
186
187 // The new use must not already exist in uses_.
188 for (const HloUse& use : uses_) {
189 DCHECK_NE(use, new_use);
190 }
191
192 uses_.push_back(std::move(new_use));
193 }
194 }
195 }
196
197 // Update liveout status of this HloValue.
198 const HloModule& module = *position.instruction->parent()->parent();
199 if (position.instruction ==
200 module.entry_computation()->root_instruction()) {
201 live_out_of_module_ = true;
202 }
203 }
204 }
205
operator <<(std::ostream & out,const HloValue & value)206 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
207 out << value.ToShortString();
208 return out;
209 }
210
SortAndUniquifyValues()211 void HloValueSet::SortAndUniquifyValues() {
212 absl::c_sort(values_, HloValue::IdLessThan);
213 values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
214 values_.end());
215 }
216
ToString() const217 string HloValueSet::ToString() const {
218 return StrCat(
219 "HloValueSet: ",
220 absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
221 result->append(value->ToShortString());
222 }));
223 }
224
AssignUnionOf(absl::Span<const HloValueSet * const> inputs)225 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
226 HloValueSet union_set;
227 for (const HloValueSet* input : inputs) {
228 for (const HloValue* value : input->values()) {
229 union_set.values_.push_back(value);
230 }
231 }
232 union_set.SortAndUniquifyValues();
233 if (*this != union_set) {
234 *this = union_set;
235 return true;
236 }
237 return false;
238 }
239
AddValue(const HloValue * value)240 bool HloValueSet::AddValue(const HloValue* value) {
241 auto it = std::lower_bound(values_.begin(), values_.end(), value,
242 HloValue::IdLessThan);
243 if (it == values_.end() || (*it)->id() != value->id()) {
244 values_.insert(it, value);
245 return true;
246 }
247 return false; // already exists
248 }
249
operator <<(std::ostream & out,const HloValueSet & value_set)250 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
251 out << value_set.ToString();
252 return out;
253 }
254
AssignUnionOf(absl::Span<const InstructionValueSet * const> inputs)255 bool InstructionValueSet::AssignUnionOf(
256 absl::Span<const InstructionValueSet* const> inputs) {
257 CHECK_GT(inputs.size(), 0);
258 for (int i = 1; i < inputs.size(); ++i) {
259 DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
260 }
261 bool changed = false;
262 for (auto& pair : *this) {
263 const ShapeIndex& index = pair.first;
264 HloValueSet& value_set = pair.second;
265
266 std::vector<const HloValueSet*> input_value_sets;
267 for (const InstructionValueSet* input : inputs) {
268 input_value_sets.push_back(&input->element(index));
269 }
270 changed |= value_set.AssignUnionOf(input_value_sets);
271 }
272
273 return changed;
274 }
275
operator <<(std::ostream & out,const InstructionValueSet & instruction_value_set)276 std::ostream& operator<<(std::ostream& out,
277 const InstructionValueSet& instruction_value_set) {
278 out << instruction_value_set.ToString();
279 return out;
280 }
281
ToString() const282 string InstructionValueSet::ToString() const {
283 string out =
284 StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
285 ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
286 StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
287 });
288 return out;
289 }
290
291 } // namespace xla
292