• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
38 // Checks whether instr is or transitively contains an instruction that we
39 // shouldn't fold.
40 //
41 // Specifically, we don't fold kRng or kAfterAll instructions:
42 //
43 //  - kRng is already marked as side-effecting and so is skipped elsewhere, but
44 //    we check for it here.  Even kRng weren't side-effecting and took an
45 //    explicit seed, we *still* wouldn't want to constant-fold it, because the
46 //    evaluator's handling of rng is not guaranteed to be identical to any
47 //    particular backend's rng.
48 //
49 //  - kAfterAll needs to be skipped because a kAfterAll op with no args can
50 //    currently materialize a token "out of thin air".  TODO(b/110532604):
51 //    Remove this check once AfterAll requires at least one operand, in which
52 //    case constant folding will be impossible.
IsOrContainsIllegalInstr(const HloInstruction * instr)53 static bool IsOrContainsIllegalInstr(const HloInstruction* instr) {
54   if (instr->opcode() == HloOpcode::kAfterAll ||
55       instr->opcode() == HloOpcode::kRng) {
56     return true;
57   }
58   for (const HloComputation* c : instr->called_computations()) {
59     if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) {
60       return true;
61     }
62   }
63   return false;
64 }
65 
66 /*static*/ std::atomic<int64_t> HloConstantFolding::slow_op_counter_{0};
67 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)68 StatusOr<bool> HloConstantFolding::Run(
69     HloModule* module,
70     const absl::flat_hash_set<absl::string_view>& execution_threads) {
71   // Limit the constant folding to 0 iterations to skip folding loops. This
72   // retains the behavior from before while loop support in HloEvaluator and may
73   // be revised.
74   auto evaluator = std::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
75   // fast-path lets us e.g. use Eigen for matmuls.
76   evaluator->set_use_fast_path(true);
77 
78   bool changed = false;
79 
80   for (auto* computation :
81        module->MakeNonfusionComputations(execution_threads)) {
82     for (auto* instruction : computation->MakeInstructionPostOrder()) {
83       // Skip dead code.
84       if (instruction->IsDead()) {
85         continue;
86       }
87 
88       // We only handle instructions where
89       //
90       //  - at least one operand is a constant, and
91       //  - all other operands are either constants or broadcast(constant).
92       //
93       // Why this particular set of rules around broadcasts?
94       //
95       //  - We don't want to fold broadcast(constant) on its own, because in
96       //    general it's "simpler" to remember that it's a broadcast.  Also,
97       //    algsimp will fold an all-one-value constant into a broadcast, so
98       //    we'd just end up fighting with it.
99       //
100       //  - We don't want to fold an op where all operands are broadcasts of
101       //    constants, because algsimp will transform op(broadcast(constant) =>
102       //    broadcast(op(constant)).  Then we can constant-fold the smaller op.
103       //
104       //  - So the only remaining case is where some but not all operands are
105       //    broadcasts of constants, e.g. op(constant, broadcast(constant)).
106       //
107       if (!absl::c_any_of(instruction->operands(),
108                           [](const HloInstruction* operand) {
109                             return operand->opcode() == HloOpcode::kConstant;
110                           }) ||
111           !absl::c_all_of(
112               instruction->operands(), [](const HloInstruction* operand) {
113                 return operand->opcode() == HloOpcode::kConstant ||
114                        (operand->opcode() == HloOpcode::kBroadcast &&
115                         operand->operand(0)->opcode() == HloOpcode::kConstant);
116               })) {
117         continue;
118       }
119 
120       // Don't fold Constant, Parameter, and Tuple instructions.  Tuple
121       // constants are not directly supported by any backends, hence folding
122       // Tuple is not useful and would in fact be expanded back into kTuple by
123       // Algebraic Simplifier.
124       //
125       // (We do allow folding subcomputations that contain these instructions.)
126       if (instruction->opcode() == HloOpcode::kParameter ||
127           instruction->opcode() == HloOpcode::kConstant ||
128           instruction->opcode() == HloOpcode::kTuple) {
129         continue;
130       }
131 
132       // Broadcasts dramatically increase the size of constants, which is often
133       // detrimental to performance and memory capacity, so do not fold
134       // broadcasts.
135       if (instruction->opcode() == HloOpcode::kBroadcast ||
136           instruction->opcode() == HloOpcode::kIota) {
137         continue;
138       }
139 
140       // Don't fold across async execution thread if it's not supposed to be
141       // changed by this pass.
142       if (instruction->IsAsynchronous() &&
143           instruction->async_execution_thread() !=
144               instruction->parent()->execution_thread()) {
145         continue;
146       }
147 
148       // Do not fold FFT. Evaluating it may significantly increase compile time.
149       if (instruction->opcode() == HloOpcode::kFft) {
150         continue;
151       }
152 
153       // Check for instructions that we can't fold even if they appear inside of
154       // a subcomputation (e.g. a kCall).
155       if (IsOrContainsIllegalInstr(instruction)) {
156         continue;
157       }
158 
159       // Don't constant-fold side-effecting instructions or instructions which
160       // contain side-effecting instructions.
161       if (instruction->HasSideEffect()) {
162         continue;
163       }
164 
165       // Don't constant fold unless it's a net positive or the output is small.
166       if (instruction->shape().IsArray()) {
167         int64_t elements_in_removed_operands = 0;
168         for (HloInstruction* operand : instruction->operands()) {
169           if (operand->user_count() == 1 && operand->shape().IsArray()) {
170             elements_in_removed_operands +=
171                 ShapeUtil::ElementsIn(operand->shape());
172           }
173         }
174         int64_t elements_in_constant =
175             ShapeUtil::ElementsIn(instruction->shape());
176 
177         static const int64_t kMaximumConstantSizeElements = 45 * 1000 * 1000;
178         if (elements_in_constant > elements_in_removed_operands &&
179             elements_in_constant > kMaximumConstantSizeElements) {
180           continue;
181         }
182       }
183 
184       VLOG(5) << "Constant folding: " << instruction->ToString();
185 
186       absl::Duration slow_timeout =
187           absl::Seconds(uint64_t{1} << slow_op_counter_.load());
188       // We cannot call `instruction->ToString() within the callback, because
189       // the instruction may be modified and invalidated in place, and ToString
190       // will fail if the compilation is slow. We probably do not want to
191       // call `ToString()` for all the instructions, thus, we only display the
192       // name by default.
193       std::string instruction_msg;
194       if (VLOG_IS_ON(4)) {
195         instruction_msg = instruction->ToString();
196       } else {
197         instruction_msg =
198             absl::StrCat(instruction->name(),
199                          " (displaying the full instruction incurs a runtime "
200                          "overhead. Raise your logging level to 4 or above).");
201       }
202       SlowOperationAlarm slow_alarm(slow_timeout, [instruction_msg = std::move(
203                                                        instruction_msg),
204                                                    slow_timeout] {
205         const bool ndebug =
206 #if NDEBUG
207             true;
208 #else
209             false;
210 #endif
211         absl::string_view explanation_msg =
212             ndebug
213                 ? "This isn't necessarily a bug; constant-folding is "
214                   "inherently a trade-off between compilation time and speed "
215                   "at runtime.  XLA has some guards that attempt to keep "
216                   "constant folding from taking too long, but fundamentally "
217                   "you'll always be able to come up with an input program that "
218                   "takes a long time.\n\n"
219                   "If you'd like to file a bug, run with envvar "
220                   "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results."
221                 : "XLA was built without compiler optimizations, which can be "
222                   "slow.  Try rebuilding with -c opt.";
223         return absl::StrFormat(
224             "Constant folding an instruction is taking > %s:\n\n"
225             "  %s\n\n"  // instruction->name() or instruction->ToString()
226             "%s",       // explanation_msg
227             absl::FormatDuration(slow_timeout), instruction_msg,
228             explanation_msg);
229       });
230 
231       // Currently we skip unimplemented operations.
232       // TODO(b/35975797): Fold constant computations for more operations.
233       Literal result;
234       if (!evaluator->TryEvaluate(
235               instruction, &result,
236               /*recursively_evaluate_nonconstant_operands=*/true)) {
237         VLOG(2) << "Constant folding failed for instruction: "
238                 << instruction->ToString();
239         continue;
240       }
241 
242       slow_alarm.cancel();
243       if (slow_alarm.fired()) {
244         slow_op_counter_++;
245       }
246 
247       VLOG(4) << "Constant folded: " << instruction->ToString();
248 
249       TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
250           instruction, HloInstruction::CreateConstant(std::move(result))));
251       changed = true;
252     }
253   }
254   return changed;
255 }
256 
257 }  // namespace xla
258