• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
25 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/logical_buffer.h"
28 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 
32 namespace xla {
33 
34 // A memory scheduler computes an execution sequence for the HLO instructions in
35 // 'computation' that minimizes peak memory, given a points-to analysis result
36 // that describes buffer aliasing, together with a target-specific size function
37 // that maps a tensor's logical size to its padded size.
38 typedef std::function<StatusOr<HloInstructionSequence>(
39     HloComputation*, const TuplePointsToAnalysis&,
40     const LogicalBuffer::SizeFunction&,
41     const absl::flat_hash_map<const HloComputation*, int64>&)>
42     MemorySchedulerAlgorithm;
43 
44 // List scheduler
45 StatusOr<HloInstructionSequence> ListMemoryScheduler(
46     HloComputation* computation,
47     const TuplePointsToAnalysis& points_to_analysis,
48     const LogicalBuffer::SizeFunction& size_function,
49     const absl::flat_hash_map<const HloComputation*, int64>&
50         memory_by_computation);
51 
52 // DFS-order scheduler
53 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
54     HloComputation* computation,
55     const TuplePointsToAnalysis& points_to_analysis,
56     const LogicalBuffer::SizeFunction& size_function,
57     const absl::flat_hash_map<const HloComputation*, int64>&
58         memory_by_computation);
59 
60 // Naive Post Order scheduler
61 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
62     HloComputation* computation,
63     const TuplePointsToAnalysis& points_to_analysis,
64     const LogicalBuffer::SizeFunction& size_function,
65     const absl::flat_hash_map<const HloComputation*, int64>&
66         memory_by_computation);
67 
68 // The default scheduling algorithm. Runs both the list scheduler
69 // and the DFS scheduler, and chooses whichever returns a lower min-memory,
70 // not accounting for fragmentation.
71 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
72     HloComputation* computation,
73     const TuplePointsToAnalysis& points_to_analysis,
74     const LogicalBuffer::SizeFunction& size_function,
75     const absl::flat_hash_map<const HloComputation*, int64>&
76         memory_by_computation);
77 
78 // Returns an HloSchedule which seeks to minimize the memory required for
79 // the computation. size_function is the function returning the number of bytes
80 // required for a LogicalBuffer.
81 StatusOr<HloSchedule> ScheduleModule(
82     HloModule* module, const LogicalBuffer::SizeFunction& size_function,
83     const MemorySchedulerAlgorithm& algorithm = {});
84 
85 // Computes the schedule for a single computation.
86 // Currently only used by the GPU backend.
87 StatusOr<HloInstructionSequence> ScheduleComputation(
88     HloComputation* computation,
89     const LogicalBuffer::SizeFunction& size_function);
90 
91 // A pass which schedules the HLO instructions in a module. The HloModule's
92 // schedule field is set to the resulting HloSchedule using
93 // HloModule::set_schedule.
94 class HloMemoryScheduler : public HloModulePass {
95  public:
96   // size_function is the function returning the number of bytes required for a
97   // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
98   // specified, then DefaultMemoryScheduler is used.
99   HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
100                      const MemorySchedulerAlgorithm& algorithm = {});
101   ~HloMemoryScheduler() override = default;
name()102   absl::string_view name() const override { return "hlo-memory-scheduler"; }
103 
104   StatusOr<bool> Run(HloModule* module) override;
105 
106  private:
107   LogicalBuffer::SizeFunction size_function_;
108   MemorySchedulerAlgorithm algorithm_;
109 };
110 
111 // A pass which produces a naive, but correct schedule. The schedule is produced
112 // using a DFS traversal of the graph with no attempt to minimize memory use.
113 class HloTrivialScheduler : public HloModulePass {
114  public:
name()115   absl::string_view name() const override { return "hlo-trivial-scheduler"; }
116 
117   StatusOr<bool> Run(HloModule* module) override;
118 };
119 
120 // A trivial pass which clears the schedule currently set on the
121 // HloModule. After this pass runs HloModule::has_schedule will return false.
122 class HloDescheduler : public HloModulePass {
123  public:
124   HloDescheduler() = default;
125   ~HloDescheduler() override = default;
name()126   absl::string_view name() const override { return "hlo-descheduler"; }
127 
128   StatusOr<bool> Run(HloModule* module) override;
129 };
130 
131 }  // namespace xla
132 
133 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
134