1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 18 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/service/hlo_module.h" 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 23 24 namespace xla { 25 26 class AlgebraicSimplifierOptions { 27 public: AlgebraicSimplifierOptions()28 AlgebraicSimplifierOptions() {} 29 // Platform dependent callback to determine if a reshape `from_shape` to 30 // `to_shape` is a bitcast. 31 using ReshapeIsBitcastCallback = 32 std::function<bool(const Shape& from_shape, const Shape& to_shape)>; AlgebraicSimplifierOptions(ReshapeIsBitcastCallback reshape_is_bitcast_callback)33 explicit AlgebraicSimplifierOptions( 34 ReshapeIsBitcastCallback reshape_is_bitcast_callback) 35 : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {} 36 37 // Use the platform specific callback if set. It is not sensible to return 38 // true here if the options are not layout sensitive. ReshapeIsBitcast(const Shape & from_shape,const Shape & to_shape)39 bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { 40 if (!is_layout_sensitive_) { 41 return false; 42 } 43 if (!reshape_is_bitcast_callback_) { 44 return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); 45 } 46 return reshape_is_bitcast_callback_(from_shape, to_shape); 47 } 48 49 // If is_layout_sensitive is true, then the simplifier preserves layout during 50 // transformation. Otherwise, layout is ignored. set_is_layout_sensitive(bool is_layout_sensitive)51 void set_is_layout_sensitive(bool is_layout_sensitive) { 52 is_layout_sensitive_ = is_layout_sensitive; 53 } 54 is_layout_sensitive()55 bool is_layout_sensitive() const { return is_layout_sensitive_; } 56 57 // Enable dot simplification on platforms where it is profitable. set_enable_dot_strength_reduction(bool enable_dot_strength_reduction)58 void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { 59 enable_dot_strength_reduction_ = enable_dot_strength_reduction; 60 } 61 enable_dot_strength_reduction()62 bool enable_dot_strength_reduction() const { 63 return enable_dot_strength_reduction_; 64 } 65 66 // Enable dot->multiple rewrite for dot as an outer-product set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite)67 void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) { 68 enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite; 69 } 70 enable_dot_to_multiply_rewrite()71 bool enable_dot_to_multiply_rewrite() const { 72 return enable_dot_to_multiply_rewrite_; 73 } 74 75 // Enable convolution simplification on platforms where it is profitable. set_enable_conv_simplification(bool enable_conv_simplification)76 void set_enable_conv_simplification(bool enable_conv_simplification) { 77 enable_conv_simplification_ = enable_conv_simplification; 78 } enable_conv_simplification()79 bool enable_conv_simplification() const { 80 return enable_conv_simplification_; 81 } 82 83 // Enable convolution operand swapping on platforms where it is supported. set_enable_conv_operand_swap(bool enable_conv_operand_swap)84 void set_enable_conv_operand_swap(bool enable_conv_operand_swap) { 85 enable_conv_operand_swap_ = enable_conv_operand_swap; 86 } enable_conv_operand_swap()87 bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } 88 89 // Move constant scalar multiply to one operand or output of convolutions with 90 // the smallest tensor size, to reduce the number of scalar multiply. set_enable_scalar_multiply_reduction(bool enable_scalar_multiply_reduction)91 void set_enable_scalar_multiply_reduction( 92 bool enable_scalar_multiply_reduction) { 93 enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction; 94 } 95 enable_scalar_multiply_reduction()96 bool enable_scalar_multiply_reduction() const { 97 return enable_scalar_multiply_reduction_; 98 } 99 100 // Also the algebraic simplifer to treat floating point values like real 101 // numbers. set_enable_floats_are_real(bool enable_floats_are_real)102 void set_enable_floats_are_real(bool enable_floats_are_real) { 103 enable_floats_are_real_ = enable_floats_are_real; 104 } 105 enable_floats_are_real()106 bool enable_floats_are_real() const { return enable_floats_are_real_; } 107 108 // If enable_window_reduce_replacement is true, the kReduceWindow instruction 109 // can be optimized by replacement with simpler operations. set_enable_window_reduce_to_reduce_replacement(bool enable_window_reduce_to_reduce_replacement)110 void set_enable_window_reduce_to_reduce_replacement( 111 bool enable_window_reduce_to_reduce_replacement) { 112 enable_window_reduce_to_reduce_replacement_ = 113 enable_window_reduce_to_reduce_replacement; 114 } 115 enable_window_reduce_to_reduce_replacement()116 bool enable_window_reduce_to_reduce_replacement() const { 117 return enable_window_reduce_to_reduce_replacement_; 118 } 119 120 // Sets the size of a gather operand that can be unrolled into many selects. set_very_small_gather_size(int64 size)121 void set_very_small_gather_size(int64 size) { 122 very_small_gather_size_ = size; 123 } 124 very_small_gather_size()125 int64 very_small_gather_size() const { return very_small_gather_size_; } 126 set_cudnn_batchnorm_forward_training_metadata(const string & c)127 void set_cudnn_batchnorm_forward_training_metadata(const string& c) { 128 metadata_.cudnn_batchnorm_forward_training_metadata = c; 129 } 130 get_cudnn_batchnorm_forward_training_metadata()131 const string& get_cudnn_batchnorm_forward_training_metadata() const { 132 return metadata_.cudnn_batchnorm_forward_training_metadata; 133 } 134 set_enable_reduce_of_reshape(bool enable_reduce_of_reshape)135 void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) { 136 enable_reduce_of_reshape_ = enable_reduce_of_reshape; 137 } 138 enable_reduce_of_reshape()139 bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } 140 set_enable_negative_padding_replacement(bool enable_negative_padding_replacement)141 void set_enable_negative_padding_replacement( 142 bool enable_negative_padding_replacement) { 143 enable_negative_padding_replacement_ = enable_negative_padding_replacement; 144 } 145 enable_negative_padding_replacement()146 bool enable_negative_padding_replacement() const { 147 return enable_negative_padding_replacement_; 148 } 149 set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast)150 void set_replace_transpose_with_bitcast(bool replace_transpose_with_bitcast) { 151 replace_transpose_with_bitcast_ = replace_transpose_with_bitcast; 152 } 153 replace_transpose_with_bitcast()154 bool replace_transpose_with_bitcast() const { 155 return replace_transpose_with_bitcast_; 156 } 157 158 private: 159 // Metadata struct can be used to store any metadata information encapsulated 160 // with the AlgebraicSimplierOptions that can be later used in an 161 // AlgebraicSimplifier pass. For example, 162 // cudnn_batchnorm_forward_training_metadata can be used to store the name of 163 // a custom call. If the custom call is 164 // __cudnn$batchNormalizationForwardTraining, the output with index 2 is 165 // guaranteed to be postive. This property has been used to recursively 166 // determine if the operand of an instruction is always positive. 167 struct Metadata { 168 string cudnn_batchnorm_forward_training_metadata{""}; MetadataMetadata169 Metadata() {} 170 }; 171 ReshapeIsBitcastCallback reshape_is_bitcast_callback_; 172 bool is_layout_sensitive_{false}; 173 bool enable_dot_strength_reduction_{true}; 174 bool enable_dot_to_multiply_rewrite_{true}; 175 bool enable_conv_simplification_{true}; 176 bool enable_conv_operand_swap_{true}; 177 bool enable_scalar_multiply_reduction_{false}; 178 bool enable_floats_are_real_{false}; 179 bool enable_window_reduce_to_reduce_replacement_{true}; 180 bool enable_reduce_of_reshape_{true}; 181 bool enable_negative_padding_replacement_{true}; 182 bool replace_transpose_with_bitcast_{true}; 183 int64 very_small_gather_size_{4}; 184 Metadata metadata_; 185 }; 186 187 // A pass which performs algebraic simplifications. 188 class AlgebraicSimplifier : public HloModulePass { 189 public: 190 // If is_layout_sensitive is true, then the simplifier preserves layout during 191 // transformation. Otherwise, layout is ignored. AlgebraicSimplifier(const AlgebraicSimplifierOptions & options)192 explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) 193 : options_(options) {} 194 ~AlgebraicSimplifier() override = default; name()195 absl::string_view name() const override { return "algsimp"; } 196 197 // Run algebraic simplification on the given computation. Returns whether the 198 // computation was changed. 199 StatusOr<bool> Run(HloModule* module) override; 200 201 // Create constant from literal with tiles and element size updated in the 202 // constant's layout. CreateConstantWithLayoutUpdated(Literal literal)203 std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated( 204 Literal literal) { 205 auto constant = HloInstruction::CreateConstant(std::move(literal)); 206 UpdateLayout(constant->mutable_shape()); 207 return constant; 208 } 209 210 private: 211 AlgebraicSimplifierOptions options_; 212 }; 213 214 } // namespace xla 215 216 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 217