• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
18 
19 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
24 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
25 
26 namespace xla {
27 
28 // HLO pass which inserts reduce-precision instructions into the HLO graph, for
29 // purposes of experimenting with the effects of reduced-precision storage of
30 // intermediate values.
31 class ReducePrecisionInsertion : public HloModulePass {
32   using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
33 
34  public:
35   // The exponent_bits and mantissa_bits arguments specify the parameters of
36   // the instructions to insert.  The instructions will be inserted after each
37   // instruction with an opcode for which the instruction_filter_function
38   // function returns true and the output type is F32.
ReducePrecisionInsertion(const int exponent_bits,const int mantissa_bits,const HloReducePrecisionOptions::Location location,const InstructionFilterFunction & instruction_filter_function)39   explicit ReducePrecisionInsertion(
40       const int exponent_bits, const int mantissa_bits,
41       const HloReducePrecisionOptions::Location location,
42       const InstructionFilterFunction& instruction_filter_function)
43       : exponent_bits_(exponent_bits),
44         mantissa_bits_(mantissa_bits),
45         location_(location),
46         instruction_filter_function_(instruction_filter_function) {}
47 
48   // Version of the constructor that takes an HloReducePrecisionOptions proto
49   // rather than explicitly-enumerated parameters, for convenience when
50   // creating passes based on DebugOptions.
ReducePrecisionInsertion(const HloReducePrecisionOptions & reduce_precision_options)51   explicit ReducePrecisionInsertion(
52       const HloReducePrecisionOptions& reduce_precision_options)
53       : exponent_bits_(reduce_precision_options.exponent_bits()),
54         mantissa_bits_(reduce_precision_options.mantissa_bits()),
55         location_(reduce_precision_options.location()),
56         instruction_filter_function_(
57             make_filter_function(reduce_precision_options)) {}
58 
~ReducePrecisionInsertion()59   ~ReducePrecisionInsertion() override{};
60 
name()61   absl::string_view name() const override {
62     return "reduce-precision-insertion";
63   }
64 
65   // Run the pass on the given module. Returns whether the module was changed
66   // (reduce-precision instructions were inserted).
67   StatusOr<bool> Run(HloModule* module) override;
68 
69   // Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
70   // representation and InstructionFilterFunction functions.
71   static InstructionFilterFunction make_filter_function(
72       const HloReducePrecisionOptions& reduce_precision_options);
73   static HloReducePrecisionOptions make_options_proto(
74       const HloReducePrecisionOptions::Location location,
75       const int exponent_bits, const int mantissa_bits,
76       const std::function<bool(HloOpcode)>& opcode_filter_function,
77       const std::vector<string>& opname_substring_list = {});
78 
79   // Enumeration to control which passes should be added.
80   enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION };
81 
82   // Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list
83   // of HloReducePrecisionOptions in a DebugOptions proto.  Returns true if any
84   // passes were added.
85   static bool AddPasses(HloPassPipeline* pipeline,
86                         const DebugOptions& debug_options,
87                         const PassTiming pass_timing);
88 
89  private:
90   // Select the instructions that should have reduce-precision operations
91   // attached to them.
92   std::vector<HloInstruction*> instructions_to_modify(
93       const HloComputation* computation);
94 
95   // Insert a reduce-precision operation into the graph on the output of the
96   // given instruction.
97   StatusOr<bool> insert_after(HloInstruction* instruction);
98 
99   // Insert reduce-precision operations into the graph on the inputs of the
100   // given instructions.  (For fusion instructions, the operations will be
101   // inserted inside the fusion computation, on the outputs of the relevant
102   // input parameters.)
103   StatusOr<bool> insert_on_inputs(
104       const std::vector<HloInstruction*>& instructions);
105 
106   // Insert reduce-precision operations into the graph on the outputs of the
107   // given instructions.  (For fusion instructions, the operations will be
108   // inserted inside the fusion computation as a new root.)
109   StatusOr<bool> insert_on_outputs(
110       const std::vector<HloInstruction*>& instructions);
111 
112   // Is this shape valid for inserting a reduce-precision operation?
is_valid_shape(const Shape & shape)113   bool is_valid_shape(const Shape& shape) {
114     // For now, ReducePrecision is only implemented for F32 arrays, so this
115     // ignores instructions that produce other data.  In particular, this
116     // currently ignores instructions producing tuples, even if those tuples
117     // contain F32 arrays inside them.  The assumption is that in most cases
118     // equivalent behavior can be obtained by adding ReducePrecision
119     // instructions after the instructions that pull the F32 arrays out of
120     // the tuples.
121     return shape.element_type() == PrimitiveType::F32;
122   }
123 
124   // Is this instruction one such that following or preceding it with a new
125   // reduce-precision operation will be redundant?
is_redundant(const HloInstruction * instruction)126   bool is_redundant(const HloInstruction* instruction) {
127     return instruction->opcode() == HloOpcode::kReducePrecision &&
128            instruction->exponent_bits() <= exponent_bits_ &&
129            instruction->mantissa_bits() <= mantissa_bits_;
130   }
131 
132   // Parameters for the precision reduction to be added.
133   const int exponent_bits_;
134   const int mantissa_bits_;
135 
136   // Pass "timing" parameter.  This also controls aspects of how the pass
137   // selects locations to insert instructions.
138   const HloReducePrecisionOptions::Location location_;
139 
140   // User-provided Function to determine whether a given instruction should
141   // have a reduce-precision instruction inserted in its output stream.
142   const InstructionFilterFunction instruction_filter_function_;
143 };
144 
145 }  // namespace xla
146 
147 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
148