1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ 18 19 #include "absl/strings/string_view.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 22 #include "tensorflow/compiler/xla/statusor.h" 23 24 namespace xla { 25 26 namespace conditional_opt { 27 // At the conceptual level, a boundary can be thought of as representing a 28 // single virtual operation, except this virtual operation is conditionally 29 // instantiated into different concrete operations at each conditional branch. 30 // So a boundary is mapped to a single concrete operation if it is outside of 31 // conditional branches, and is mapped to a list of instructions if inside the 32 // branches. This data structure therefore allows a common data structure 33 // representation of the instructions to be moved, whether they are inside or 34 // outside of the branches. Subsequently, it allows a common implementation 35 // basis to be used for both moving instructions out of and for moving them 36 // inside branches. 37 class Boundary { 38 public: 39 enum class Position { kInsideBranch, kOutsideBranch, kUndefined }; Boundary()40 Boundary() : position_(Position::kUndefined) {} Boundary(Position p)41 explicit Boundary(Position p) : position_(p) {} mutable_operands()42 std::vector<HloInstruction*>& mutable_operands() { return operands_; } operands()43 const std::vector<HloInstruction*>& operands() const { return operands_; } IsInsideBranch()44 bool IsInsideBranch() const { return position_ == Position::kInsideBranch; } IsOutsideBranch()45 bool IsOutsideBranch() const { return position_ == Position::kOutsideBranch; } GetPosition()46 Position GetPosition() const { return position_; } IsEmpty()47 bool IsEmpty() const { return operands_.empty(); } ToString()48 std::string ToString() const { 49 std::string res; 50 for (HloInstruction* op : operands_) { 51 res += op->ToString() + ";"; 52 } 53 return res; 54 } 55 bool operator==(const Boundary& that) { 56 return ContainersEqual(operands_, that.operands_); 57 } 58 59 private: 60 // Boundary instructions in the conditional branches, one from each branch 61 // of the conditional; or a single operand from outside the conditional. 62 std::vector<HloInstruction*> operands_; 63 Position position_; 64 }; 65 66 // HLO pass that moves identical ops in/out of conditional. 67 // - The definition of identical are the shape of the operands are identical 68 // and their properties are identical. 69 // - Only the identical ops that won't share operands with other ops will 70 // be moved out of conditional. 71 // The cost model of the code motion optimization includes two components: 72 // represented by the move_config_ and reuse_config_ arrays of the optimization. 73 // The move_config_ array uses 1 vs 0 to dictate whether each Hlo Opcode, when 74 // used with its first operand being another given Hlo Opcode, is allowed to 75 // move across any conditional boundary; the reuse_config_ array uses an integer 76 // to represent the force between each pair of HloOpcode regarding how 77 // attractive it is to place these instructions together (both inside or outside 78 // of a conditional). Both arrays use Hlo Opcode only to drive the 79 // configuration, regardless of where the operations are located in the 80 // module. 81 class ConditionalCodeMotion : public HloModulePass { 82 public: 83 // If is_layout_sensitive is true, then the hoist process preserves layout 84 // during identical comparison. Otherwise, layout is ignored. 85 // The search configuration is a single integer but is split into four parts: 86 // (sign, n, m, p), where n,m,p each occupy 8 bits and together make the 24 87 // bits at the end of the int32. For the sign part, if search_config is <0, 88 // the reuse_config_ cost model is modified (tuned); if search_config is >0, 89 // the move_config_ cost model is modified (tuned); if search_config == 0, 90 // the default cost model is used with no tuning. When tuning, the entries in 91 // the designated configuration array (move_config_ or reuse_config_) are 92 // flipped between 0 and another default integer, starting from the pth entry 93 // being queried by the optimization and repeated every nth time a new entry 94 // is visited, until a maximal of m entries have been changed. The tuning 95 // start over when optimizing a new model. 96 explicit ConditionalCodeMotion(bool is_layout_sensitive, 97 bool pursue_full_conditional_code_motion, 98 int64_t search_config = 0) is_layout_sensitive_(is_layout_sensitive)99 : is_layout_sensitive_(is_layout_sensitive), 100 pursue_full_conditional_code_motion_( 101 /*turn off special case if tuning*/ 102 pursue_full_conditional_code_motion && search_config == 0), 103 search_config_index_(0) { 104 search_config_.push_back(search_config); 105 if (search_config != 0) { 106 search_config_map_[0] = search_config_; 107 } 108 } ConditionalCodeMotion(bool is_layout_sensitive,bool pursue_full_conditional_code_motion,std::string search_config)109 explicit ConditionalCodeMotion(bool is_layout_sensitive, 110 bool pursue_full_conditional_code_motion, 111 std::string search_config) 112 : is_layout_sensitive_(is_layout_sensitive), 113 pursue_full_conditional_code_motion_( 114 /*turn off special case if tuning*/ 115 pursue_full_conditional_code_motion && search_config.empty()), 116 search_config_index_(-1) { 117 ParseSearchConfiguration(search_config); 118 } 119 // Parse a given string in the format of a sequence of i,s,m,t into a 120 // list of transformation search configurations, each configuration generated 121 // by invoking MakeSearchConfig(s,m,t) and will be used for the ith 122 // conditional encountered when optimizing a given module. 123 void ParseSearchConfiguration(const std::string& search_config); 124 // Make a single search configuration for changing transformation decisions: 125 // flip the decisions at position n = flip_start + flip_stride * m, and 126 // m = 0..max_flip. 127 // The following defines how the int64 search configuration is composed, as 128 // flip_start + (flip_max << kMaxPos) + (flip_stride << kStridePos). 129 // Position (digit) for maximum number of flips. 130 static constexpr int kMaxPos = 16; 131 // Position (digit) for the count-down to the first flip. 132 static constexpr int kStartPos = 0; 133 // Position (digit) for the count-down to the next flip. 134 static constexpr int kStridePos = 32; 135 // Bit mask for extracting the last digits of value. 136 static constexpr int kValueMask = 0xffff; MakeSearchConfig(int64_t start,int64_t max,int64_t stride)137 static int64 MakeSearchConfig(int64_t start, int64_t max, int64_t stride) { 138 const int64_t config = 139 (max << kMaxPos) + (start << kStartPos) + (stride << kStridePos); 140 VLOG(2) << "flip stride = " << flip_stride(config) << "\n"; 141 VLOG(2) << "flig config = " << config << "\n"; 142 return config; 143 } 144 flip_start(int64_t search_config)145 static int16 flip_start(int64_t search_config) { 146 return (search_config >> kStartPos) & kValueMask; 147 } 148 flip_stride(int64_t search_config)149 static int16 flip_stride(int64_t search_config) { 150 return (search_config >> kStridePos) & kValueMask; 151 } 152 DecrementMaxFlip(int64 * search_config)153 static int16 DecrementMaxFlip(int64* search_config) { 154 const int16_t max_flip = ((*search_config) >> kMaxPos) & kValueMask; 155 // Decrement flip count so we can stop if it reaches 0. 156 if (max_flip > 0) { 157 *search_config -= (1 << kMaxPos); 158 } 159 return max_flip; 160 } 161 name()162 absl::string_view name() const override { return "conditional-code-motion"; } 163 StatusOr<bool> Run(HloModule* module) override; 164 165 // Optimization decision for each boundary of the conditional instruction. 166 class Decision { 167 public: 168 enum class Direction : uint8 { 169 kMoveOutOfBranch, 170 kMoveIntoBranch, 171 kNoChange 172 }; 173 174 public: Decision(Direction direction,int benefit)175 Decision(Direction direction, int benefit) 176 : direction_(direction), benefit_(benefit) {} GetDirection()177 Direction GetDirection() const { return direction_; } GetBenefit()178 int GetBenefit() const { return benefit_; } 179 180 private: 181 Direction direction_; 182 int benefit_; 183 }; 184 // If the optimization decision is NO_CHANGE, new_boundary is set to nullptr; 185 // otherwise, it is set to the new boundary after proposed optimization. 186 virtual Decision ConsiderCodeMotion( 187 HloInstruction* conditional, const Boundary& cur_boundary, 188 std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries, 189 absl::flat_hash_map<HloInstruction*, int>& visited_count); 190 191 private: 192 const bool is_layout_sensitive_; 193 const bool pursue_full_conditional_code_motion_; 194 // The following parameterizes the transformation decisions and cost model. 195 std::vector<int64> search_config_; 196 int64 search_config_index_; 197 // Map each conditional to a vector of its search configurations. The key of 198 // the map is the index number of the conditional in a module when traversed 199 // in post order, and the value of the map is the sequence of search 200 // configurations specified with the same index number for the conditional. 201 absl::flat_hash_map<int64, std::vector<int64>> search_config_map_; 202 std::vector<std::vector<int64>> move_config_, reuse_config_; 203 204 StatusOr<bool> MoveInstructionOut(HloInstruction* conditional, 205 std::vector<Boundary>& to_move_out, 206 std::vector<Boundary>& new_boundaries); 207 StatusOr<bool> MoveInstructionIn(HloInstruction* conditional, 208 std::vector<Boundary>& to_move_in, 209 std::vector<Boundary>& new_boundaries); 210 void SetDefaultMoveConfig(); 211 }; 212 } // namespace conditional_opt 213 214 } // namespace xla 215 216 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ 217