1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_module.h"
19 #include "tensorflow/compiler/xla/shape_util.h"
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace xla {
23
instructions_to_modify(const HloComputation * computation)24 std::vector<HloInstruction*> ReducePrecisionInsertion::instructions_to_modify(
25 const HloComputation* computation) {
26 std::vector<HloInstruction*> instruction_list;
27
28 switch (location_) {
29 case HloReducePrecisionOptions::OP_INPUTS:
30 case HloReducePrecisionOptions::OP_OUTPUTS:
31 case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
32 for (auto* instruction : computation->instructions()) {
33 VLOG(4) << "Visited instruction: " << instruction->ToString();
34 if (instruction_filter_function_(instruction)) {
35 instruction_list.push_back(instruction);
36 }
37 }
38 break;
39
40 case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
41 case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
42 for (auto* instruction : computation->instructions()) {
43 VLOG(4) << "Visited instruction: " << instruction->ToString();
44 if (instruction->opcode() != HloOpcode::kFusion) {
45 continue;
46 }
47 for (auto* fused_instruction :
48 instruction->fused_instructions_computation()->instructions()) {
49 VLOG(4) << "Checking sub-instruction: "
50 << fused_instruction->ToString();
51 if (instruction_filter_function_(fused_instruction)) {
52 instruction_list.push_back(instruction);
53 break;
54 }
55 }
56 }
57 break;
58
59 default:
60 break;
61 }
62 VLOG(1) << "Found " << instruction_list.size()
63 << " candidate instruction(s) for reduce-precision insertion";
64
65 return instruction_list;
66 }
67
insert_after(HloInstruction * instruction)68 StatusOr<bool> ReducePrecisionInsertion::insert_after(
69 HloInstruction* instruction) {
70 // Check that this isn't already an equivalent operation.
71 if (is_redundant(instruction)) {
72 VLOG(2) << "Skipped: instruction is already an equivalent"
73 " reduce-precision instruction:"
74 << instruction->ToString();
75 return false;
76 }
77
78 // Check that we haven't already inserted an equivalent reduce-precision
79 // operation after this instruction. (The zero-user case occurs when this is
80 // the root instruction.)
81 if (instruction->user_count() > 0) {
82 bool redundant_followers = true;
83 for (HloInstruction* user : instruction->users()) {
84 if (!is_redundant(user)) {
85 redundant_followers = false;
86 break;
87 }
88 }
89 if (redundant_followers) {
90 VLOG(2) << "Skipped: instruction already followed by equivalent"
91 " reduce-precision instructions";
92 return false;
93 }
94 }
95
96 HloInstruction* reduced = instruction->parent()->AddInstruction(
97 HloInstruction::CreateReducePrecision(instruction->shape(), instruction,
98 exponent_bits_, mantissa_bits_));
99 TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(reduced));
100 return true;
101 }
102
insert_on_inputs(const std::vector<HloInstruction * > & instructions)103 StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
104 const std::vector<HloInstruction*>& instructions) {
105 bool computation_changed = false;
106 for (auto instruction : instructions) {
107 VLOG(2) << "Adding reduce-precision operation to inputs of instruction: "
108 << instruction->ToString();
109 for (int64 i = 0; i < instruction->operand_count(); i++) {
110 HloInstruction* operand = instruction->mutable_operand(i);
111 VLOG(2) << "Adding to operand " << i << ": " << operand;
112
113 if (!is_valid_shape(operand->shape())) {
114 VLOG(2) << "Skipped: value is not of type F32";
115 continue;
116 }
117
118 if (is_redundant(operand)) {
119 VLOG(2) << "Skipped: operand is already an equivalent reduce-precision"
120 " instruction";
121 continue;
122 }
123
124 if (instruction->opcode() == HloOpcode::kFusion &&
125 (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
126 instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) {
127 // Insert the reduce-precision operation inside the fusion computation,
128 // after the corresponding parameter instruction.
129 TF_ASSIGN_OR_RETURN(
130 bool instruction_changed,
131 insert_after(instruction->fused_instructions_computation()
132 ->parameter_instruction(i)));
133 computation_changed |= instruction_changed;
134 } else {
135 // Look for an existing reduce-precision operation on the operand. (We
136 // need to be careful not to create a loop, though!)
137 HloInstruction* reduced = nullptr;
138 for (auto& user : operand->users()) {
139 if (user != instruction &&
140 user->opcode() == HloOpcode::kReducePrecision &&
141 user->exponent_bits() == exponent_bits_ &&
142 user->mantissa_bits() == mantissa_bits_) {
143 reduced = user;
144 break;
145 }
146 }
147 // If there wasn't an existing reduce-precision operation, create one.
148 if (!reduced) {
149 reduced = instruction->parent()->AddInstruction(
150 HloInstruction::CreateReducePrecision(
151 operand->shape(), operand, exponent_bits_, mantissa_bits_));
152 }
153 // Insert the reduce-precision operation before the operand.
154 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(i, reduced));
155 computation_changed = true;
156 }
157 }
158 }
159
160 return computation_changed;
161 }
162
insert_on_outputs(const std::vector<HloInstruction * > & instructions)163 StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
164 const std::vector<HloInstruction*>& instructions) {
165 bool computation_changed = false;
166 for (const auto& instruction : instructions) {
167 VLOG(2) << "Adding reduce-precision operation to output of instruction: "
168 << instruction->ToString();
169
170 if (!is_valid_shape(instruction->shape())) {
171 VLOG(2) << "Skipped: value is not of type F32";
172 continue;
173 }
174
175 if (instruction->opcode() == HloOpcode::kFusion &&
176 (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
177 instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) {
178 // Insert the reduce-precision operation as the last operation inside
179 // the fusion computation.
180 HloInstruction* fusion_root = instruction->fused_expression_root();
181 VLOG(2) << "Inserting new operation after existing fusion root: "
182 << fusion_root->ToString();
183
184 TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(fusion_root));
185 computation_changed |= instruction_changed;
186 } else {
187 // Insert the reduce-precision operation after the instruction.
188 TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(instruction));
189 computation_changed |= instruction_changed;
190 }
191 }
192
193 return computation_changed;
194 }
195
Run(HloModule * module)196 StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
197 bool changed = false;
198 VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
199
200 for (auto* computation : module->MakeNonfusionComputations()) {
201 StatusOr<bool> computation_changed;
202 switch (location_) {
203 case HloReducePrecisionOptions::OP_INPUTS:
204 case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
205 computation_changed = ReducePrecisionInsertion::insert_on_inputs(
206 instructions_to_modify(computation));
207 break;
208
209 case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
210 case HloReducePrecisionOptions::OP_OUTPUTS:
211 case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
212 computation_changed = ReducePrecisionInsertion::insert_on_outputs(
213 instructions_to_modify(computation));
214 break;
215 default:
216 break;
217 }
218 TF_RETURN_IF_ERROR(computation_changed.status());
219
220 if (computation_changed.ValueOrDie()) {
221 changed = true;
222 VLOG(3) << "Computation after reduce-precision insertion:";
223 XLA_VLOG_LINES(3, computation->ToString());
224 } else {
225 VLOG(3) << "Computation " << computation->name() << " unchanged";
226 }
227 }
228
229 return changed;
230 }
231
232 ReducePrecisionInsertion::InstructionFilterFunction
make_filter_function(const HloReducePrecisionOptions & reduce_precision_options)233 ReducePrecisionInsertion::make_filter_function(
234 const HloReducePrecisionOptions& reduce_precision_options) {
235 // Implement the filter function with a lookup table.
236 std::vector<bool> opcode_filter(HloOpcodeCount(), false);
237 for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
238 opcode_filter[opcode] = true;
239 }
240 if (reduce_precision_options.opname_substrings_to_suffix_size() == 0) {
241 return [opcode_filter](const HloInstruction* instruction) {
242 return opcode_filter[static_cast<unsigned int>(instruction->opcode())];
243 };
244 } else {
245 std::vector<string> opname_substrings;
246 for (const auto& substring :
247 reduce_precision_options.opname_substrings_to_suffix()) {
248 opname_substrings.push_back(substring);
249 }
250 return [opcode_filter,
251 opname_substrings](const HloInstruction* instruction) {
252 if (!opcode_filter[static_cast<unsigned int>(instruction->opcode())]) {
253 return false;
254 }
255 const auto& opname = instruction->metadata().op_name();
256 for (const auto& substring : opname_substrings) {
257 if (opname.find(substring) != string::npos) {
258 return true;
259 }
260 }
261 return false;
262 };
263 }
264 }
265
make_options_proto(const HloReducePrecisionOptions::Location location,const int exponent_bits,const int mantissa_bits,const std::function<bool (HloOpcode)> & opcode_filter_function,const std::vector<string> & opname_substring_list)266 HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
267 const HloReducePrecisionOptions::Location location, const int exponent_bits,
268 const int mantissa_bits,
269 const std::function<bool(HloOpcode)>& opcode_filter_function,
270 const std::vector<string>& opname_substring_list) {
271 HloReducePrecisionOptions options;
272 options.set_location(location);
273 options.set_exponent_bits(exponent_bits);
274 options.set_mantissa_bits(mantissa_bits);
275 for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
276 if (opcode_filter_function(static_cast<HloOpcode>(opcode))) {
277 options.add_opcodes_to_suffix(opcode);
278 }
279 }
280 for (auto& string : opname_substring_list) {
281 options.add_opname_substrings_to_suffix(string);
282 }
283 return options;
284 }
285
AddPasses(HloPassPipeline * pipeline,const DebugOptions & debug_options,const PassTiming pass_timing)286 bool ReducePrecisionInsertion::AddPasses(HloPassPipeline* pipeline,
287 const DebugOptions& debug_options,
288 const PassTiming pass_timing) {
289 bool passes_added = false;
290 for (const auto& pass_options :
291 debug_options.hlo_reduce_precision_options()) {
292 bool add_pass;
293 switch (pass_options.location()) {
294 case HloReducePrecisionOptions::OP_INPUTS:
295 case HloReducePrecisionOptions::OP_OUTPUTS:
296 add_pass = pass_timing == PassTiming::BEFORE_OPTIMIZATION;
297 break;
298 case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
299 case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
300 case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
301 add_pass = pass_timing == PassTiming::AFTER_FUSION;
302 break;
303 default:
304 add_pass = false;
305 }
306 if (add_pass) {
307 pipeline->AddPass<ReducePrecisionInsertion>(pass_options);
308 passes_added = true;
309 }
310 }
311 return passes_added;
312 }
313
314 } // namespace xla
315