1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 18 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/service/hlo_module.h" 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 23 24 namespace xla { 25 26 class AlgebraicSimplifierOptions { 27 public: AlgebraicSimplifierOptions()28 AlgebraicSimplifierOptions() {} 29 // Platform dependent callback to determine if a reshape `from_shape` to 30 // `to_shape` is a bitcast. 31 using ReshapeIsBitcastCallback = 32 std::function<bool(const Shape& from_shape, const Shape& to_shape)>; AlgebraicSimplifierOptions(ReshapeIsBitcastCallback reshape_is_bitcast_callback)33 explicit AlgebraicSimplifierOptions( 34 ReshapeIsBitcastCallback reshape_is_bitcast_callback) 35 : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {} 36 37 // Use the platform specific callback if set. It is not sensible to return 38 // true here if the options are not layout sensitive. ReshapeIsBitcast(const Shape & from_shape,const Shape & to_shape)39 bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { 40 if (!is_layout_sensitive_) { 41 return false; 42 } 43 if (!reshape_is_bitcast_callback_) { 44 return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); 45 } 46 return reshape_is_bitcast_callback_(from_shape, to_shape); 47 } 48 49 // If is_layout_sensitive is true, then the simplifier preserves layout during 50 // transformation. Otherwise, layout is ignored. set_is_layout_sensitive(bool is_layout_sensitive)51 void set_is_layout_sensitive(bool is_layout_sensitive) { 52 is_layout_sensitive_ = is_layout_sensitive; 53 } 54 is_layout_sensitive()55 bool is_layout_sensitive() const { return is_layout_sensitive_; } 56 57 // Enable dot simplification on platforms where it is profitable. set_enable_dot_strength_reduction(bool enable_dot_strength_reduction)58 void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { 59 enable_dot_strength_reduction_ = enable_dot_strength_reduction; 60 } 61 enable_dot_strength_reduction()62 bool enable_dot_strength_reduction() const { 63 return enable_dot_strength_reduction_; 64 } 65 66 // Enable convolution simplification on platforms where it is profitable. set_enable_conv_simplification(bool enable_conv_simplification)67 void set_enable_conv_simplification(bool enable_conv_simplification) { 68 enable_conv_simplification_ = enable_conv_simplification; 69 } enable_conv_simplification()70 bool enable_conv_simplification() const { 71 return enable_conv_simplification_; 72 } 73 74 // If enable_window_reduce_replacement is true, the kReduceWindow instruction 75 // can be optimized by replacement with simpler operations. set_enable_window_reduce_to_reduce_replacement(bool enable_window_reduce_to_reduce_replacement)76 void set_enable_window_reduce_to_reduce_replacement( 77 bool enable_window_reduce_to_reduce_replacement) { 78 enable_window_reduce_to_reduce_replacement_ = 79 enable_window_reduce_to_reduce_replacement; 80 } 81 enable_window_reduce_to_reduce_replacement()82 bool enable_window_reduce_to_reduce_replacement() const { 83 return enable_window_reduce_to_reduce_replacement_; 84 } 85 86 private: 87 ReshapeIsBitcastCallback reshape_is_bitcast_callback_; 88 bool is_layout_sensitive_{false}; 89 bool enable_dot_strength_reduction_{true}; 90 bool enable_conv_simplification_{true}; 91 bool enable_window_reduce_to_reduce_replacement_{true}; 92 }; 93 94 // A pass which performs algebraic simplifications. 95 class AlgebraicSimplifier : public HloModulePass { 96 public: 97 // If is_layout_sensitive is true, then the simplifier preserves layout during 98 // transformation. Otherwise, layout is ignored. AlgebraicSimplifier(const AlgebraicSimplifierOptions & options)99 explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) 100 : options_(options) {} 101 ~AlgebraicSimplifier() override = default; name()102 absl::string_view name() const override { return "algsimp"; } 103 104 // Run algebraic simplification on the given computation. Returns whether the 105 // computation was changed. 106 StatusOr<bool> Run(HloModule* module) override; 107 108 private: 109 AlgebraicSimplifierOptions options_; 110 }; 111 112 } // namespace xla 113 114 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 115