1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
38 // Checks whether instr is or transitively contains an instruction that we
39 // shouldn't fold.
40 //
41 // Specifically, we don't fold kRng or kAfterAll instructions:
42 //
43 // - kRng is already marked as side-effecting and so is skipped elsewhere, but
44 // we check for it here. Even kRng weren't side-effecting and took an
45 // explicit seed, we *still* wouldn't want to constant-fold it, because the
46 // evaluator's handling of rng is not guaranteed to be identical to any
47 // particular backend's rng.
48 //
49 // - kAfterAll needs to be skipped because a kAfterAll op with no args can
50 // currently materialize a token "out of thin air". TODO(b/110532604):
51 // Remove this check once AfterAll requires at least one operand, in which
52 // case constant folding will be impossible.
IsOrContainsIllegalInstr(const HloInstruction * instr)53 static bool IsOrContainsIllegalInstr(const HloInstruction* instr) {
54 if (instr->opcode() == HloOpcode::kAfterAll ||
55 instr->opcode() == HloOpcode::kRng) {
56 return true;
57 }
58 for (const HloComputation* c : instr->called_computations()) {
59 if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) {
60 return true;
61 }
62 }
63 return false;
64 }
65
66 /*static*/ std::atomic<int64_t> HloConstantFolding::slow_op_counter_{0};
67
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)68 StatusOr<bool> HloConstantFolding::Run(
69 HloModule* module,
70 const absl::flat_hash_set<absl::string_view>& execution_threads) {
71 // Limit the constant folding to 0 iterations to skip folding loops. This
72 // retains the behavior from before while loop support in HloEvaluator and may
73 // be revised.
74 auto evaluator = std::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
75 // fast-path lets us e.g. use Eigen for matmuls.
76 evaluator->set_use_fast_path(true);
77
78 bool changed = false;
79
80 for (auto* computation :
81 module->MakeNonfusionComputations(execution_threads)) {
82 for (auto* instruction : computation->MakeInstructionPostOrder()) {
83 // Skip dead code.
84 if (instruction->IsDead()) {
85 continue;
86 }
87
88 // We only handle instructions where
89 //
90 // - at least one operand is a constant, and
91 // - all other operands are either constants or broadcast(constant).
92 //
93 // Why this particular set of rules around broadcasts?
94 //
95 // - We don't want to fold broadcast(constant) on its own, because in
96 // general it's "simpler" to remember that it's a broadcast. Also,
97 // algsimp will fold an all-one-value constant into a broadcast, so
98 // we'd just end up fighting with it.
99 //
100 // - We don't want to fold an op where all operands are broadcasts of
101 // constants, because algsimp will transform op(broadcast(constant) =>
102 // broadcast(op(constant)). Then we can constant-fold the smaller op.
103 //
104 // - So the only remaining case is where some but not all operands are
105 // broadcasts of constants, e.g. op(constant, broadcast(constant)).
106 //
107 if (!absl::c_any_of(instruction->operands(),
108 [](const HloInstruction* operand) {
109 return operand->opcode() == HloOpcode::kConstant;
110 }) ||
111 !absl::c_all_of(
112 instruction->operands(), [](const HloInstruction* operand) {
113 return operand->opcode() == HloOpcode::kConstant ||
114 (operand->opcode() == HloOpcode::kBroadcast &&
115 operand->operand(0)->opcode() == HloOpcode::kConstant);
116 })) {
117 continue;
118 }
119
120 // Don't fold Constant, Parameter, and Tuple instructions. Tuple
121 // constants are not directly supported by any backends, hence folding
122 // Tuple is not useful and would in fact be expanded back into kTuple by
123 // Algebraic Simplifier.
124 //
125 // (We do allow folding subcomputations that contain these instructions.)
126 if (instruction->opcode() == HloOpcode::kParameter ||
127 instruction->opcode() == HloOpcode::kConstant ||
128 instruction->opcode() == HloOpcode::kTuple) {
129 continue;
130 }
131
132 // Broadcasts dramatically increase the size of constants, which is often
133 // detrimental to performance and memory capacity, so do not fold
134 // broadcasts.
135 if (instruction->opcode() == HloOpcode::kBroadcast ||
136 instruction->opcode() == HloOpcode::kIota) {
137 continue;
138 }
139
140 // Don't fold across async execution thread if it's not supposed to be
141 // changed by this pass.
142 if (instruction->IsAsynchronous() &&
143 instruction->async_execution_thread() !=
144 instruction->parent()->execution_thread()) {
145 continue;
146 }
147
148 // Do not fold FFT. Evaluating it may significantly increase compile time.
149 if (instruction->opcode() == HloOpcode::kFft) {
150 continue;
151 }
152
153 // Check for instructions that we can't fold even if they appear inside of
154 // a subcomputation (e.g. a kCall).
155 if (IsOrContainsIllegalInstr(instruction)) {
156 continue;
157 }
158
159 // Don't constant-fold side-effecting instructions or instructions which
160 // contain side-effecting instructions.
161 if (instruction->HasSideEffect()) {
162 continue;
163 }
164
165 // Don't constant fold unless it's a net positive or the output is small.
166 if (instruction->shape().IsArray()) {
167 int64_t elements_in_removed_operands = 0;
168 for (HloInstruction* operand : instruction->operands()) {
169 if (operand->user_count() == 1 && operand->shape().IsArray()) {
170 elements_in_removed_operands +=
171 ShapeUtil::ElementsIn(operand->shape());
172 }
173 }
174 int64_t elements_in_constant =
175 ShapeUtil::ElementsIn(instruction->shape());
176
177 static const int64_t kMaximumConstantSizeElements = 45 * 1000 * 1000;
178 if (elements_in_constant > elements_in_removed_operands &&
179 elements_in_constant > kMaximumConstantSizeElements) {
180 continue;
181 }
182 }
183
184 VLOG(5) << "Constant folding: " << instruction->ToString();
185
186 absl::Duration slow_timeout =
187 absl::Seconds(uint64_t{1} << slow_op_counter_.load());
188 // We cannot call `instruction->ToString() within the callback, because
189 // the instruction may be modified and invalidated in place, and ToString
190 // will fail if the compilation is slow. We probably do not want to
191 // call `ToString()` for all the instructions, thus, we only display the
192 // name by default.
193 std::string instruction_msg;
194 if (VLOG_IS_ON(4)) {
195 instruction_msg = instruction->ToString();
196 } else {
197 instruction_msg =
198 absl::StrCat(instruction->name(),
199 " (displaying the full instruction incurs a runtime "
200 "overhead. Raise your logging level to 4 or above).");
201 }
202 SlowOperationAlarm slow_alarm(slow_timeout, [instruction_msg = std::move(
203 instruction_msg),
204 slow_timeout] {
205 const bool ndebug =
206 #if NDEBUG
207 true;
208 #else
209 false;
210 #endif
211 absl::string_view explanation_msg =
212 ndebug
213 ? "This isn't necessarily a bug; constant-folding is "
214 "inherently a trade-off between compilation time and speed "
215 "at runtime. XLA has some guards that attempt to keep "
216 "constant folding from taking too long, but fundamentally "
217 "you'll always be able to come up with an input program that "
218 "takes a long time.\n\n"
219 "If you'd like to file a bug, run with envvar "
220 "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results."
221 : "XLA was built without compiler optimizations, which can be "
222 "slow. Try rebuilding with -c opt.";
223 return absl::StrFormat(
224 "Constant folding an instruction is taking > %s:\n\n"
225 " %s\n\n" // instruction->name() or instruction->ToString()
226 "%s", // explanation_msg
227 absl::FormatDuration(slow_timeout), instruction_msg,
228 explanation_msg);
229 });
230
231 // Currently we skip unimplemented operations.
232 // TODO(b/35975797): Fold constant computations for more operations.
233 Literal result;
234 if (!evaluator->TryEvaluate(
235 instruction, &result,
236 /*recursively_evaluate_nonconstant_operands=*/true)) {
237 VLOG(2) << "Constant folding failed for instruction: "
238 << instruction->ToString();
239 continue;
240 }
241
242 slow_alarm.cancel();
243 if (slow_alarm.fired()) {
244 slow_op_counter_++;
245 }
246
247 VLOG(4) << "Constant folded: " << instruction->ToString();
248
249 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
250 instruction, HloInstruction::CreateConstant(std::move(result))));
251 changed = true;
252 }
253 }
254 return changed;
255 }
256
257 } // namespace xla
258