1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 18 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/status.h" 26 27 namespace xla { 28 29 class HloModule; 30 31 // Class representing a sequence of HLO instructions such as the sequential 32 // execution order of an HLO computation. 33 class HloInstructionSequence { 34 public: 35 HloInstructionSequence() = default; HloInstructionSequence(absl::Span<HloInstruction * const> instructions)36 explicit HloInstructionSequence( 37 absl::Span<HloInstruction* const> instructions) { 38 for (HloInstruction* instruction : instructions) { 39 push_back(instruction); 40 } 41 } 42 43 // Adds the instruction to the end of the sequence. push_back(HloInstruction * instruction)44 void push_back(HloInstruction* instruction) { 45 instruction_sequence_.push_back(instruction); 46 id_sequence_.push_back(instruction->unique_id()); 47 } 48 49 // Removes the instruction from the sequence. remove_instruction(HloInstruction * instruction)50 void remove_instruction(HloInstruction* instruction) { 51 auto instruction_it = std::find(instruction_sequence_.begin(), 52 instruction_sequence_.end(), instruction); 53 if (instruction_it != instruction_sequence_.end()) { 54 auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(), 55 instruction->unique_id()); 56 instruction_sequence_.erase(instruction_it); 57 id_sequence_.erase(id_it); 58 } 59 } 60 61 // Replaces the old instruction with the new instruction in the sequence. replace_instruction(HloInstruction * old_instruction,HloInstruction * new_instruction)62 void replace_instruction(HloInstruction* old_instruction, 63 HloInstruction* new_instruction) { 64 auto instruction_it = 65 std::find(instruction_sequence_.begin(), instruction_sequence_.end(), 66 old_instruction); 67 auto id_it = std::find(id_sequence_.begin(), id_sequence_.end(), 68 old_instruction->unique_id()); 69 CHECK(instruction_it != instruction_sequence_.end()) 70 << "Do not find instruction id " << old_instruction->unique_id(); 71 CHECK(id_it != id_sequence_.end()); 72 *instruction_it = new_instruction; 73 *id_it = new_instruction->unique_id(); 74 } 75 76 // Clears the sequence of all instructions. clear()77 void clear() { 78 instruction_sequence_.clear(); 79 id_sequence_.clear(); 80 } 81 size()82 int64 size() const { return instruction_sequence_.size(); } 83 84 // Returns the sequence of HLO instructions. instructions()85 const std::vector<HloInstruction*>& instructions() const { 86 return instruction_sequence_; 87 } 88 89 // Returns the unique IDs of the instructions in the sequence (in order). ids()90 const std::vector<int>& ids() const { return id_sequence_; } 91 92 private: 93 // The sequence as HloInstructions. 94 std::vector<HloInstruction*> instruction_sequence_; 95 96 // The sequence of HLO instructions, represented by their unique IDs. The 97 // sequence is stored as both HloInstructions and unique IDs because the 98 // sequence may be referenced after transformations to the HLO graph and HLO 99 // pointers can be invalidated or recycled in this process (see 100 // HloSchedule::Update). 101 std::vector<int> id_sequence_; 102 }; 103 104 // A class representing a sequential schedule of instructions for an HLO 105 // module. A complete HLO schedule contains an instruction sequence for every 106 // non-fusion computation in the HLO module. 107 class HloSchedule { 108 public: HloSchedule(const HloModule * module)109 explicit HloSchedule(const HloModule* module) : module_(module) {} 110 111 // (De)Serialize an HloSchedule to/from a HloScheduleProto. 112 static StatusOr<HloSchedule> CreateFromProto(const HloModule* module, 113 const HloScheduleProto& proto); 114 StatusOr<HloScheduleProto> ToProto() const; 115 116 // Returns a reference to the sequence for the given computation. 117 const HloInstructionSequence& sequence( 118 const HloComputation* computation) const; 119 120 // Returns the sequence for the given computation. An empty sequence is 121 // created if none exists for the computation. 122 HloInstructionSequence& GetOrCreateSequence( 123 const HloComputation* computation); 124 125 // Sets the sequence for the given computation to the given sequence. 126 void set_sequence(const HloComputation* computation, 127 absl::Span<HloInstruction* const> sequence); 128 void set_sequence(const HloComputation* computation, 129 HloInstructionSequence sequence); 130 131 // Returns a map from HloComputation unique ID to instruction sequence. The 132 // map contains all sequences in the schedule. sequences()133 const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const { 134 return sequences_; 135 } 136 137 // Returns true if the schedule has a sequence for the given computation. is_computation_scheduled(const HloComputation * computation)138 bool is_computation_scheduled(const HloComputation* computation) const { 139 return sequences_.contains(computation->unique_id()); 140 } 141 142 // Removes the computation from the sequences. remove_computation(const HloComputation * computation)143 void remove_computation(const HloComputation* computation) { 144 auto it = sequences_.find(computation->unique_id()); 145 CHECK(it != sequences_.end()); 146 sequences_.erase(it); 147 } 148 149 // Removes the instruction from the computation's sequence. remove_instruction(const HloComputation * computation,HloInstruction * instruction)150 void remove_instruction(const HloComputation* computation, 151 HloInstruction* instruction) { 152 sequences_[computation->unique_id()].remove_instruction(instruction); 153 } 154 155 // Replaces the old instruction with the new instruction in the computation's 156 // sequence. replace_instruction(const HloComputation * computation,HloInstruction * old_instruction,HloInstruction * new_instruction)157 void replace_instruction(const HloComputation* computation, 158 HloInstruction* old_instruction, 159 HloInstruction* new_instruction) { 160 sequences_[computation->unique_id()].replace_instruction(old_instruction, 161 new_instruction); 162 } 163 164 // Updates the schedule such that it is (again) a valid schedule for the 165 // module. This is used to update a schedule after the HLO module has been 166 // transformed in some way. In general, the only transformations to the module 167 // for which a schedule can be updated is the addition or removal of 168 // instructions and removal of computations. Updating the schedule after new 169 // dependencies between existing instructions in the module is not supported 170 // and may result in an error status returned. 171 // 172 // Instructions in the module which also exist in the given schedule will 173 // remain in the same order in the updated schedule. Instructions which exist 174 // in the module but not in the given schedule will be placed as early as 175 // possible in the updated schedule. 176 Status Update(); 177 178 // Verifies that the given schedule is valid for the given module. 179 // Specifically, the schedule contains exactly the instructions in the 180 // non-fusion computations in the module and every dependency in the module is 181 // satisfied in the schedule. 182 Status Verify() const; 183 184 string ToString() const; 185 empty()186 bool empty() const { return sequences_.empty(); } 187 module()188 const HloModule* module() const { return module_; } 189 190 private: 191 // Updates the instruction sequence for the given computation. 192 Status UpdateComputationSchedule(const HloComputation* computation); 193 194 const HloModule* module_; 195 196 // A map from computation unique ID to instruction sequence. Unique IDs are 197 // used rather than HloComputation pointers because HLO pointers are not 198 // unique across HLO transformations because pointers may be recycled. 199 absl::flat_hash_map<int64, HloInstructionSequence> sequences_; 200 }; 201 202 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); 203 204 } // namespace xla 205 206 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 207