1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_value.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/str_join.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace xla {
40
41 using absl::StrAppend;
42 using absl::StrCat;
43
shape() const44 const Shape& HloPosition::shape() const {
45 return ShapeUtil::GetSubshape(instruction->shape(), index);
46 }
47
ToString() const48 string HloPosition::ToString() const {
49 string index_str =
50 instruction->shape().IsTuple() ? (" " + index.ToString()) : "";
51 return StrCat(instruction->name(), index_str);
52 }
53
operator <<(std::ostream & out,const HloPosition & position)54 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
55 out << position.ToString();
56 return out;
57 }
58
ToString() const59 string HloUse::ToString() const {
60 string index_str = instruction->operand(operand_number)->shape().IsTuple()
61 ? (" " + operand_index.ToString())
62 : "";
63 return StrCat(instruction->name(), ", operand ", operand_number, index_str);
64 }
65
operator <<(std::ostream & out,const HloUse & use)66 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
67 out << use.ToString();
68 return out;
69 }
70
HloValue(HloValue::Id id,HloInstruction * instruction,const ShapeIndex & index,bool is_phi)71 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
72 const ShapeIndex& index, bool is_phi)
73 : BufferValue(instruction, index, id), is_phi_(is_phi) {
74 // The defining position is always the first element in the positions_ vector.
75 positions_.push_back(HloPosition{instruction, index});
76 }
77
operator ==(const HloValue & other) const78 bool HloValue::operator==(const HloValue& other) const {
79 bool equal = defining_instruction() == other.defining_instruction() &&
80 defining_index() == other.defining_index();
81 // If the values are equal they must both be phi (or non phi).
82 CHECK(!(equal && is_phi() != other.is_phi()));
83 return equal;
84 }
85
operator !=(const HloValue & other) const86 bool HloValue::operator!=(const HloValue& other) const {
87 return !(*this == other);
88 }
89
ToShortString() const90 string HloValue::ToShortString() const {
91 return absl::StrFormat(
92 "<%d %s%s%s%s>", id(), instruction()->name(),
93 instruction()->shape().IsTuple() ? index().ToString() : "",
94 is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
95 }
96
ToString(int indent) const97 string HloValue::ToString(int indent) const {
98 string indentation(indent, ' ');
99 string out =
100 StrCat(indentation, ToShortString(), "\n", indentation, " positions:\n");
101 for (const HloPosition& position : positions()) {
102 StrAppend(&out, indentation, " ", position.ToString(), "\n");
103 }
104 StrAppend(&out, indentation, " uses:\n");
105 for (const HloUse& use : uses()) {
106 StrAppend(&out, indentation, " ", use.ToString(), "\n");
107 }
108 return out;
109 }
110
111 namespace {
112
113 // Returns true if the instruction 'user' may use the value at the given
114 // ShapeIndex in the given operand. Generally, instruction which pass through
115 // values transparently without reading the value are not considered to use the
116 // value.
MayUseOperandValue(int64 operand_number,const ShapeIndex & index,const HloInstruction * user)117 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
118 const HloInstruction* user) {
119 switch (user->opcode()) {
120 case HloOpcode::kGetTupleElement:
121 case HloOpcode::kCopy:
122 // These instructions only access the top-level values of their
123 // operand. Non-top-level (nested) values are passed through
124 // transparently.
125 CHECK_EQ(operand_number, 0);
126 return index.empty();
127 case HloOpcode::kTupleSelect:
128 // Select does not use any nested elements of its selected-from operands
129 // (operand 1 and 2)
130 CHECK_GE(operand_number, 0);
131 CHECK_LE(operand_number, 2);
132 return operand_number == 0 || index.empty();
133
134 case HloOpcode::kDomain:
135 case HloOpcode::kTuple:
136 // These instructions always pass through their operands transparently.
137 return false;
138
139 case HloOpcode::kCall:
140 case HloOpcode::kWhile:
141 // Although call and while instructions pass through their operands, they
142 // are considered uses.
143 return true;
144
145 default:
146 return true;
147 }
148 }
149
150 } // namespace
151
SetPositionsAndComputeUses(absl::Span<const HloPosition> positions)152 void HloValue::SetPositionsAndComputeUses(
153 absl::Span<const HloPosition> positions) {
154 CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
155
156 // The positions must be unique and should not contain the defining position
157 // as this is added at construction time.
158 for (const HloPosition& position_a : positions) {
159 DCHECK_NE(position_a, defining_position());
160 for (const HloPosition& position_b : positions) {
161 if (&position_a != &position_b) {
162 DCHECK_NE(position_a, position_b);
163 }
164 }
165 }
166
167 positions_.insert(positions_.end(), positions.begin(), positions.end());
168
169 // Gather the computation roots at which this value appears.
170 absl::flat_hash_set<HloInstruction*> root_positions;
171 for (const HloPosition& position : positions_) {
172 if (position.instruction ==
173 position.instruction->parent()->root_instruction()) {
174 root_positions.insert(position.instruction);
175 }
176 }
177
178 // Build vector of HloUses for the value.
179 for (const HloPosition& position : positions_) {
180 for (HloInstruction* user : position.instruction->users()) {
181 for (int64 i = 0; i < user->operand_count(); ++i) {
182 if (user->operand(i) != position.instruction) {
183 continue;
184 }
185
186 // Root instructions of computations are considered to be uses whether
187 // or not the root instruction itself actually uses the value.
188 if (MayUseOperandValue(i, position.index, user) ||
189 ContainsKey(root_positions, user)) {
190 HloUse new_use{user, i, position.index};
191
192 // The new use must not already exist in uses_.
193 for (const HloUse& use : uses_) {
194 DCHECK_NE(use, new_use);
195 }
196
197 uses_.push_back(std::move(new_use));
198 }
199 }
200 }
201
202 // Update liveout status of this HloValue.
203 const HloModule& module = *position.instruction->parent()->parent();
204 if (position.instruction ==
205 module.entry_computation()->root_instruction()) {
206 live_out_of_module_ = true;
207 }
208 }
209 }
210
operator <<(std::ostream & out,const HloValue & value)211 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
212 out << value.ToShortString();
213 return out;
214 }
215
SortAndUniquifyValues()216 void HloValueSet::SortAndUniquifyValues() {
217 absl::c_sort(values_, HloValue::IdLessThan);
218 values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
219 values_.end());
220 }
221
ToString() const222 string HloValueSet::ToString() const {
223 return StrCat(
224 "HloValueSet: ",
225 absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) {
226 result->append(value->ToShortString());
227 }));
228 }
229
AssignUnionOf(absl::Span<const HloValueSet * const> inputs)230 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
231 HloValueSet union_set;
232 for (const HloValueSet* input : inputs) {
233 for (const HloValue* value : input->values()) {
234 union_set.values_.push_back(value);
235 }
236 }
237 union_set.SortAndUniquifyValues();
238 if (*this != union_set) {
239 *this = union_set;
240 return true;
241 }
242 return false;
243 }
244
AddValue(const HloValue * value)245 bool HloValueSet::AddValue(const HloValue* value) {
246 auto it = std::lower_bound(values_.begin(), values_.end(), value,
247 HloValue::IdLessThan);
248 if (it == values_.end() || (*it)->id() != value->id()) {
249 values_.insert(it, value);
250 return true;
251 }
252 return false; // already exists
253 }
254
operator <<(std::ostream & out,const HloValueSet & value_set)255 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
256 out << value_set.ToString();
257 return out;
258 }
259
IsAmbiguous() const260 bool InstructionValueSet::IsAmbiguous() const {
261 bool ambiguous = false;
262 for (auto& iter : *this) {
263 ambiguous |= iter.second.values().size() > 1;
264 }
265 return ambiguous;
266 }
267
AssignUnionOf(absl::Span<const InstructionValueSet * const> inputs)268 bool InstructionValueSet::AssignUnionOf(
269 absl::Span<const InstructionValueSet* const> inputs) {
270 CHECK_GT(inputs.size(), 0);
271 for (int i = 1; i < inputs.size(); ++i) {
272 DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
273 }
274 bool changed = false;
275 for (auto& pair : *this) {
276 const ShapeIndex& index = pair.first;
277 HloValueSet& value_set = pair.second;
278
279 std::vector<const HloValueSet*> input_value_sets;
280 for (const InstructionValueSet* input : inputs) {
281 input_value_sets.push_back(&input->element(index));
282 }
283 changed |= value_set.AssignUnionOf(input_value_sets);
284 }
285
286 return changed;
287 }
288
operator <<(std::ostream & out,const InstructionValueSet & instruction_value_set)289 std::ostream& operator<<(std::ostream& out,
290 const InstructionValueSet& instruction_value_set) {
291 out << instruction_value_set.ToString();
292 return out;
293 }
294
ToString() const295 string InstructionValueSet::ToString() const {
296 string out =
297 StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
298 ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
299 StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
300 });
301 return out;
302 }
303
304 } // namespace xla
305