• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_module.h"
19 #include "tensorflow/compiler/xla/shape_util.h"
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace xla {
23 
instructions_to_modify(const HloComputation * computation)24 std::vector<HloInstruction*> ReducePrecisionInsertion::instructions_to_modify(
25     const HloComputation* computation) {
26   std::vector<HloInstruction*> instruction_list;
27 
28   switch (location_) {
29     case HloReducePrecisionOptions::OP_INPUTS:
30     case HloReducePrecisionOptions::OP_OUTPUTS:
31     case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
32       for (auto* instruction : computation->instructions()) {
33         VLOG(4) << "Visited instruction: " << instruction->ToString();
34         if (instruction_filter_function_(instruction)) {
35           instruction_list.push_back(instruction);
36         }
37       }
38       break;
39 
40     case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
41     case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
42       for (auto* instruction : computation->instructions()) {
43         VLOG(4) << "Visited instruction: " << instruction->ToString();
44         if (instruction->opcode() != HloOpcode::kFusion) {
45           continue;
46         }
47         for (auto* fused_instruction :
48              instruction->fused_instructions_computation()->instructions()) {
49           VLOG(4) << "Checking sub-instruction: "
50                   << fused_instruction->ToString();
51           if (instruction_filter_function_(fused_instruction)) {
52             instruction_list.push_back(instruction);
53             break;
54           }
55         }
56       }
57       break;
58 
59     default:
60       break;
61   }
62   VLOG(1) << "Found " << instruction_list.size()
63           << " candidate instruction(s) for reduce-precision insertion";
64 
65   return instruction_list;
66 }
67 
insert_after(HloInstruction * instruction)68 StatusOr<bool> ReducePrecisionInsertion::insert_after(
69     HloInstruction* instruction) {
70   // Check that this isn't already an equivalent operation.
71   if (is_redundant(instruction)) {
72     VLOG(2) << "Skipped: instruction is already an equivalent"
73                " reduce-precision instruction:"
74             << instruction->ToString();
75     return false;
76   }
77 
78   // Check that we haven't already inserted an equivalent reduce-precision
79   // operation after this instruction.  (The zero-user case occurs when this is
80   // the root instruction.)
81   if (instruction->user_count() > 0) {
82     bool redundant_followers = true;
83     for (HloInstruction* user : instruction->users()) {
84       if (!is_redundant(user)) {
85         redundant_followers = false;
86         break;
87       }
88     }
89     if (redundant_followers) {
90       VLOG(2) << "Skipped: instruction already followed by equivalent"
91                  " reduce-precision instructions";
92       return false;
93     }
94   }
95 
96   HloInstruction* reduced = instruction->parent()->AddInstruction(
97       HloInstruction::CreateReducePrecision(instruction->shape(), instruction,
98                                             exponent_bits_, mantissa_bits_));
99   TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(reduced));
100   return true;
101 }
102 
insert_on_inputs(const std::vector<HloInstruction * > & instructions)103 StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
104     const std::vector<HloInstruction*>& instructions) {
105   bool computation_changed = false;
106   for (auto instruction : instructions) {
107     VLOG(2) << "Adding reduce-precision operation to inputs of instruction: "
108             << instruction->ToString();
109     for (int64 i = 0; i < instruction->operand_count(); i++) {
110       HloInstruction* operand = instruction->mutable_operand(i);
111       VLOG(2) << "Adding to operand " << i << ": " << operand;
112 
113       if (!is_valid_shape(operand->shape())) {
114         VLOG(2) << "Skipped: value is not of type F32";
115         continue;
116       }
117 
118       if (is_redundant(operand)) {
119         VLOG(2) << "Skipped: operand is already an equivalent reduce-precision"
120                    " instruction";
121         continue;
122       }
123 
124       if (instruction->opcode() == HloOpcode::kFusion &&
125           (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
126            instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) {
127         // Insert the reduce-precision operation inside the fusion computation,
128         // after the corresponding parameter instruction.
129         TF_ASSIGN_OR_RETURN(
130             bool instruction_changed,
131             insert_after(instruction->fused_instructions_computation()
132                              ->parameter_instruction(i)));
133         computation_changed |= instruction_changed;
134       } else {
135         // Look for an existing reduce-precision operation on the operand.  (We
136         // need to be careful not to create a loop, though!)
137         HloInstruction* reduced = nullptr;
138         for (auto& user : operand->users()) {
139           if (user != instruction &&
140               user->opcode() == HloOpcode::kReducePrecision &&
141               user->exponent_bits() == exponent_bits_ &&
142               user->mantissa_bits() == mantissa_bits_) {
143             reduced = user;
144             break;
145           }
146         }
147         // If there wasn't an existing reduce-precision operation, create one.
148         if (!reduced) {
149           reduced = instruction->parent()->AddInstruction(
150               HloInstruction::CreateReducePrecision(
151                   operand->shape(), operand, exponent_bits_, mantissa_bits_));
152         }
153         // Insert the reduce-precision operation before the operand.
154         TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(i, reduced));
155         computation_changed = true;
156       }
157     }
158   }
159 
160   return computation_changed;
161 }
162 
insert_on_outputs(const std::vector<HloInstruction * > & instructions)163 StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
164     const std::vector<HloInstruction*>& instructions) {
165   bool computation_changed = false;
166   for (const auto& instruction : instructions) {
167     VLOG(2) << "Adding reduce-precision operation to output of instruction: "
168             << instruction->ToString();
169 
170     if (!is_valid_shape(instruction->shape())) {
171       VLOG(2) << "Skipped: value is not of type F32";
172       continue;
173     }
174 
175     if (instruction->opcode() == HloOpcode::kFusion &&
176         (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
177          instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) {
178       // Insert the reduce-precision operation as the last operation inside
179       // the fusion computation.
180       HloInstruction* fusion_root = instruction->fused_expression_root();
181       VLOG(2) << "Inserting new operation after existing fusion root: "
182               << fusion_root->ToString();
183 
184       TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(fusion_root));
185       computation_changed |= instruction_changed;
186     } else {
187       // Insert the reduce-precision operation after the instruction.
188       TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(instruction));
189       computation_changed |= instruction_changed;
190     }
191   }
192 
193   return computation_changed;
194 }
195 
Run(HloModule * module)196 StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
197   bool changed = false;
198   VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
199 
200   for (auto* computation : module->MakeNonfusionComputations()) {
201     StatusOr<bool> computation_changed;
202     switch (location_) {
203       case HloReducePrecisionOptions::OP_INPUTS:
204       case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
205         computation_changed = ReducePrecisionInsertion::insert_on_inputs(
206             instructions_to_modify(computation));
207         break;
208 
209       case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
210       case HloReducePrecisionOptions::OP_OUTPUTS:
211       case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
212         computation_changed = ReducePrecisionInsertion::insert_on_outputs(
213             instructions_to_modify(computation));
214         break;
215       default:
216         break;
217     }
218     TF_RETURN_IF_ERROR(computation_changed.status());
219 
220     if (computation_changed.ValueOrDie()) {
221       changed = true;
222       VLOG(3) << "Computation after reduce-precision insertion:";
223       XLA_VLOG_LINES(3, computation->ToString());
224     } else {
225       VLOG(3) << "Computation " << computation->name() << " unchanged";
226     }
227   }
228 
229   return changed;
230 }
231 
232 ReducePrecisionInsertion::InstructionFilterFunction
make_filter_function(const HloReducePrecisionOptions & reduce_precision_options)233 ReducePrecisionInsertion::make_filter_function(
234     const HloReducePrecisionOptions& reduce_precision_options) {
235   // Implement the filter function with a lookup table.
236   std::vector<bool> opcode_filter(HloOpcodeCount(), false);
237   for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
238     opcode_filter[opcode] = true;
239   }
240   if (reduce_precision_options.opname_substrings_to_suffix_size() == 0) {
241     return [opcode_filter](const HloInstruction* instruction) {
242       return opcode_filter[static_cast<unsigned int>(instruction->opcode())];
243     };
244   } else {
245     std::vector<string> opname_substrings;
246     for (const auto& substring :
247          reduce_precision_options.opname_substrings_to_suffix()) {
248       opname_substrings.push_back(substring);
249     }
250     return [opcode_filter,
251             opname_substrings](const HloInstruction* instruction) {
252       if (!opcode_filter[static_cast<unsigned int>(instruction->opcode())]) {
253         return false;
254       }
255       const auto& opname = instruction->metadata().op_name();
256       for (const auto& substring : opname_substrings) {
257         if (opname.find(substring) != string::npos) {
258           return true;
259         }
260       }
261       return false;
262     };
263   }
264 }
265 
make_options_proto(const HloReducePrecisionOptions::Location location,const int exponent_bits,const int mantissa_bits,const std::function<bool (HloOpcode)> & opcode_filter_function,const std::vector<string> & opname_substring_list)266 HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
267     const HloReducePrecisionOptions::Location location, const int exponent_bits,
268     const int mantissa_bits,
269     const std::function<bool(HloOpcode)>& opcode_filter_function,
270     const std::vector<string>& opname_substring_list) {
271   HloReducePrecisionOptions options;
272   options.set_location(location);
273   options.set_exponent_bits(exponent_bits);
274   options.set_mantissa_bits(mantissa_bits);
275   for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
276     if (opcode_filter_function(static_cast<HloOpcode>(opcode))) {
277       options.add_opcodes_to_suffix(opcode);
278     }
279   }
280   for (auto& string : opname_substring_list) {
281     options.add_opname_substrings_to_suffix(string);
282   }
283   return options;
284 }
285 
AddPasses(HloPassPipeline * pipeline,const DebugOptions & debug_options,const PassTiming pass_timing)286 bool ReducePrecisionInsertion::AddPasses(HloPassPipeline* pipeline,
287                                          const DebugOptions& debug_options,
288                                          const PassTiming pass_timing) {
289   bool passes_added = false;
290   for (const auto& pass_options :
291        debug_options.hlo_reduce_precision_options()) {
292     bool add_pass;
293     switch (pass_options.location()) {
294       case HloReducePrecisionOptions::OP_INPUTS:
295       case HloReducePrecisionOptions::OP_OUTPUTS:
296         add_pass = pass_timing == PassTiming::BEFORE_OPTIMIZATION;
297         break;
298       case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
299       case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
300       case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
301         add_pass = pass_timing == PassTiming::AFTER_FUSION;
302         break;
303       default:
304         add_pass = false;
305     }
306     if (add_pass) {
307       pipeline->AddPass<ReducePrecisionInsertion>(pass_options);
308       passes_added = true;
309     }
310   }
311   return passes_added;
312 }
313 
314 }  // namespace xla
315