1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 18 19 #include <type_traits> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 24 #include "tensorflow/compiler/xla/status.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/core/stringpiece.h" 29 #include "tensorflow/core/lib/gtl/array_slice.h" 30 #include "tensorflow/core/lib/gtl/flatmap.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 class HloComputation; 37 class HloInstruction; 38 39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an 40 // instruction, all its operands were already visited. User code can subclass 41 // this to iterate over an HloInstruction DAG. The Handle* routines have 42 // operands / data unpacked for ease of use in the visitor subclass. 43 // 44 // No instruction will ever be visited twice; however, the root instruction will 45 // be reported again when the traversal is done via a call to FinishVisit. 46 // 47 // A subclass must override at least 48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and 49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)). 50 // The default Handle methods for (unary, binary) ops call 51 // (HandleElementwiseUnary, HandleElementwiseBinary). 52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an 53 // "unimplemented" error status. 54 // 55 // Note: this may change to an iterator in the future for flexibility purposes. 56 // 57 // Users should not use this class directly, but use the type-aliases 58 // DfsHloVisitor/ConstDfsHloVisitor instead. 59 template <typename HloInstructionPtr> 60 class DfsHloVisitorBase { 61 static_assert( 62 std::is_same<HloInstruction*, HloInstructionPtr>::value || 63 std::is_same<const HloInstruction*, HloInstructionPtr>::value, 64 "Template argument expected to be HloInstruction* or const " 65 "HloInstruction*"); 66 67 public: DfsHloVisitorBase()68 DfsHloVisitorBase() {} ~DfsHloVisitorBase()69 virtual ~DfsHloVisitorBase() {} 70 71 // These routines are self-descriptive, see class comment for usage 72 // information. 73 74 virtual Status HandleElementwiseUnary(HloInstructionPtr hlo); 75 virtual Status HandleElementwiseBinary(HloInstructionPtr hlo); 76 77 virtual Status HandleClamp(HloInstructionPtr hlo) = 0; 78 virtual Status HandleSelect(HloInstructionPtr hlo) = 0; HandleMaximum(HloInstructionPtr hlo)79 virtual Status HandleMaximum(HloInstructionPtr hlo) { 80 return HandleElementwiseBinary(hlo); 81 } HandleMinimum(HloInstructionPtr hlo)82 virtual Status HandleMinimum(HloInstructionPtr hlo) { 83 return HandleElementwiseBinary(hlo); 84 } 85 virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0; HandleConvert(HloInstructionPtr hlo)86 virtual Status HandleConvert(HloInstructionPtr hlo) { 87 return HandleElementwiseUnary(hlo); 88 } HandleBitcastConvert(HloInstructionPtr hlo)89 virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { 90 return HandleElementwiseUnary(hlo); 91 } HandleCopy(HloInstructionPtr hlo)92 virtual Status HandleCopy(HloInstructionPtr hlo) { 93 return HandleElementwiseUnary(hlo); 94 } HandleComplex(HloInstructionPtr hlo)95 virtual Status HandleComplex(HloInstructionPtr hlo) { 96 return HandleElementwiseBinary(hlo); 97 } HandleMultiply(HloInstructionPtr hlo)98 virtual Status HandleMultiply(HloInstructionPtr hlo) { 99 return HandleElementwiseBinary(hlo); 100 } 101 virtual Status HandleDot(HloInstructionPtr hlo) = 0; HandlePower(HloInstructionPtr hlo)102 virtual Status HandlePower(HloInstructionPtr hlo) { 103 return HandleElementwiseBinary(hlo); 104 } 105 virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; 106 virtual Status HandleFft(HloInstructionPtr fft) = 0; 107 virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; HandleCompare(HloInstructionPtr hlo)108 virtual Status HandleCompare(HloInstructionPtr hlo) { 109 return HandleElementwiseBinary(hlo); 110 } HandleAdd(HloInstructionPtr hlo)111 virtual Status HandleAdd(HloInstructionPtr hlo) { 112 return HandleElementwiseBinary(hlo); 113 } HandleDivide(HloInstructionPtr hlo)114 virtual Status HandleDivide(HloInstructionPtr hlo) { 115 return HandleElementwiseBinary(hlo); 116 } HandleRemainder(HloInstructionPtr hlo)117 virtual Status HandleRemainder(HloInstructionPtr hlo) { 118 return HandleElementwiseBinary(hlo); 119 } HandleSubtract(HloInstructionPtr hlo)120 virtual Status HandleSubtract(HloInstructionPtr hlo) { 121 return HandleElementwiseBinary(hlo); 122 } HandleAbs(HloInstructionPtr hlo)123 virtual Status HandleAbs(HloInstructionPtr hlo) { 124 return HandleElementwiseUnary(hlo); 125 } HandleAtan2(HloInstructionPtr hlo)126 virtual Status HandleAtan2(HloInstructionPtr hlo) { 127 return HandleElementwiseBinary(hlo); 128 } HandleRound(HloInstructionPtr hlo)129 virtual Status HandleRound(HloInstructionPtr hlo) { 130 return HandleElementwiseUnary(hlo); 131 } HandleSign(HloInstructionPtr hlo)132 virtual Status HandleSign(HloInstructionPtr hlo) { 133 return HandleElementwiseUnary(hlo); 134 } HandleNegate(HloInstructionPtr hlo)135 virtual Status HandleNegate(HloInstructionPtr hlo) { 136 return HandleElementwiseUnary(hlo); 137 } HandleExp(HloInstructionPtr hlo)138 virtual Status HandleExp(HloInstructionPtr hlo) { 139 return HandleElementwiseUnary(hlo); 140 } HandleFloor(HloInstructionPtr hlo)141 virtual Status HandleFloor(HloInstructionPtr hlo) { 142 return HandleElementwiseUnary(hlo); 143 } HandleCeil(HloInstructionPtr hlo)144 virtual Status HandleCeil(HloInstructionPtr hlo) { 145 return HandleElementwiseUnary(hlo); 146 } HandleLog(HloInstructionPtr hlo)147 virtual Status HandleLog(HloInstructionPtr hlo) { 148 return HandleElementwiseUnary(hlo); 149 } HandleCos(HloInstructionPtr hlo)150 virtual Status HandleCos(HloInstructionPtr hlo) { 151 return HandleElementwiseUnary(hlo); 152 } HandleSin(HloInstructionPtr hlo)153 virtual Status HandleSin(HloInstructionPtr hlo) { 154 return HandleElementwiseUnary(hlo); 155 } HandleTanh(HloInstructionPtr hlo)156 virtual Status HandleTanh(HloInstructionPtr hlo) { 157 return HandleElementwiseUnary(hlo); 158 } HandleReal(HloInstructionPtr hlo)159 virtual Status HandleReal(HloInstructionPtr hlo) { 160 return HandleElementwiseUnary(hlo); 161 } HandleImag(HloInstructionPtr hlo)162 virtual Status HandleImag(HloInstructionPtr hlo) { 163 return HandleElementwiseUnary(hlo); 164 } HandleIsFinite(HloInstructionPtr hlo)165 virtual Status HandleIsFinite(HloInstructionPtr hlo) { 166 return HandleElementwiseUnary(hlo); 167 } HandleAnd(HloInstructionPtr hlo)168 virtual Status HandleAnd(HloInstructionPtr hlo) { 169 return HandleElementwiseBinary(hlo); 170 } HandleNot(HloInstructionPtr hlo)171 virtual Status HandleNot(HloInstructionPtr hlo) { 172 return HandleElementwiseUnary(hlo); 173 } HandleOr(HloInstructionPtr hlo)174 virtual Status HandleOr(HloInstructionPtr hlo) { 175 return HandleElementwiseBinary(hlo); 176 } HandleShiftLeft(HloInstructionPtr hlo)177 virtual Status HandleShiftLeft(HloInstructionPtr hlo) { 178 return HandleElementwiseBinary(hlo); 179 } HandleShiftRightArithmetic(HloInstructionPtr hlo)180 virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) { 181 return HandleElementwiseBinary(hlo); 182 } HandleShiftRightLogical(HloInstructionPtr hlo)183 virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) { 184 return HandleElementwiseBinary(hlo); 185 } 186 HandleReducePrecision(HloInstructionPtr hlo)187 virtual Status HandleReducePrecision(HloInstructionPtr hlo) { 188 return HandleElementwiseUnary(hlo); 189 } 190 191 virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; 192 virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; 193 virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; 194 virtual Status HandleRng(HloInstructionPtr hlo) = 0; 195 virtual Status HandleReverse(HloInstructionPtr hlo) = 0; 196 virtual Status HandleSort(HloInstructionPtr hlo) = 0; 197 virtual Status HandleConstant(HloInstructionPtr hlo) = 0; 198 virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; 199 virtual Status HandleReduce(HloInstructionPtr hlo) = 0; 200 virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; 201 virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; 202 virtual Status HandleReshape(HloInstructionPtr hlo) = 0; 203 virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; 204 virtual Status HandleParameter(HloInstructionPtr hlo) = 0; 205 virtual Status HandleFusion(HloInstructionPtr hlo) = 0; 206 virtual Status HandleCall(HloInstructionPtr hlo) = 0; 207 virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; 208 virtual Status HandleSlice(HloInstructionPtr hlo) = 0; 209 virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; 210 virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; 211 virtual Status HandleTuple(HloInstructionPtr hlo) = 0; 212 virtual Status HandleMap(HloInstructionPtr hlo) = 0; 213 virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; 214 virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; 215 virtual Status HandleWhile(HloInstructionPtr hlo) = 0; 216 virtual Status HandleConditional(HloInstructionPtr hlo) = 0; 217 virtual Status HandleGather(HloInstructionPtr hlo) = 0; 218 219 virtual Status HandlePad(HloInstructionPtr hlo) = 0; 220 221 virtual Status HandleSend(HloInstructionPtr send) = 0; 222 virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; 223 224 virtual Status HandleRecv(HloInstructionPtr recv) = 0; 225 virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; 226 227 virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; 228 229 virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0; 230 231 virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; 232 233 // Invoked to inform the visitor that the traversal has completed, and that 234 // the root was "root". 235 virtual Status FinishVisit(HloInstructionPtr root) = 0; 236 237 // 3 possible visitation states of HLO instructions. Each instruction's 238 // state only flows one way: kNotVisited -> kVisiting -> kVisited. 239 enum VisitState { 240 kNotVisited = 0, 241 kVisiting = 1, 242 kVisited = 2, 243 }; 244 GetVisitState(int id)245 VisitState GetVisitState(int id) { return visit_state_.GetState(id); } 246 VisitState GetVisitState(const HloInstruction& instruction); 247 248 // Resize internal state if necessary to hold state for ids <= num. 249 // This call is purely a performance hint and can be omitted without 250 // affecting correctness. ReserveVisitStates(int num)251 void ReserveVisitStates(int num) { visit_state_.Reserve(num); } 252 253 // Useful when we want to visit the same computation more than once with the 254 // same visitor. ResetVisitStates()255 void ResetVisitStates() { visit_state_.Reset(); } 256 SetVisitState(int id,VisitState state)257 void SetVisitState(int id, VisitState state) { 258 visit_state_.SetState(id, state); 259 } 260 261 // Sets the visitation state of the given instruction as kVisiting. 262 // 263 // Precondition: current state must be kNotVisited. 264 void SetVisiting(const HloInstruction& instruction); 265 266 // Sets the visitation state of the given instruction as kVisited. 267 // 268 // Precondition: current state must be either kNotVisited or kVisiting. 269 void SetVisited(const HloInstruction& instruction); 270 271 // Returns whether the state of the given instruction is kVisiting. IsVisiting(const HloInstruction & instruction)272 bool IsVisiting(const HloInstruction& instruction) { 273 return GetVisitState(instruction) == kVisiting; 274 } 275 276 // Returns whether the state of the given instruction is kVisited. DidVisit(const HloInstruction & instruction)277 bool DidVisit(const HloInstruction& instruction) { 278 return GetVisitState(instruction) == kVisited; 279 } 280 281 // Returns whether the state of the given instruction is kNotVisited. NotVisited(const HloInstruction & instruction)282 bool NotVisited(const HloInstruction& instruction) { 283 return GetVisitState(instruction) == kNotVisited; 284 } 285 286 // This method should be overridden by subclasses that wish to run some 287 // operation on an op before its Handle* visitor method is called. 288 // 289 // For any HLO op, the order of calls is: 290 // 291 // Preprocess(op); 292 // Handle/OpType/(op); 293 // Postprocess(op); 294 // 295 // Overriding methods should call DfsHloVisitor::Preprocess before doing their 296 // own preprocessing. 297 virtual Status Preprocess(HloInstructionPtr hlo); 298 299 // This method should be overridden by subclasses that wish to run some 300 // operation on an op after its Handle* visitor method is called. See 301 // Preprocess for more details. 302 // 303 // Overriding methods should call DfsHloVisitor::Postprocess after doing their 304 // own postprocessing. 305 virtual Status Postprocess(HloInstructionPtr hlo); 306 307 private: 308 class DFSVisitStates { 309 public: DFSVisitStates()310 DFSVisitStates() {} Reserve(uint64 num)311 void Reserve(uint64 num) { 312 states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord); 313 } GetState(uint64 id)314 VisitState GetState(uint64 id) { 315 uint64 word_index = id / kStatesPerWord; 316 if (word_index >= states_.size()) { 317 return VisitState::kNotVisited; 318 } 319 static_assert(static_cast<int>(VisitState::kVisited) < 3, 320 "VisitState must fit in two bits"); 321 uint64 w = states_[word_index]; 322 uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state 323 return static_cast<VisitState>((w >> shift) & 0x3); 324 } SetState(uint64 id,VisitState state)325 void SetState(uint64 id, VisitState state) { 326 uint64 word_index = id / kStatesPerWord; 327 if (word_index >= states_.size()) { 328 states_.resize(word_index + 1, 0); 329 } 330 uint64* w = &states_[word_index]; 331 uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state 332 uint64 mask = 0x3ull << shift; 333 *w = (*w & ~mask) | (static_cast<uint64>(state) << shift); 334 DCHECK_EQ(GetState(id), state); 335 } Reset()336 void Reset() { states_.clear(); } 337 338 private: 339 static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; 340 // Map from id to two-bit states. We store 32 such states per 64-bit 341 // value 342 std::vector<uint64> states_; 343 }; 344 345 DFSVisitStates visit_state_; 346 347 TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); 348 }; 349 350 // Users should use one of these two type aliases, which are the only two valid 351 // instantiations of DfsHloVisitorBase. 352 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>; 353 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>; 354 355 } // namespace xla 356 357 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 358