1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/container/flat_hash_set.h" 20 #include "tensorflow/compiler/xla/service/call_graph.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 26 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 27 #include "tensorflow/compiler/xla/shape.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 30 namespace xla { 31 32 // HLO pass which rematerializes instructions to reduce peak memory use, where 33 // memory use is defined as the total size of all live HLO instruction 34 // values. Parameters and constants are included in memory use estimates. 35 // 36 // CSE will undo the effects of this optimization and should not be run after 37 // this pass. In general, this pass should be run very late, immediately before 38 // code generation. 39 class HloRematerialization : public HloModulePass { 40 public: 41 using ShapeSizeFunction = std::function<int64(const Shape&)>; 42 43 using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>; 44 45 // Helper struct that communicates the before / after sizes for the 46 // rematerialization process. 47 struct RematerializationSizes { 48 int64 before_bytes; 49 int64 after_bytes; 50 }; 51 52 // Mode in which the rematerialization algorithm should be run. 53 enum class RematerializationMode { 54 kRecomputeOnly, // Only consider the kCompress RematStrategy. 55 kCompressOnly, // Only consider the kRecompute RematStrategy. 56 kRecomputeAndCompress // Consider both kRecompute and kRemat. 57 }; 58 59 // Enum to specify whether this rematerialization pass occurs before or after 60 // multi-output fusion. 61 enum class RematerializationPass { 62 kPreFusion, // Rematerialization pass before multi-output fusion. 63 kPostFusion // Rematerialization pass after multi-output fusion. 64 }; 65 DefaultCompactShapeFunction(const Shape & shape)66 static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } 67 68 // Constructor parameters: 69 // 70 // size_function: Function which returns the size in bytes of the top-level 71 // buffer of the given shape. 72 // 73 // memory_limit_bytes: The threshold number of bytes to reduce memory use to 74 // via rematerialization. Size of aliased outputs should be subtracted 75 // from this. 76 // 77 // sizes: Pointer to data structure which records the peak memory usage of 78 // the HLO module before/after rematerialization. Value are set during 79 // Run(). Can be nullptr. 80 // 81 // compact_shape_function: Function which returns the compact form of a 82 // shape. If nullptr is provided, an default identity function is used. 83 explicit HloRematerialization( 84 const ShapeSizeFunction& size_function, int64 memory_limit_bytes, 85 RematerializationSizes* sizes, RematerializationPass pass_location, 86 CompactShapeFunction compact_shape_function = nullptr, 87 RematerializationMode mode = RematerializationMode::kRecomputeAndCompress) size_function_(size_function)88 : size_function_(size_function), 89 memory_limit_bytes_(memory_limit_bytes), 90 sizes_(sizes), 91 pass_location_(pass_location), 92 compact_shape_function_(compact_shape_function == nullptr 93 ? DefaultCompactShapeFunction 94 : std::move(compact_shape_function)), 95 mode_(mode) {} 96 ~HloRematerialization() override = default; 97 name()98 absl::string_view name() const override { return "rematerialization"; } 99 100 // Runs rematerialization on the given module. Returns whether the module was 101 // changed. Requires that the module has a schedule set 102 // (HloModule::has_schedule() is true) before running. Returns whether any 103 // instructions were rematerialized. If memory use is already below the limit 104 // specified in the constructor then no instructions are rematerialized and 105 // false is returned. 106 StatusOr<bool> Run(HloModule* module) override; 107 108 protected: 109 // Rematerializes instructions within the given computation. 'order' is the 110 // order in which the computation's instructions will be emitted in the 111 // backend. Rematerialized instructions will be added to the HLO computation 112 // and inserted into 'order'. 113 virtual StatusOr<bool> RematerializeComputation(HloComputation* computation, 114 HloSchedule* schedule, 115 int64 memory_limit_bytes); 116 117 // Computes and returns the peak memory used by the given computation. The 118 // peak memory is the maximum total size of all live HLO instruction values at 119 // any program point. 'order' is the order in which the HLO instructions will 120 // be emitted which is used to determine lifespans of HLO values. 121 StatusOr<int64> ComputePeakMemory(const HloComputation* computation, 122 const HloInstructionSequence& order) const; 123 124 // Returns the peak memory usage of the called computations for the given 125 // instruction. Zero is returned if the instruction calls no computations. 126 StatusOr<int64> CalledComputationsMemoryUsage( 127 const HloInstruction* instruction) const; 128 129 // Selects an algorithm to use for HLO scheduling. 130 MemorySchedulerAlgorithm scheduler_algorithm_; 131 132 // Function which computes the size of the top-level buffer of a shape. 133 const ShapeSizeFunction size_function_; 134 135 // The threshold number of bytes to reduce memory use to via 136 // rematerialization. 137 const int64 memory_limit_bytes_; 138 139 // Pointer to data structure which records the peak memory usage of the HLO 140 // module before/after rematerialization 141 RematerializationSizes* sizes_; 142 143 // Specifies whether this rematerialization pass occurs before or after 144 // multi-output fusion. 145 RematerializationPass pass_location_; 146 147 // Converts a shape into compact form, returns the same shape if a shape is 148 // already considered compact. 149 const CompactShapeFunction compact_shape_function_; 150 151 // Call graph of the hlo_module. 152 std::unique_ptr<CallGraph> call_graph_; 153 154 // The peak memory usage of each computation. The map contains only those 155 // computations called from sequential context 156 // (CallContext::kSequential). These values are updated as rematerialization 157 // occurs. 158 absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_; 159 160 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 161 162 // Set of computations which have had rematerialization 163 // applied. Rematerialization is only applied once per computation. 164 absl::flat_hash_set<const HloComputation*> rematerialized_computations_; 165 166 // Count of the total instructions rematerialized. 167 int64 instructions_rematerialized_ = 0; 168 169 // Count of the net instructions added to the HLO module by 170 // rematerialization. This can be different than instructions_rematerialized_ 171 // because some rematerializations are effectively moves in the HLO 172 // schedule. In these cases, the rematerialization instruction replaces all 173 // uses of the original instruction and the original instruction is 174 // dead. Hence, no net instructions were added. 175 int64 net_instructions_added_ = 0; 176 177 RematerializationMode mode_; 178 }; 179 180 } // namespace xla 181 182 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 183