1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/algorithm/container.h" 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/strings/string_view.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/status.h" 29 30 namespace xla { 31 32 class HloModule; 33 34 // Class representing a sequence of HLO instructions such as the sequential 35 // execution order of an HLO computation. 36 class HloInstructionSequence { 37 public: 38 HloInstructionSequence() = default; HloInstructionSequence(absl::Span<HloInstruction * const> instructions)39 explicit HloInstructionSequence( 40 absl::Span<HloInstruction* const> instructions) { 41 for (HloInstruction* instruction : instructions) { 42 push_back(instruction); 43 } 44 } 45 46 // Adds the instruction to the end of the sequence. push_back(HloInstruction * instruction)47 void push_back(HloInstruction* instruction) { 48 instruction_sequence_.push_back(instruction); 49 id_sequence_.push_back(instruction->unique_id()); 50 } 51 52 // Removes the instruction from the sequence. remove_instruction(HloInstruction * instruction)53 void remove_instruction(HloInstruction* instruction) { 54 auto instruction_it = std::find(instruction_sequence_.begin(), 55 instruction_sequence_.end(), instruction); 56 if (instruction_it != instruction_sequence_.end()) { 57 auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(), 58 instruction->unique_id()); 59 instruction_sequence_.erase(instruction_it); 60 id_sequence_.erase(id_it); 61 } 62 } 63 64 // Replaces the old instruction with the new instruction in the sequence. replace_instruction(HloInstruction * old_instruction,HloInstruction * new_instruction)65 void replace_instruction(HloInstruction* old_instruction, 66 HloInstruction* new_instruction) { 67 auto instruction_it = 68 std::find(instruction_sequence_.begin(), instruction_sequence_.end(), 69 old_instruction); 70 auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(), 71 old_instruction->unique_id()); 72 CHECK(instruction_it != instruction_sequence_.end()) 73 << "Do not find instruction id " << old_instruction->unique_id(); 74 CHECK(id_it != id_sequence_.end()); 75 *instruction_it = new_instruction; 76 *id_it = new_instruction->unique_id(); 77 } 78 79 // Clears the sequence of all instructions. clear()80 void clear() { 81 instruction_sequence_.clear(); 82 id_sequence_.clear(); 83 } 84 size()85 int64_t size() const { return instruction_sequence_.size(); } 86 87 // Returns the sequence of HLO instructions. instructions()88 const std::vector<HloInstruction*>& instructions() const { 89 return instruction_sequence_; 90 } 91 92 // Returns the unique IDs of the instructions in the sequence (in order). ids()93 const std::vector<int>& ids() const { return id_sequence_; } 94 95 private: 96 // The sequence as HloInstructions. 97 std::vector<HloInstruction*> instruction_sequence_; 98 99 // The sequence of HLO instructions, represented by their unique IDs. The 100 // sequence is stored as both HloInstructions and unique IDs because the 101 // sequence may be referenced after transformations to the HLO graph and HLO 102 // pointers can be invalidated or recycled in this process (see 103 // HloSchedule::Update). 104 std::vector<int> id_sequence_; 105 }; 106 107 // A class representing a sequential schedule of instructions for an HLO 108 // module. A complete HLO schedule contains an instruction sequence for every 109 // non-fusion computation in the HLO module. 110 class HloSchedule { 111 public: HloSchedule(const HloModule * module)112 explicit HloSchedule(const HloModule* module) : module_(module) {} 113 114 // (De)Serialize an HloSchedule to/from a HloScheduleProto. 115 static StatusOr<HloSchedule> CreateFromProto(const HloModule* module, 116 const HloScheduleProto& proto); 117 StatusOr<HloScheduleProto> ToProto() const; 118 119 // Returns a reference to the sequence for the given computation. 120 const HloInstructionSequence& sequence( 121 const HloComputation* computation) const; 122 123 // Returns the sequence for the given computation. An empty sequence is 124 // created if none exists for the computation. 125 HloInstructionSequence& GetOrCreateSequence( 126 const HloComputation* computation); 127 128 // Sets the sequence for the given computation to the given sequence. 129 void set_sequence(const HloComputation* computation, 130 absl::Span<HloInstruction* const> sequence); 131 void set_sequence(const HloComputation* computation, 132 HloInstructionSequence sequence); 133 134 // Returns a map from HloComputation unique ID to instruction sequence. The 135 // map contains all sequences in the schedule. sequences()136 const absl::flat_hash_map<int64_t, HloInstructionSequence>& sequences() 137 const { 138 return sequences_; 139 } 140 141 absl::flat_hash_map<std::string, int64_t> num_sequences_by_execution_thread() 142 const; 143 144 // Returns true if the schedule has a sequence for the given computation. is_computation_scheduled(const HloComputation * computation)145 bool is_computation_scheduled(const HloComputation* computation) const { 146 return sequences_.contains(computation->unique_id()); 147 } 148 149 // Removes the computation from the sequences. remove_computation(const HloComputation * computation)150 void remove_computation(const HloComputation* computation) { 151 auto it = sequences_.find(computation->unique_id()); 152 CHECK(it != sequences_.end()); 153 sequences_.erase(it); 154 execution_threads_.erase(computation->unique_id()); 155 } 156 157 // Removes the instruction from the computation's sequence. remove_instruction(const HloComputation * computation,HloInstruction * instruction)158 void remove_instruction(const HloComputation* computation, 159 HloInstruction* instruction) { 160 sequences_[computation->unique_id()].remove_instruction(instruction); 161 } 162 163 // Replaces the old instruction with the new instruction in the computation's 164 // sequence. replace_instruction(const HloComputation * computation,HloInstruction * old_instruction,HloInstruction * new_instruction)165 void replace_instruction(const HloComputation* computation, 166 HloInstruction* old_instruction, 167 HloInstruction* new_instruction) { 168 sequences_[computation->unique_id()].replace_instruction(old_instruction, 169 new_instruction); 170 } 171 172 // Updates the schedule for specified threads such that it is (again) a valid 173 // schedule for the module. This is used to update a schedule after the HLO 174 // module has been transformed in some way. In general, the only 175 // transformations to the module for which a schedule can be updated is the 176 // addition or removal of instructions and removal of computations. Updating 177 // the schedule after new dependencies between existing instructions in the 178 // module is not supported and may result in an error status returned. 179 // 180 // Instructions in the module which also exist in the given schedule will 181 // remain in the same order in the updated schedule. Instructions which exist 182 // in the module but not in the given schedule will be placed as early as 183 // possible in the updated schedule. 184 Status Update( 185 const absl::flat_hash_set<absl::string_view>& execution_threads = {}); 186 187 // Verifies that the given schedule is valid for the given module. 188 // Specifically, the schedule contains exactly the instructions in the 189 // non-fusion computations in the module and every dependency in the module is 190 // satisfied in the schedule. 191 Status Verify() const; 192 193 std::string ToString() const; 194 empty()195 bool empty() const { return sequences_.empty(); } 196 module()197 const HloModule* module() const { return module_; } 198 199 private: 200 // Updates the instruction sequence for the given computation. 201 Status UpdateComputationSchedule(const HloComputation* computation); 202 203 const HloModule* module_; 204 205 // A map from computation unique ID to instruction sequence. Unique IDs are 206 // used rather than HloComputation pointers because HLO pointers are not 207 // unique across HLO transformations because pointers may be recycled. 208 absl::flat_hash_map<int64_t, HloInstructionSequence> sequences_; 209 210 // A corresponding map of `sequences_`, mapping the computation unique ID 211 // included in the shedule to execution threads. We need to store this since 212 // sometimes, computation could be removed while we still need the execution 213 // thread info for the remaining sequences. 214 absl::flat_hash_map<int64_t, std::string> execution_threads_; 215 }; 216 217 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); 218 219 } // namespace xla 220 221 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 222