1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 18 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 26 #include "tensorflow/compiler/xla/status.h" 27 28 namespace xla { 29 30 class HloModule; 31 32 // Class representing a sequence of HLO instructions such as the sequential 33 // execution order of an HLO computation. 34 class HloInstructionSequence { 35 public: 36 HloInstructionSequence() = default; HloInstructionSequence(absl::Span<HloInstruction * const> instructions)37 explicit HloInstructionSequence( 38 absl::Span<HloInstruction* const> instructions) { 39 for (HloInstruction* instruction : instructions) { 40 push_back(instruction); 41 } 42 } 43 44 // Adds the instruction to the end of the sequence. push_back(HloInstruction * instruction)45 void push_back(HloInstruction* instruction) { 46 instruction_sequence_.push_back(instruction); 47 id_sequence_.push_back(instruction->unique_id()); 48 } 49 50 // Clears the sequence of all instructions. clear()51 void clear() { 52 instruction_sequence_.clear(); 53 id_sequence_.clear(); 54 } 55 size()56 int64 size() const { return instruction_sequence_.size(); } 57 58 // Returns the sequence of HLO instructions. instructions()59 const std::vector<HloInstruction*>& instructions() const { 60 return instruction_sequence_; 61 } 62 63 // Returns the unique IDs of the instructions in the sequence (in order). ids()64 const std::vector<int>& ids() const { return id_sequence_; } 65 66 private: 67 // The sequence as HloInstructions. 68 std::vector<HloInstruction*> instruction_sequence_; 69 70 // The sequence of HLO instructions, represented by their unique IDs. The 71 // sequence is stored as both HloInstructions and unique IDs because the 72 // sequence may be referenced after transformations to the HLO graph and HLO 73 // pointers can be invalidated or recycled in this process (see 74 // HloSchedule::Update). 75 std::vector<int> id_sequence_; 76 }; 77 78 // A class representing a sequential schedule of instructions for an HLO 79 // module. A complete HLO schedule contains an instruction sequence for every 80 // non-fusion computation in the HLO module. 81 class HloSchedule { 82 public: HloSchedule(const HloModule * module)83 explicit HloSchedule(const HloModule* module) : module_(module) {} 84 85 // (De)Serialize an HloSchedule to/from a HloScheduleProto. 86 static StatusOr<HloSchedule> CreateFromProto(const HloModule* module, 87 const HloScheduleProto& proto); 88 StatusOr<HloScheduleProto> ToProto() const; 89 90 // Returns a reference to the sequence for the given computation. 91 const HloInstructionSequence& sequence( 92 const HloComputation* computation) const; 93 94 // Returns the sequence for the given computation. An empty sequence is 95 // created if none exists for the computation. 96 HloInstructionSequence& GetOrCreateSequence( 97 const HloComputation* computation); 98 99 // Sets the sequence for the given computation to the given sequence. 100 void set_sequence(const HloComputation* computation, 101 absl::Span<HloInstruction* const> sequence); 102 void set_sequence(const HloComputation* computation, 103 HloInstructionSequence sequence); 104 105 // Returns a map from HloComputation unique ID to instruction sequence. The 106 // map contains all sequences in the schedule. sequences()107 const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const { 108 return sequences_; 109 } 110 111 // Returns true if the schedule has a sequence for the given computation. is_computation_scheduled(const HloComputation * computation)112 bool is_computation_scheduled(const HloComputation* computation) const { 113 return sequences_.contains(computation->unique_id()); 114 } 115 116 // Updates the schedule such that it is (again) a valid schedule for the 117 // module. This is used to update a schedule after the HLO module has been 118 // transformed in some way. In general, the only transformations to the module 119 // for which a schedule can be updated is the addition or removal of 120 // instructions and removal of computations. Updating the schedule after new 121 // dependencies between existing instructions in the module is not supported 122 // and may result in an error status returned. 123 // 124 // Instructions in the module which also exist in the given schedule will 125 // remain in the same order in the updated schedule. Instructions which exist 126 // in the module but not in the given schedule will be placed as early as 127 // possible in the updated schedule. 128 Status Update(); 129 130 // Verifies that the given schedule is valid for the given module. 131 // Specifically, the schedule contains exactly the instructions in the 132 // non-fusion computations in the module and every dependency in the module is 133 // satisfied in the schedule. 134 Status Verify() const; 135 136 string ToString() const; 137 empty()138 bool empty() const { return sequences_.empty(); } 139 module()140 const HloModule* module() const { return module_; } 141 142 private: 143 // Updates the instruction sequence for the given computation. 144 Status UpdateComputationSchedule(const HloComputation* computation); 145 146 const HloModule* module_; 147 148 // A map from computation unique ID to instruction sequence. Unique IDs are 149 // used rather than HloComputation pointers because HLO pointers are not 150 // unique across HLO transformations because pointers may be recycled. 151 absl::flat_hash_map<int64, HloInstructionSequence> sequences_; 152 }; 153 154 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); 155 156 } // namespace xla 157 158 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ 159