• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/service/call_graph.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
26 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 
30 namespace xla {
31 
32 // HLO pass which rematerializes instructions to reduce peak memory use, where
33 // memory use is defined as the total size of all live HLO instruction
34 // values. Parameters and constants are included in memory use estimates.
35 //
36 // CSE will undo the effects of this optimization and should not be run after
37 // this pass. In general, this pass should be run very late, immediately before
38 // code generation.
39 class HloRematerialization : public HloModulePass {
40  public:
41   using ShapeSizeFunction = std::function<int64(const Shape&)>;
42 
43   using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>;
44 
45   // Helper struct that communicates the before / after sizes for the
46   // rematerialization process.
47   struct RematerializationSizes {
48     int64 before_bytes;
49     int64 after_bytes;
50   };
51 
52   // Mode in which the rematerialization algorithm should be run.
53   enum class RematerializationMode {
54     kRecomputeOnly,        // Only consider the kCompress RematStrategy.
55     kCompressOnly,         // Only consider the kRecompute RematStrategy.
56     kRecomputeAndCompress  // Consider both kRecompute and kRemat.
57   };
58 
59   // Enum to specify whether this rematerialization pass occurs before or after
60   // multi-output fusion.
61   enum class RematerializationPass {
62     kPreFusion,  // Rematerialization pass before multi-output fusion.
63     kPostFusion  // Rematerialization pass after multi-output fusion.
64   };
65 
DefaultCompactShapeFunction(const Shape & shape)66   static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
67 
68   // Constructor parameters:
69   //
70   //   size_function: Function which returns the size in bytes of the top-level
71   //     buffer of the given shape.
72   //
73   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
74   //     via rematerialization. Size of aliased outputs should be subtracted
75   //     from this.
76   //
77   //   sizes: Pointer to data structure which records the peak memory usage of
78   //     the HLO module before/after rematerialization. Value are set during
79   //     Run(). Can be nullptr.
80   //
81   //   compact_shape_function: Function which returns the compact form of a
82   //   shape. If nullptr is provided, an default identity function is used.
83   explicit HloRematerialization(
84       const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
85       RematerializationSizes* sizes, RematerializationPass pass_location,
86       CompactShapeFunction compact_shape_function = nullptr,
87       RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
size_function_(size_function)88       : size_function_(size_function),
89         memory_limit_bytes_(memory_limit_bytes),
90         sizes_(sizes),
91         pass_location_(pass_location),
92         compact_shape_function_(compact_shape_function == nullptr
93                                     ? DefaultCompactShapeFunction
94                                     : std::move(compact_shape_function)),
95         mode_(mode) {}
96   ~HloRematerialization() override = default;
97 
name()98   absl::string_view name() const override { return "rematerialization"; }
99 
100   // Runs rematerialization on the given module. Returns whether the module was
101   // changed. Requires that the module has a schedule set
102   // (HloModule::has_schedule() is true) before running. Returns whether any
103   // instructions were rematerialized. If memory use is already below the limit
104   // specified in the constructor then no instructions are rematerialized and
105   // false is returned.
106   StatusOr<bool> Run(HloModule* module) override;
107 
108  protected:
109   // Rematerializes instructions within the given computation. 'order' is the
110   // order in which the computation's instructions will be emitted in the
111   // backend. Rematerialized instructions will be added to the HLO computation
112   // and inserted into 'order'.
113   virtual StatusOr<bool> RematerializeComputation(HloComputation* computation,
114                                                   HloSchedule* schedule,
115                                                   int64 memory_limit_bytes);
116 
117   // Computes and returns the peak memory used by the given computation. The
118   // peak memory is the maximum total size of all live HLO instruction values at
119   // any program point. 'order' is the order in which the HLO instructions will
120   // be emitted which is used to determine lifespans of HLO values.
121   StatusOr<int64> ComputePeakMemory(const HloComputation* computation,
122                                     const HloInstructionSequence& order) const;
123 
124   // Returns the peak memory usage of the called computations for the given
125   // instruction. Zero is returned if the instruction calls no computations.
126   StatusOr<int64> CalledComputationsMemoryUsage(
127       const HloInstruction* instruction) const;
128 
129   // Selects an algorithm to use for HLO scheduling.
130   MemorySchedulerAlgorithm scheduler_algorithm_;
131 
132   // Function which computes the size of the top-level buffer of a shape.
133   const ShapeSizeFunction size_function_;
134 
135   // The threshold number of bytes to reduce memory use to via
136   // rematerialization.
137   const int64 memory_limit_bytes_;
138 
139   // Pointer to data structure which records the peak memory usage of the HLO
140   // module before/after rematerialization
141   RematerializationSizes* sizes_;
142 
143   // Specifies whether this rematerialization pass occurs before or after
144   // multi-output fusion.
145   RematerializationPass pass_location_;
146 
147   // Converts a shape into compact form, returns the same shape if a shape is
148   // already considered compact.
149   const CompactShapeFunction compact_shape_function_;
150 
151   // Call graph of the hlo_module.
152   std::unique_ptr<CallGraph> call_graph_;
153 
154   // The peak memory usage of each computation. The map contains only those
155   // computations called from sequential context
156   // (CallContext::kSequential). These values are updated as rematerialization
157   // occurs.
158   absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_;
159 
160   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
161 
162   // Set of computations which have had rematerialization
163   // applied. Rematerialization is only applied once per computation.
164   absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
165 
166   // Count of the total instructions rematerialized.
167   int64 instructions_rematerialized_ = 0;
168 
169   // Count of the net instructions added to the HLO module by
170   // rematerialization. This can be different than instructions_rematerialized_
171   // because some rematerializations are effectively moves in the HLO
172   // schedule. In these cases, the rematerialization instruction replaces all
173   // uses of the original instruction and the original instruction is
174   // dead. Hence, no net instructions were added.
175   int64 net_instructions_added_ = 0;
176 
177   RematerializationMode mode_;
178 };
179 
180 }  // namespace xla
181 
182 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
183