• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/service/call_graph.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
26 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 
30 namespace xla {
31 
32 // HLO pass which rematerializes instructions to reduce peak memory use, where
33 // memory use is defined as the total size of all live HLO instruction
34 // values. Parameters and constants are included in memory use estimates.
35 //
36 // CSE will undo the effects of this optimization and should not be run after
37 // this pass. In general, this pass should be run very late, immediately before
38 // code generation.
39 class HloRematerialization : public HloModulePass {
40  public:
41   using ShapeSizeFunction = std::function<int64(const Shape&)>;
42 
43   using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>;
44 
45   // Helper struct that communicates the before / after sizes for the
46   // rematerialization process.
47   struct RematerializationSizes {
48     int64 before_bytes = -1;
49     int64 after_bytes = -1;
50   };
51 
52   // Mode in which the rematerialization algorithm should be run.
53   enum class RematerializationMode {
54     kRecomputeOnly,        // Only consider the kCompress RematStrategy.
55     kCompressOnly,         // Only consider the kRecompute RematStrategy.
56     kRecomputeAndCompress  // Consider both kRecompute and kRemat.
57   };
58 
59   // Enum to specify whether this rematerialization pass occurs before or after
60   // multi-output fusion.
61   enum class RematerializationPass {
62     kPreFusion,  // Rematerialization pass before multi-output fusion.
63     kPostFusion  // Rematerialization pass after multi-output fusion.
64   };
65 
DefaultCompactShapeFunction(const Shape & shape)66   static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
67 
68   // Constructor parameters:
69   //
70   //   size_function: Function which returns the size in bytes of the top-level
71   //     buffer of the given shape.
72   //
73   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
74   //     via rematerialization. Size of aliased outputs should be subtracted
75   //     from this.
76   //
77   //   sizes: Pointer to data structure which records the peak memory usage of
78   //     the HLO module before/after rematerialization. Value are set during
79   //     Run(). Can be nullptr.
80   //
81   //   compact_shape_function: Function which returns the compact form of a
82   //   shape. If nullptr is provided, an default identity function is used.
83   explicit HloRematerialization(
84       const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
85       RematerializationSizes* sizes, RematerializationPass pass_location,
86       int block_size_limit,
87       CompactShapeFunction compact_shape_function = nullptr,
88       RematerializationMode mode = RematerializationMode::kRecomputeAndCompress,
89       int64 min_remat_size = 0)
size_function_(size_function)90       : size_function_(size_function),
91         memory_limit_bytes_(memory_limit_bytes),
92         sizes_(sizes),
93         pass_location_(pass_location),
94         block_size_limit_(block_size_limit),
95         compact_shape_function_(compact_shape_function == nullptr
96                                     ? DefaultCompactShapeFunction
97                                     : std::move(compact_shape_function)),
98         mode_(mode),
99         min_remat_size_(min_remat_size) {}
100   ~HloRematerialization() override = default;
101 
name()102   absl::string_view name() const override { return "rematerialization"; }
103 
104   // Runs rematerialization on the given module. Returns whether the module was
105   // changed. Requires that the module has a schedule set
106   // (HloModule::has_schedule() is true) before running. Returns whether any
107   // instructions were rematerialized. If memory use is already below the limit
108   // specified in the constructor then no instructions are rematerialized and
109   // false is returned.
110   StatusOr<bool> Run(HloModule* module) override;
111 
112  protected:
113   // Rematerializes instructions within the given computation. 'order' is the
114   // order in which the computation's instructions will be emitted in the
115   // backend. Rematerialized instructions will be added to the HLO computation
116   // and inserted into 'order'.
117   virtual StatusOr<bool> RematerializeComputation(HloComputation* computation,
118                                                   HloSchedule* schedule,
119                                                   int64 memory_limit_bytes,
120                                                   int64 min_remat_size);
121 
122   // Computes and returns the peak memory used by the given computation. The
123   // peak memory is the maximum total size of all live HLO instruction values at
124   // any program point. 'order' is the order in which the HLO instructions will
125   // be emitted which is used to determine lifespans of HLO values.
126   StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
127                                     const HloInstructionSequence& order) const;
128 
129   // Returns the peak memory usage of the called computations for the given
130   // instruction. Zero is returned if the instruction calls no computations.
131   StatusOr<int64> CalledComputationsMemoryUsage(
132       const HloInstruction* instruction) const;
133 
134   // Selects an algorithm to use for HLO scheduling.
135   MemorySchedulerAlgorithm scheduler_algorithm_;
136 
137   // Function which computes the size of the top-level buffer of a shape.
138   const ShapeSizeFunction size_function_;
139 
140   // The threshold number of bytes to reduce memory use to via
141   // rematerialization.
142   const int64 memory_limit_bytes_;
143 
144   // Pointer to data structure which records the peak memory usage of the HLO
145   // module before/after rematerialization
146   RematerializationSizes* sizes_;
147 
148   // Specifies whether this rematerialization pass occurs before or after
149   // multi-output fusion.
150   RematerializationPass pass_location_;
151 
152   // Maximum number of consecutive instructions to consider for
153   // rematerialization.
154   int block_size_limit_;
155 
156   // Converts a shape into compact form, returns the same shape if a shape is
157   // already considered compact.
158   const CompactShapeFunction compact_shape_function_;
159 
160   // Call graph of the hlo_module.
161   std::unique_ptr<CallGraph> call_graph_;
162 
163   // The peak memory usage of each computation. The map contains only those
164   // computations called from sequential context
165   // (CallContext::kSequential). These values are updated as rematerialization
166   // occurs.
167   absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_;
168 
169   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
170 
171   // Set of computations which have had rematerialization
172   // applied. Rematerialization is only applied once per computation.
173   absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
174 
175   // Count of the total instructions rematerialized.
176   int64 instructions_rematerialized_ = 0;
177 
178   // Count of the net instructions added to the HLO module by
179   // rematerialization. This can be different than instructions_rematerialized_
180   // because some rematerializations are effectively moves in the HLO
181   // schedule. In these cases, the rematerialization instruction replaces all
182   // uses of the original instruction and the original instruction is
183   // dead. Hence, no net instructions were added.
184   int64 net_instructions_added_ = 0;
185 
186   // Size of the largest block that has been rematerialized. This is actually an
187   // upper bound (within a factor of 2) on the block size.
188   int max_rematerialized_block_size_ = 0;
189 
190   RematerializationMode mode_;
191 
192   int64 min_remat_size_;
193 };
194 
195 }  // namespace xla
196 
197 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
198