• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/status.h"
29 
30 namespace xla {
31 
32 class HloModule;
33 
34 // Class representing a sequence of HLO instructions such as the sequential
35 // execution order of an HLO computation.
36 class HloInstructionSequence {
37  public:
38   HloInstructionSequence() = default;
HloInstructionSequence(absl::Span<HloInstruction * const> instructions)39   explicit HloInstructionSequence(
40       absl::Span<HloInstruction* const> instructions) {
41     for (HloInstruction* instruction : instructions) {
42       push_back(instruction);
43     }
44   }
45 
46   // Adds the instruction to the end of the sequence.
push_back(HloInstruction * instruction)47   void push_back(HloInstruction* instruction) {
48     instruction_sequence_.push_back(instruction);
49     id_sequence_.push_back(instruction->unique_id());
50   }
51 
52   // Removes the instruction from the sequence.
remove_instruction(HloInstruction * instruction)53   void remove_instruction(HloInstruction* instruction) {
54     auto instruction_it = std::find(instruction_sequence_.begin(),
55                                     instruction_sequence_.end(), instruction);
56     if (instruction_it != instruction_sequence_.end()) {
57       auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(),
58                              instruction->unique_id());
59       instruction_sequence_.erase(instruction_it);
60       id_sequence_.erase(id_it);
61     }
62   }
63 
64   // Replaces the old instruction with the new instruction in the sequence.
replace_instruction(HloInstruction * old_instruction,HloInstruction * new_instruction)65   void replace_instruction(HloInstruction* old_instruction,
66                            HloInstruction* new_instruction) {
67     auto instruction_it =
68         std::find(instruction_sequence_.begin(), instruction_sequence_.end(),
69                   old_instruction);
70     auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(),
71                            old_instruction->unique_id());
72     CHECK(instruction_it != instruction_sequence_.end())
73         << "Do not find instruction id " << old_instruction->unique_id();
74     CHECK(id_it != id_sequence_.end());
75     *instruction_it = new_instruction;
76     *id_it = new_instruction->unique_id();
77   }
78 
79   // Clears the sequence of all instructions.
clear()80   void clear() {
81     instruction_sequence_.clear();
82     id_sequence_.clear();
83   }
84 
size()85   int64_t size() const { return instruction_sequence_.size(); }
86 
87   // Returns the sequence of HLO instructions.
instructions()88   const std::vector<HloInstruction*>& instructions() const {
89     return instruction_sequence_;
90   }
91 
92   // Returns the unique IDs of the instructions in the sequence (in order).
ids()93   const std::vector<int>& ids() const { return id_sequence_; }
94 
95  private:
96   // The sequence as HloInstructions.
97   std::vector<HloInstruction*> instruction_sequence_;
98 
99   // The sequence of HLO instructions, represented by their unique IDs. The
100   // sequence is stored as both HloInstructions and unique IDs because the
101   // sequence may be referenced after transformations to the HLO graph and HLO
102   // pointers can be invalidated or recycled in this process (see
103   // HloSchedule::Update).
104   std::vector<int> id_sequence_;
105 };
106 
107 // A class representing a sequential schedule of instructions for an HLO
108 // module. A complete HLO schedule contains an instruction sequence for every
109 // non-fusion computation in the HLO module.
110 class HloSchedule {
111  public:
HloSchedule(const HloModule * module)112   explicit HloSchedule(const HloModule* module) : module_(module) {}
113 
114   // (De)Serialize an HloSchedule to/from a HloScheduleProto.
115   static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
116                                                const HloScheduleProto& proto);
117   StatusOr<HloScheduleProto> ToProto() const;
118 
119   // Returns a reference to the sequence for the given computation.
120   const HloInstructionSequence& sequence(
121       const HloComputation* computation) const;
122 
123   // Returns the sequence for the given computation. An empty sequence is
124   // created if none exists for the computation.
125   HloInstructionSequence& GetOrCreateSequence(
126       const HloComputation* computation);
127 
128   // Sets the sequence for the given computation to the given sequence.
129   void set_sequence(const HloComputation* computation,
130                     absl::Span<HloInstruction* const> sequence);
131   void set_sequence(const HloComputation* computation,
132                     HloInstructionSequence sequence);
133 
134   // Returns a map from HloComputation unique ID to instruction sequence. The
135   // map contains all sequences in the schedule.
sequences()136   const absl::flat_hash_map<int64_t, HloInstructionSequence>& sequences()
137       const {
138     return sequences_;
139   }
140 
141   absl::flat_hash_map<std::string, int64_t> num_sequences_by_execution_thread()
142       const;
143 
144   // Returns true if the schedule has a sequence for the given computation.
is_computation_scheduled(const HloComputation * computation)145   bool is_computation_scheduled(const HloComputation* computation) const {
146     return sequences_.contains(computation->unique_id());
147   }
148 
149   // Removes the computation from the sequences.
remove_computation(const HloComputation * computation)150   void remove_computation(const HloComputation* computation) {
151     auto it = sequences_.find(computation->unique_id());
152     CHECK(it != sequences_.end());
153     sequences_.erase(it);
154     execution_threads_.erase(computation->unique_id());
155   }
156 
157   // Removes the instruction from the computation's sequence.
remove_instruction(const HloComputation * computation,HloInstruction * instruction)158   void remove_instruction(const HloComputation* computation,
159                           HloInstruction* instruction) {
160     sequences_[computation->unique_id()].remove_instruction(instruction);
161   }
162 
163   // Replaces the old instruction with the new instruction in the computation's
164   // sequence.
replace_instruction(const HloComputation * computation,HloInstruction * old_instruction,HloInstruction * new_instruction)165   void replace_instruction(const HloComputation* computation,
166                            HloInstruction* old_instruction,
167                            HloInstruction* new_instruction) {
168     sequences_[computation->unique_id()].replace_instruction(old_instruction,
169                                                              new_instruction);
170   }
171 
172   // Updates the schedule for specified threads such that it is (again) a valid
173   // schedule for the module. This is used to update a schedule after the HLO
174   // module has been transformed in some way. In general, the only
175   // transformations to the module for which a schedule can be updated is the
176   // addition or removal of instructions and removal of computations. Updating
177   // the schedule after new dependencies between existing instructions in the
178   // module is not supported and may result in an error status returned.
179   //
180   // Instructions in the module which also exist in the given schedule will
181   // remain in the same order in the updated schedule. Instructions which exist
182   // in the module but not in the given schedule will be placed as early as
183   // possible in the updated schedule.
184   Status Update(
185       const absl::flat_hash_set<absl::string_view>& execution_threads = {});
186 
187   // Verifies that the given schedule is valid for the given module.
188   // Specifically, the schedule contains exactly the instructions in the
189   // non-fusion computations in the module and every dependency in the module is
190   // satisfied in the schedule.
191   Status Verify() const;
192 
193   std::string ToString() const;
194 
empty()195   bool empty() const { return sequences_.empty(); }
196 
module()197   const HloModule* module() const { return module_; }
198 
199  private:
200   // Updates the instruction sequence for the given computation.
201   Status UpdateComputationSchedule(const HloComputation* computation);
202 
203   const HloModule* module_;
204 
205   // A map from computation unique ID to instruction sequence. Unique IDs are
206   // used rather than HloComputation pointers because HLO pointers are not
207   // unique across HLO transformations because pointers may be recycled.
208   absl::flat_hash_map<int64_t, HloInstructionSequence> sequences_;
209 
210   // A corresponding map of `sequences_`, mapping the computation unique ID
211   // included in the shedule to execution threads. We need to store this since
212   // sometimes, computation could be removed while we still need the execution
213   // thread info for the remaining sequences.
214   absl::flat_hash_map<int64_t, std::string> execution_threads_;
215 };
216 
217 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
218 
219 }  // namespace xla
220 
221 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
222