• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_value.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
41 using absl::StrAppend;
42 using absl::StrCat;
43 
shape() const44 const Shape& HloPosition::shape() const {
45   return ShapeUtil::GetSubshape(instruction->shape(), index);
46 }
47 
ToString() const48 std::string HloPosition::ToString() const {
49   std::string index_str =
50       instruction->shape().IsTuple() ? (" " + index.ToString()) : "";
51   return StrCat(instruction->name(), index_str);
52 }
53 
operator <<(std::ostream & out,const HloPosition & position)54 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
55   out << position.ToString();
56   return out;
57 }
58 
ToString() const59 std::string HloUse::ToString() const {
60   std::string index_str =
61       instruction->operand(operand_number)->shape().IsTuple()
62           ? (" " + operand_index.ToString())
63           : "";
64   return StrCat(instruction->name(), ", operand ", operand_number, index_str);
65 }
66 
operator <<(std::ostream & out,const HloUse & use)67 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
68   out << use.ToString();
69   return out;
70 }
71 
HloValue(HloValue::Id id,HloInstruction * instruction,const ShapeIndex & index,bool is_phi)72 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
73                    const ShapeIndex& index, bool is_phi)
74     : BufferValue(instruction, index, id), is_phi_(is_phi) {
75   // The defining position is always the first element in the positions_ vector.
76   positions_.push_back(HloPosition{instruction, index});
77 }
78 
ToShortString() const79 std::string HloValue::ToShortString() const {
80   return absl::StrFormat(
81       "<%d %s%s%s%s>", id(), instruction()->name(),
82       instruction()->shape().IsTuple() ? index().ToString() : "",
83       is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : "");
84 }
85 
ToString(int indent) const86 std::string HloValue::ToString(int indent) const {
87   std::string indentation(indent, ' ');
88   std::string out =
89       StrCat(indentation, ToShortString(), "\n", indentation, " positions:\n");
90   for (const HloPosition& position : positions()) {
91     StrAppend(&out, indentation, "  ", position.ToString(), "\n");
92   }
93   StrAppend(&out, indentation, " uses:\n");
94   for (const HloUse& use : GetUses()) {
95     StrAppend(&out, indentation, "  ", use.ToString(), "\n");
96   }
97   StrAppend(&out, indentation, " from instruction:", instruction()->ToString(),
98             "\n");
99   return out;
100 }
101 
102 namespace {
103 
104 // Returns true if the instruction 'user' may use the value at the given
105 // ShapeIndex in the given operand. Generally, instruction which pass through
106 // values transparently without reading the value are not considered to use the
107 // value.
MayUseOperandValue(int64_t operand_number,const ShapeIndex & index,const HloInstruction * user)108 bool MayUseOperandValue(int64_t operand_number, const ShapeIndex& index,
109                         const HloInstruction* user) {
110   switch (user->opcode()) {
111     case HloOpcode::kGetTupleElement:
112     case HloOpcode::kCopy:
113       // These instructions only access the top-level values of their
114       // operand. Non-top-level (nested) values are passed through
115       // transparently.
116       CHECK_EQ(operand_number, 0);
117       return index.empty();
118     case HloOpcode::kDomain:
119     case HloOpcode::kTuple:
120       // These instructions always pass through their operands transparently.
121       return false;
122 
123     case HloOpcode::kCall:
124     case HloOpcode::kWhile:
125       // Although call and while instructions pass through their operands, they
126       // are considered uses.
127       return true;
128 
129     default:
130       return true;
131   }
132 }
133 
134 }  // namespace
135 
SetPositions(absl::Span<const HloPosition> positions)136 void HloValue::SetPositions(absl::Span<const HloPosition> positions) {
137   CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
138 
139   // The positions must be unique and should not contain the defining position
140   // as this is added at construction time.
141   for (const HloPosition& position_a : positions) {
142     DCHECK_NE(position_a, defining_position());
143     for (const HloPosition& position_b : positions) {
144       if (&position_a != &position_b) {
145         DCHECK_NE(position_a, position_b);
146       }
147     }
148   }
149 
150   positions_.insert(positions_.end(), positions.begin(), positions.end());
151   // Update liveout status of this HloValue.
152   live_out_of_module_ |=
153       IsRootOf(defining_instruction()->GetModule()->entry_computation());
154 }
155 
ComputeUses(std::vector<HloUse> & uses) const156 void HloValue::ComputeUses(std::vector<HloUse>& uses) const {
157   // Gather the computation roots at which this value appears.
158   absl::flat_hash_set<HloInstruction*> root_positions;
159   for (const HloPosition& position : positions_) {
160     if (position.instruction->IsRoot()) {
161       root_positions.insert(position.instruction);
162     }
163   }
164 
165   // Build vector of HloUses for the value.
166   for (const HloPosition& position : positions_) {
167     for (HloInstruction* user : position.instruction->users()) {
168       for (int64_t i = 0; i < user->operand_count(); ++i) {
169         if (user->operand(i) != position.instruction) {
170           continue;
171         }
172 
173         // Root instructions of computations are considered to be uses whether
174         // or not the root instruction itself actually uses the value.
175         if (MayUseOperandValue(i, position.index, user) ||
176             root_positions.contains(user)) {
177           HloUse new_use{user, i, position.index};
178 
179           // The new use must not already exist in uses.
180           for (const HloUse& use : uses) {
181             DCHECK_NE(use, new_use);
182           }
183 
184           uses.push_back(std::move(new_use));
185         }
186       }
187     }
188   }
189 }
190 
IsRootOf(const HloComputation * computation) const191 bool HloValue::IsRootOf(const HloComputation* computation) const {
192   return absl::c_any_of(positions_, [&](const HloPosition& position) {
193     return position.instruction->IsRoot() &&
194            position.instruction->parent() == computation;
195   });
196 }
197 
operator <<(std::ostream & out,const HloValue & value)198 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
199   out << value.ToShortString();
200   return out;
201 }
202 
HloValueSet(absl::Span<const HloValue * const> values)203 HloValueSet::HloValueSet(absl::Span<const HloValue* const> values)
204     : values_(values.begin(), values.end()) {
205   SortAndUniquifyValues();
206 }
207 
HloValueSet(const absl::flat_hash_set<const HloValue * > & values)208 HloValueSet::HloValueSet(const absl::flat_hash_set<const HloValue*>& values)
209     : values_(values.begin(), values.end()) {
210   // Values are already unique, so only need to sort.
211   absl::c_sort(values_, HloValue::IdLessThan);
212 }
213 
SortAndUniquifyValues()214 void HloValueSet::SortAndUniquifyValues() {
215   absl::c_sort(values_, HloValue::IdLessThan);
216   values_.erase(std::unique(values_.begin(), values_.end()), values_.end());
217 }
218 
ToString() const219 std::string HloValueSet::ToString() const {
220   return StrCat("HloValueSet: ",
221                 absl::StrJoin(values_, ", ",
222                               [](std::string* result, const HloValue* value) {
223                                 result->append(value->ToShortString());
224                               }));
225 }
226 
AssignUnionOf(absl::Span<const HloValueSet * const> inputs)227 bool HloValueSet::AssignUnionOf(absl::Span<const HloValueSet* const> inputs) {
228   HloValueSet union_set;
229   for (const HloValueSet* input : inputs) {
230     for (const HloValue* value : input->values()) {
231       union_set.values_.push_back(value);
232     }
233   }
234   union_set.SortAndUniquifyValues();
235   if (*this != union_set) {
236     *this = union_set;
237     return true;
238   }
239   return false;
240 }
241 
AddValue(const HloValue * value)242 bool HloValueSet::AddValue(const HloValue* value) {
243   auto it = std::lower_bound(values_.begin(), values_.end(), value,
244                              HloValue::IdLessThan);
245   if (it == values_.end() || (*it)->id() != value->id()) {
246     values_.insert(it, value);
247     return true;
248   }
249   return false;  // already exists
250 }
251 
operator <<(std::ostream & out,const HloValueSet & value_set)252 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
253   out << value_set.ToString();
254   return out;
255 }
256 
IsAmbiguous() const257 bool InstructionValueSet::IsAmbiguous() const {
258   bool ambiguous = false;
259   for (auto& iter : *this) {
260     ambiguous |= iter.second.values().size() > 1;
261   }
262   return ambiguous;
263 }
264 
AssignUnionOf(absl::Span<const InstructionValueSet * const> inputs)265 bool InstructionValueSet::AssignUnionOf(
266     absl::Span<const InstructionValueSet* const> inputs) {
267   CHECK_GT(inputs.size(), 0);
268   for (int i = 1; i < inputs.size(); ++i) {
269     DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
270   }
271   bool changed = false;
272   for (auto& pair : *this) {
273     const ShapeIndex& index = pair.first;
274     HloValueSet& value_set = pair.second;
275 
276     std::vector<const HloValueSet*> input_value_sets;
277     for (const InstructionValueSet* input : inputs) {
278       input_value_sets.push_back(&input->element(index));
279     }
280     changed |= value_set.AssignUnionOf(input_value_sets);
281   }
282 
283   return changed;
284 }
285 
AssignUnionOf(const InstructionValueSet & input,ShapeIndexView input_index)286 bool InstructionValueSet::AssignUnionOf(const InstructionValueSet& input,
287                                         ShapeIndexView input_index) {
288   bool changed = false;
289   for (auto& [index, value_set] : *this) {
290     ShapeIndex source_index(input_index);
291     for (auto i : index) {
292       source_index.push_back(i);
293     }
294     changed |= value_set.AssignUnionOf({&input.element(source_index)});
295   }
296 
297   return changed;
298 }
299 
operator <<(std::ostream & out,const InstructionValueSet & instruction_value_set)300 std::ostream& operator<<(std::ostream& out,
301                          const InstructionValueSet& instruction_value_set) {
302   out << instruction_value_set.ToString();
303   return out;
304 }
305 
ToString() const306 std::string InstructionValueSet::ToString() const {
307   std::string out =
308       StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
309   ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
310     StrAppend(&out, "  ", index.ToString(), " : ", value_set.ToString(), "\n");
311   });
312   return out;
313 }
314 
315 }  // namespace xla
316