1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ 18 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_module.h" 24 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 25 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 26 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 27 #include "tensorflow/compiler/xla/service/logical_buffer.h" 28 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/types.h" 31 32 namespace xla { 33 34 // A memory scheduler computes an execution sequence for the HLO instructions in 35 // 'computation' that minimizes peak memory, given a points-to analysis result 36 // that describes buffer aliasing, together with a target-specific size function 37 // that maps a tensor's logical size to its padded size. 38 typedef std::function<StatusOr<HloInstructionSequence>( 39 HloComputation*, const TuplePointsToAnalysis&, 40 const LogicalBuffer::SizeFunction&, 41 const absl::flat_hash_map<const HloComputation*, int64>&)> 42 MemorySchedulerAlgorithm; 43 44 // List scheduler 45 StatusOr<HloInstructionSequence> ListMemoryScheduler( 46 HloComputation* computation, 47 const TuplePointsToAnalysis& points_to_analysis, 48 const LogicalBuffer::SizeFunction& size_function, 49 const absl::flat_hash_map<const HloComputation*, int64>& 50 memory_by_computation); 51 52 // DFS-order scheduler 53 StatusOr<HloInstructionSequence> DFSMemoryScheduler( 54 HloComputation* computation, 55 const TuplePointsToAnalysis& points_to_analysis, 56 const LogicalBuffer::SizeFunction& size_function, 57 const absl::flat_hash_map<const HloComputation*, int64>& 58 memory_by_computation); 59 60 // Naive Post Order scheduler 61 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( 62 HloComputation* computation, 63 const TuplePointsToAnalysis& points_to_analysis, 64 const LogicalBuffer::SizeFunction& size_function, 65 const absl::flat_hash_map<const HloComputation*, int64>& 66 memory_by_computation); 67 68 // The default scheduling algorithm. Runs both the list scheduler 69 // and the DFS scheduler, and chooses whichever returns a lower min-memory, 70 // not accounting for fragmentation. 71 StatusOr<HloInstructionSequence> DefaultMemoryScheduler( 72 HloComputation* computation, 73 const TuplePointsToAnalysis& points_to_analysis, 74 const LogicalBuffer::SizeFunction& size_function, 75 const absl::flat_hash_map<const HloComputation*, int64>& 76 memory_by_computation); 77 78 // Returns an HloSchedule which seeks to minimize the memory required for 79 // the computation. size_function is the function returning the number of bytes 80 // required for a LogicalBuffer. 81 StatusOr<HloSchedule> ScheduleModule( 82 HloModule* module, const LogicalBuffer::SizeFunction& size_function, 83 const MemorySchedulerAlgorithm& algorithm = {}); 84 85 // Computes the schedule for a single computation. 86 // Currently only used by the GPU backend. 87 StatusOr<HloInstructionSequence> ScheduleComputation( 88 HloComputation* computation, 89 const LogicalBuffer::SizeFunction& size_function); 90 91 // A pass which schedules the HLO instructions in a module. The HloModule's 92 // schedule field is set to the resulting HloSchedule using 93 // HloModule::set_schedule. 94 class HloMemoryScheduler : public HloModulePass { 95 public: 96 // size_function is the function returning the number of bytes required for a 97 // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not 98 // specified, then DefaultMemoryScheduler is used. 99 HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, 100 const MemorySchedulerAlgorithm& algorithm = {}); 101 ~HloMemoryScheduler() override = default; name()102 absl::string_view name() const override { return "hlo-memory-scheduler"; } 103 104 StatusOr<bool> Run(HloModule* module) override; 105 106 private: 107 LogicalBuffer::SizeFunction size_function_; 108 MemorySchedulerAlgorithm algorithm_; 109 }; 110 111 // A pass which produces a naive, but correct schedule. The schedule is produced 112 // using a DFS traversal of the graph with no attempt to minimize memory use. 113 class HloTrivialScheduler : public HloModulePass { 114 public: name()115 absl::string_view name() const override { return "hlo-trivial-scheduler"; } 116 117 StatusOr<bool> Run(HloModule* module) override; 118 }; 119 120 // A trivial pass which clears the schedule currently set on the 121 // HloModule. After this pass runs HloModule::has_schedule will return false. 122 class HloDescheduler : public HloModulePass { 123 public: 124 HloDescheduler() = default; 125 ~HloDescheduler() override = default; name()126 absl::string_view name() const override { return "hlo-descheduler"; } 127 128 StatusOr<bool> Run(HloModule* module) override; 129 }; 130 131 } // namespace xla 132 133 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ 134