1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/container/flat_hash_set.h" 20 #include "tensorflow/compiler/xla/service/call_graph.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 26 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 27 #include "tensorflow/compiler/xla/shape.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 30 namespace xla { 31 32 // HLO pass which rematerializes instructions to reduce peak memory use, where 33 // memory use is defined as the total size of all live HLO instruction 34 // values. Parameters and constants are included in memory use estimates. 35 // 36 // CSE will undo the effects of this optimization and should not be run after 37 // this pass. In general, this pass should be run very late, immediately before 38 // code generation. 39 class HloRematerialization : public HloModulePass { 40 public: 41 using ShapeSizeFunction = std::function<int64(const Shape&)>; 42 43 using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>; 44 45 // Helper struct that communicates the before / after sizes for the 46 // rematerialization process. 47 struct RematerializationSizes { 48 int64 before_bytes = -1; 49 int64 after_bytes = -1; 50 }; 51 52 // Mode in which the rematerialization algorithm should be run. 53 enum class RematerializationMode { 54 kRecomputeOnly, // Only consider the kCompress RematStrategy. 55 kCompressOnly, // Only consider the kRecompute RematStrategy. 56 kRecomputeAndCompress // Consider both kRecompute and kRemat. 57 }; 58 59 // Enum to specify whether this rematerialization pass occurs before or after 60 // multi-output fusion. 61 enum class RematerializationPass { 62 kPreFusion, // Rematerialization pass before multi-output fusion. 63 kPostFusion // Rematerialization pass after multi-output fusion. 64 }; 65 DefaultCompactShapeFunction(const Shape & shape)66 static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } 67 68 // Constructor parameters: 69 // 70 // size_function: Function which returns the size in bytes of the top-level 71 // buffer of the given shape. 72 // 73 // memory_limit_bytes: The threshold number of bytes to reduce memory use to 74 // via rematerialization. Size of aliased outputs should be subtracted 75 // from this. 76 // 77 // sizes: Pointer to data structure which records the peak memory usage of 78 // the HLO module before/after rematerialization. Value are set during 79 // Run(). Can be nullptr. 80 // 81 // compact_shape_function: Function which returns the compact form of a 82 // shape. If nullptr is provided, an default identity function is used. 83 explicit HloRematerialization( 84 const ShapeSizeFunction& size_function, int64 memory_limit_bytes, 85 RematerializationSizes* sizes, RematerializationPass pass_location, 86 int block_size_limit, 87 CompactShapeFunction compact_shape_function = nullptr, 88 RematerializationMode mode = RematerializationMode::kRecomputeAndCompress, 89 int64 min_remat_size = 0) size_function_(size_function)90 : size_function_(size_function), 91 memory_limit_bytes_(memory_limit_bytes), 92 sizes_(sizes), 93 pass_location_(pass_location), 94 block_size_limit_(block_size_limit), 95 compact_shape_function_(compact_shape_function == nullptr 96 ? DefaultCompactShapeFunction 97 : std::move(compact_shape_function)), 98 mode_(mode), 99 min_remat_size_(min_remat_size) {} 100 ~HloRematerialization() override = default; 101 name()102 absl::string_view name() const override { return "rematerialization"; } 103 104 // Runs rematerialization on the given module. Returns whether the module was 105 // changed. Requires that the module has a schedule set 106 // (HloModule::has_schedule() is true) before running. Returns whether any 107 // instructions were rematerialized. If memory use is already below the limit 108 // specified in the constructor then no instructions are rematerialized and 109 // false is returned. 110 StatusOr<bool> Run(HloModule* module) override; 111 112 protected: 113 // Rematerializes instructions within the given computation. 'order' is the 114 // order in which the computation's instructions will be emitted in the 115 // backend. Rematerialized instructions will be added to the HLO computation 116 // and inserted into 'order'. 117 virtual StatusOr<bool> RematerializeComputation(HloComputation* computation, 118 HloSchedule* schedule, 119 int64 memory_limit_bytes, 120 int64 min_remat_size); 121 122 // Computes and returns the peak memory used by the given computation. The 123 // peak memory is the maximum total size of all live HLO instruction values at 124 // any program point. 'order' is the order in which the HLO instructions will 125 // be emitted which is used to determine lifespans of HLO values. 126 StatusOr<int64> ComputePeakMemory(const HloComputation* computation, 127 const HloInstructionSequence& order) const; 128 129 // Returns the peak memory usage of the called computations for the given 130 // instruction. Zero is returned if the instruction calls no computations. 131 StatusOr<int64> CalledComputationsMemoryUsage( 132 const HloInstruction* instruction) const; 133 134 // Selects an algorithm to use for HLO scheduling. 135 MemorySchedulerAlgorithm scheduler_algorithm_; 136 137 // Function which computes the size of the top-level buffer of a shape. 138 const ShapeSizeFunction size_function_; 139 140 // The threshold number of bytes to reduce memory use to via 141 // rematerialization. 142 const int64 memory_limit_bytes_; 143 144 // Pointer to data structure which records the peak memory usage of the HLO 145 // module before/after rematerialization 146 RematerializationSizes* sizes_; 147 148 // Specifies whether this rematerialization pass occurs before or after 149 // multi-output fusion. 150 RematerializationPass pass_location_; 151 152 // Maximum number of consecutive instructions to consider for 153 // rematerialization. 154 int block_size_limit_; 155 156 // Converts a shape into compact form, returns the same shape if a shape is 157 // already considered compact. 158 const CompactShapeFunction compact_shape_function_; 159 160 // Call graph of the hlo_module. 161 std::unique_ptr<CallGraph> call_graph_; 162 163 // The peak memory usage of each computation. The map contains only those 164 // computations called from sequential context 165 // (CallContext::kSequential). These values are updated as rematerialization 166 // occurs. 167 absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_; 168 169 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 170 171 // Set of computations which have had rematerialization 172 // applied. Rematerialization is only applied once per computation. 173 absl::flat_hash_set<const HloComputation*> rematerialized_computations_; 174 175 // Count of the total instructions rematerialized. 176 int64 instructions_rematerialized_ = 0; 177 178 // Count of the net instructions added to the HLO module by 179 // rematerialization. This can be different than instructions_rematerialized_ 180 // because some rematerializations are effectively moves in the HLO 181 // schedule. In these cases, the rematerialization instruction replaces all 182 // uses of the original instruction and the original instruction is 183 // dead. Hence, no net instructions were added. 184 int64 net_instructions_added_ = 0; 185 186 // Size of the largest block that has been rematerialized. This is actually an 187 // upper bound (within a factor of 2) on the block size. 188 int max_rematerialized_block_size_ = 0; 189 190 RematerializationMode mode_; 191 192 int64 min_remat_size_; 193 }; 194 195 } // namespace xla 196 197 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 198