• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
18 
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 
26 class AlgebraicSimplifierOptions {
27  public:
AlgebraicSimplifierOptions()28   AlgebraicSimplifierOptions() {}
29   // Platform dependent callback to determine if a reshape `from_shape` to
30   // `to_shape` is a bitcast.
31   using ReshapeIsBitcastCallback =
32       std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
AlgebraicSimplifierOptions(ReshapeIsBitcastCallback reshape_is_bitcast_callback)33   explicit AlgebraicSimplifierOptions(
34       ReshapeIsBitcastCallback reshape_is_bitcast_callback)
35       : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {}
36 
37   // Use the platform specific callback if set. It is not sensible to return
38   // true here if the options are not layout sensitive.
ReshapeIsBitcast(const Shape & from_shape,const Shape & to_shape)39   bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const {
40     if (!is_layout_sensitive_) {
41       return false;
42     }
43     if (!reshape_is_bitcast_callback_) {
44       return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape);
45     }
46     return reshape_is_bitcast_callback_(from_shape, to_shape);
47   }
48 
49   // If is_layout_sensitive is true, then the simplifier preserves layout during
50   // transformation. Otherwise, layout is ignored.
set_is_layout_sensitive(bool is_layout_sensitive)51   void set_is_layout_sensitive(bool is_layout_sensitive) {
52     is_layout_sensitive_ = is_layout_sensitive;
53   }
54 
is_layout_sensitive()55   bool is_layout_sensitive() const { return is_layout_sensitive_; }
56 
57   // Enable dot simplification on platforms where it is profitable.
set_enable_dot_strength_reduction(bool enable_dot_strength_reduction)58   void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) {
59     enable_dot_strength_reduction_ = enable_dot_strength_reduction;
60   }
61 
enable_dot_strength_reduction()62   bool enable_dot_strength_reduction() const {
63     return enable_dot_strength_reduction_;
64   }
65 
66   // Enable dot->multiple rewrite for dot as an outer-product
set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite)67   void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) {
68     enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite;
69   }
70 
enable_dot_to_multiply_rewrite()71   bool enable_dot_to_multiply_rewrite() const {
72     return enable_dot_to_multiply_rewrite_;
73   }
74 
75   // Enable convolution simplification on platforms where it is profitable.
set_enable_conv_simplification(bool enable_conv_simplification)76   void set_enable_conv_simplification(bool enable_conv_simplification) {
77     enable_conv_simplification_ = enable_conv_simplification;
78   }
enable_conv_simplification()79   bool enable_conv_simplification() const {
80     return enable_conv_simplification_;
81   }
82 
83   // Enable convolution operand swapping on platforms where it is supported.
set_enable_conv_operand_swap(bool enable_conv_operand_swap)84   void set_enable_conv_operand_swap(bool enable_conv_operand_swap) {
85     enable_conv_operand_swap_ = enable_conv_operand_swap;
86   }
enable_conv_operand_swap()87   bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; }
88 
89   // Move constant scalar multiply to one operand or output of convolutions with
90   // the smallest tensor size, to reduce the number of scalar multiply.
set_enable_scalar_multiply_reduction(bool enable_scalar_multiply_reduction)91   void set_enable_scalar_multiply_reduction(
92       bool enable_scalar_multiply_reduction) {
93     enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction;
94   }
95 
enable_scalar_multiply_reduction()96   bool enable_scalar_multiply_reduction() const {
97     return enable_scalar_multiply_reduction_;
98   }
99 
100   // Also the algebraic simplifer to treat floating point values like real
101   // numbers.
set_enable_floats_are_real(bool enable_floats_are_real)102   void set_enable_floats_are_real(bool enable_floats_are_real) {
103     enable_floats_are_real_ = enable_floats_are_real;
104   }
105 
enable_floats_are_real()106   bool enable_floats_are_real() const { return enable_floats_are_real_; }
107 
108   // If enable_window_reduce_replacement is true, the kReduceWindow instruction
109   // can be optimized by replacement with simpler operations.
set_enable_window_reduce_to_reduce_replacement(bool enable_window_reduce_to_reduce_replacement)110   void set_enable_window_reduce_to_reduce_replacement(
111       bool enable_window_reduce_to_reduce_replacement) {
112     enable_window_reduce_to_reduce_replacement_ =
113         enable_window_reduce_to_reduce_replacement;
114   }
115 
enable_window_reduce_to_reduce_replacement()116   bool enable_window_reduce_to_reduce_replacement() const {
117     return enable_window_reduce_to_reduce_replacement_;
118   }
119 
120   // Sets the size of a gather operand that can be unrolled into many selects.
set_very_small_gather_size(int64 size)121   void set_very_small_gather_size(int64 size) {
122     very_small_gather_size_ = size;
123   }
124 
very_small_gather_size()125   int64 very_small_gather_size() const { return very_small_gather_size_; }
126 
set_cudnn_batchnorm_forward_training_metadata(const string & c)127   void set_cudnn_batchnorm_forward_training_metadata(const string& c) {
128     metadata_.cudnn_batchnorm_forward_training_metadata = c;
129   }
130 
get_cudnn_batchnorm_forward_training_metadata()131   const string& get_cudnn_batchnorm_forward_training_metadata() const {
132     return metadata_.cudnn_batchnorm_forward_training_metadata;
133   }
134 
set_enable_reduce_of_reshape(bool enable_reduce_of_reshape)135   void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) {
136     enable_reduce_of_reshape_ = enable_reduce_of_reshape;
137   }
138 
enable_reduce_of_reshape()139   bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; }
140 
set_enable_negative_padding_replacement(bool enable_negative_padding_replacement)141   void set_enable_negative_padding_replacement(
142       bool enable_negative_padding_replacement) {
143     enable_negative_padding_replacement_ = enable_negative_padding_replacement;
144   }
145 
enable_negative_padding_replacement()146   bool enable_negative_padding_replacement() const {
147     return enable_negative_padding_replacement_;
148   }
149 
set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast)150   void set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast) {
151     replace_transpose_with_bitcast_ = replace_transpose_with_bitcast;
152   }
153 
replace_transpose_with_bitcast()154   bool replace_transpose_with_bitcast() const {
155     return replace_transpose_with_bitcast_;
156   }
157 
158  private:
159   // Metadata struct can be used to store any metadata information encapsulated
160   // with the AlgebraicSimplierOptions that can be later used in an
161   // AlgebraicSimplifier pass. For example,
162   // cudnn_batchnorm_forward_training_metadata can be used to store the name of
163   // a custom call. If the custom call is
164   // __cudnn$batchNormalizationForwardTraining, the output with index 2 is
165   // guaranteed to be postive. This property has been used to recursively
166   // determine if the operand of an instruction is always positive.
167   struct Metadata {
168     string cudnn_batchnorm_forward_training_metadata{""};
MetadataMetadata169     Metadata() {}
170   };
171   ReshapeIsBitcastCallback reshape_is_bitcast_callback_;
172   bool is_layout_sensitive_{false};
173   bool enable_dot_strength_reduction_{true};
174   bool enable_dot_to_multiply_rewrite_{true};
175   bool enable_conv_simplification_{true};
176   bool enable_conv_operand_swap_{true};
177   bool enable_scalar_multiply_reduction_{false};
178   bool enable_floats_are_real_{false};
179   bool enable_window_reduce_to_reduce_replacement_{true};
180   bool enable_reduce_of_reshape_{true};
181   bool enable_negative_padding_replacement_{true};
182   bool replace_transpose_with_bitcast_{true};
183   int64 very_small_gather_size_{4};
184   Metadata metadata_;
185 };
186 
187 // A pass which performs algebraic simplifications.
188 class AlgebraicSimplifier : public HloModulePass {
189  public:
190   // If is_layout_sensitive is true, then the simplifier preserves layout during
191   // transformation. Otherwise, layout is ignored.
AlgebraicSimplifier(const AlgebraicSimplifierOptions & options)192   explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options)
193       : options_(options) {}
194   ~AlgebraicSimplifier() override = default;
name()195   absl::string_view name() const override { return "algsimp"; }
196 
197   // Run algebraic simplification on the given computation. Returns whether the
198   // computation was changed.
199   StatusOr<bool> Run(HloModule* module) override;
200 
201   // Create constant from literal with tiles and element size updated in the
202   // constant's layout.
CreateConstantWithLayoutUpdated(Literal literal)203   std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated(
204       Literal literal) {
205     auto constant = HloInstruction::CreateConstant(std::move(literal));
206     UpdateLayout(constant->mutable_shape());
207     return constant;
208   }
209 
210  private:
211   AlgebraicSimplifierOptions options_;
212 };
213 
214 }  // namespace xla
215 
216 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
217