• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
27 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
28 #include "tensorflow/compiler/xla/service/logical_buffer.h"
29 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 // Postprocessor of the HloInstructionSequence. This is an opt-in postprocessing
36 // function to MemorySchedulerAlgorithm to enforce certain hlo schedule
37 // constraints desired for custom-calls.
38 using MemorySchedulerPostprocessor =
39     std::function<HloInstructionSequence(const HloInstructionSequence&)>;
40 
41 // A memory scheduler computes an execution sequence for the HLO instructions in
42 // 'computation' that minimizes peak memory, given a points-to analysis result
43 // that describes buffer aliasing, together with a target-specific size function
44 // that maps a tensor's logical size to its padded size. peak_memory (may be
45 // nullptr) is set to the peak memory of the resulting schedule according to the
46 // HeapSimulator.
47 //
48 // TODO(yunxing): Cleanup usage of TuplePointsToAnalysis.
49 typedef std::function<StatusOr<HloInstructionSequence>(
50     HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
51     const LogicalBuffer::SizeFunction&,
52     const absl::flat_hash_map<const HloComputation*, int64>&,
53     const MemorySchedulerPostprocessor&,
54     /*peak_memory*/ int64*)>
55     MemorySchedulerAlgorithm;
56 
57 // Scheduler for the entire module.
58 typedef std::function<StatusOr<HloSchedule>(
59     HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
60     const LogicalBuffer::SizeFunction&,
61     /*peak_memory*/ int64*)>
62     ModuleSchedulerAlgorithm;
63 
64 // Lift a computation scheduler into a module scheduler by calling the
65 // computation scheduler on all computations in a module.
66 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
67     const MemorySchedulerAlgorithm&, const MemorySchedulerPostprocessor& = {});
68 
69 // List scheduler
70 StatusOr<HloInstructionSequence> ListMemoryScheduler(
71     HloComputation* computation,
72     const TuplePointsToAnalysis& points_to_analysis,
73     const HloAliasAnalysis& alias_analysis,
74     const LogicalBuffer::SizeFunction& size_function,
75     const absl::flat_hash_map<const HloComputation*, int64>&
76         memory_by_computation,
77     const MemorySchedulerPostprocessor& postprocessor, int64* peak_memory);
78 
79 // DFS-order scheduler
80 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
81     HloComputation* computation,
82     const TuplePointsToAnalysis& points_to_analysis,
83     const HloAliasAnalysis& alias_analysis,
84     const LogicalBuffer::SizeFunction& size_function,
85     const absl::flat_hash_map<const HloComputation*, int64>&
86         memory_by_computation,
87     const MemorySchedulerPostprocessor& postprocessor, int64* peak_memory);
88 
89 // Naive Post Order scheduler
90 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
91     HloComputation* computation,
92     const TuplePointsToAnalysis& points_to_analysis,
93     const HloAliasAnalysis& alias_analysis,
94     const LogicalBuffer::SizeFunction& size_function,
95     const absl::flat_hash_map<const HloComputation*, int64>&
96         memory_by_computation,
97     const MemorySchedulerPostprocessor& postprocessor, int64* peak_memory);
98 
99 // The default scheduling algorithm. Runs the list scheduler, the DFS scheduler,
100 // and the post-order scheduler and chooses whichever returns a lower min-
101 // memory, not accounting for fragmentation. peak_memory (may be nullptr) is set
102 // to the peak memory of the resulting schedule according to the HeapSimulator.
103 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
104     HloComputation* computation,
105     const TuplePointsToAnalysis& points_to_analysis,
106     const HloAliasAnalysis& alias_analysis,
107     const LogicalBuffer::SizeFunction& size_function,
108     const absl::flat_hash_map<const HloComputation*, int64>&
109         memory_by_computation,
110     const MemorySchedulerPostprocessor& postprocessor, int64* peak_memory);
111 
112 StatusOr<HloSchedule> DefaultModuleScheduler(
113     HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
114     const HloAliasAnalysis& alias_analysis,
115     const LogicalBuffer::SizeFunction& size_function, int64* peak_memory);
116 
117 // Returns an HloSchedule which seeks to minimize the memory required for the
118 // module. size_function is the function returning the number of bytes required
119 // for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak
120 // memory (according to the HeapSimulator) of all computations in the module.
121 StatusOr<HloSchedule> ScheduleModule(
122     HloModule* module, const LogicalBuffer::SizeFunction& size_function,
123     const ModuleSchedulerAlgorithm& algorithm = {},
124     int64* peak_memory = nullptr);
125 
126 // Computes the schedule for a single computation.
127 // Currently only used by the GPU backend.
128 StatusOr<HloInstructionSequence> ScheduleComputation(
129     HloComputation* computation,
130     const LogicalBuffer::SizeFunction& size_function,
131     const MemorySchedulerPostprocessor& postprocessor);
132 
133 // A pass which schedules the HLO instructions in a module. The HloModule's
134 // schedule field is set to the resulting HloSchedule using
135 // HloModule::set_schedule.
136 class HloMemoryScheduler : public HloModulePass {
137  public:
138   // size_function is the function returning the number of bytes required for a
139   // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
140   // specified, then DefaultMemoryScheduler is used.
141   HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
142                      const ModuleSchedulerAlgorithm& algorithm = {});
143 
144   ~HloMemoryScheduler() override = default;
145 
name()146   absl::string_view name() const override { return "hlo-memory-scheduler"; }
147 
148   StatusOr<bool> Run(HloModule* module) override;
149 
150  private:
151   LogicalBuffer::SizeFunction size_function_;
152 
153   ModuleSchedulerAlgorithm algorithm_;
154 };
155 
156 // A pass which produces a naive, but correct schedule. The schedule is produced
157 // using a DFS traversal of the graph with no attempt to minimize memory use.
158 class HloTrivialScheduler : public HloModulePass {
159  public:
name()160   absl::string_view name() const override { return "hlo-trivial-scheduler"; }
161 
162   StatusOr<bool> Run(HloModule* module) override;
163 };
164 
165 // A trivial pass which clears the schedule currently set on the
166 // HloModule. After this pass runs HloModule::has_schedule will return false.
167 class HloDescheduler : public HloModulePass {
168  public:
169   HloDescheduler() = default;
170   ~HloDescheduler() override = default;
name()171   absl::string_view name() const override { return "hlo-descheduler"; }
172 
173   StatusOr<bool> Run(HloModule* module) override;
174 };
175 
176 }  // namespace xla
177 
178 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
179