1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 18 19 #include <type_traits> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/status.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 32 namespace xla { 33 34 class HloComputation; 35 class HloInstruction; 36 37 // A postorder depth-first HloInstruction visitor. When Handle* is called on an 38 // instruction, all its operands were already visited. User code can subclass 39 // this to iterate over an HloInstruction DAG. The Handle* routines have 40 // operands / data unpacked for ease of use in the visitor subclass. 41 // 42 // No instruction will ever be visited twice; however, the root instruction will 43 // be reported again when the traversal is done via a call to FinishVisit. 44 // 45 // A subclass must override at least 46 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and 47 // (either HandleElementwiseBinary or all the Handle methods for binary ops)). 48 // The default Handle methods for (unary, binary) ops call 49 // (HandleElementwiseUnary, HandleElementwiseBinary). 50 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an 51 // "unimplemented" error status. 52 // 53 // Note: this may change to an iterator in the future for flexibility purposes. 54 // 55 // Users should not use this class directly, but use the type-aliases 56 // DfsHloVisitor/ConstDfsHloVisitor instead. 57 template <typename HloInstructionPtr> 58 class DfsHloVisitorBase { 59 static_assert( 60 std::is_same<HloInstruction*, HloInstructionPtr>::value || 61 std::is_same<const HloInstruction*, HloInstructionPtr>::value, 62 "Template argument expected to be HloInstruction* or const " 63 "HloInstruction*"); 64 65 public: DfsHloVisitorBase()66 DfsHloVisitorBase() {} ~DfsHloVisitorBase()67 virtual ~DfsHloVisitorBase() {} 68 69 // These routines are self-descriptive, see class comment for usage 70 // information. 71 72 virtual Status HandleElementwiseUnary(HloInstructionPtr hlo); 73 virtual Status HandleElementwiseBinary(HloInstructionPtr hlo); 74 75 virtual Status HandleClamp(HloInstructionPtr hlo) = 0; 76 virtual Status HandleSelect(HloInstructionPtr hlo) = 0; HandleMaximum(HloInstructionPtr hlo)77 virtual Status HandleMaximum(HloInstructionPtr hlo) { 78 return HandleElementwiseBinary(hlo); 79 } HandleMinimum(HloInstructionPtr hlo)80 virtual Status HandleMinimum(HloInstructionPtr hlo) { 81 return HandleElementwiseBinary(hlo); 82 } 83 virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0; HandleConvert(HloInstructionPtr hlo)84 virtual Status HandleConvert(HloInstructionPtr hlo) { 85 return HandleElementwiseUnary(hlo); 86 } HandleBitcastConvert(HloInstructionPtr hlo)87 virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { 88 return HandleElementwiseUnary(hlo); 89 } HandleCopy(HloInstructionPtr hlo)90 virtual Status HandleCopy(HloInstructionPtr hlo) { 91 return HandleElementwiseUnary(hlo); 92 } HandleComplex(HloInstructionPtr hlo)93 virtual Status HandleComplex(HloInstructionPtr hlo) { 94 return HandleElementwiseBinary(hlo); 95 } HandleMultiply(HloInstructionPtr hlo)96 virtual Status HandleMultiply(HloInstructionPtr hlo) { 97 return HandleElementwiseBinary(hlo); 98 } 99 virtual Status HandleDot(HloInstructionPtr hlo) = 0; HandlePower(HloInstructionPtr hlo)100 virtual Status HandlePower(HloInstructionPtr hlo) { 101 return HandleElementwiseBinary(hlo); 102 } HandleSqrt(HloInstructionPtr hlo)103 virtual Status HandleSqrt(HloInstructionPtr hlo) { 104 return HandleElementwiseUnary(hlo); 105 } HandleRsqrt(HloInstructionPtr hlo)106 virtual Status HandleRsqrt(HloInstructionPtr hlo) { 107 return HandleElementwiseUnary(hlo); 108 } HandleCbrt(HloInstructionPtr hlo)109 virtual Status HandleCbrt(HloInstructionPtr hlo) { 110 return HandleElementwiseUnary(hlo); 111 } 112 virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; 113 virtual Status HandleFft(HloInstructionPtr fft) = 0; 114 virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; 115 virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; 116 virtual Status HandleOptimizationBarrier(HloInstructionPtr hlo) = 0; 117 virtual Status HandleAllGather(HloInstructionPtr hlo) = 0; 118 virtual Status HandleAllGatherStart(HloInstructionPtr hlo) = 0; 119 virtual Status HandleAllGatherDone(HloInstructionPtr hlo) = 0; 120 virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; 121 virtual Status HandleReduceScatter(HloInstructionPtr hlo) = 0; 122 virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0; 123 virtual Status HandleAllReduceDone(HloInstructionPtr hlo) = 0; 124 virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; 125 virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; 126 virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; 127 virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; 128 virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; 129 virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; 130 virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; 131 virtual Status HandleSetDimensionSize(HloInstructionPtr hlo) = 0; HandleCompare(HloInstructionPtr hlo)132 virtual Status HandleCompare(HloInstructionPtr hlo) { 133 return HandleElementwiseBinary(hlo); 134 } HandleAdd(HloInstructionPtr hlo)135 virtual Status HandleAdd(HloInstructionPtr hlo) { 136 return HandleElementwiseBinary(hlo); 137 } HandleDivide(HloInstructionPtr hlo)138 virtual Status HandleDivide(HloInstructionPtr hlo) { 139 return HandleElementwiseBinary(hlo); 140 } HandleRemainder(HloInstructionPtr hlo)141 virtual Status HandleRemainder(HloInstructionPtr hlo) { 142 return HandleElementwiseBinary(hlo); 143 } HandleSubtract(HloInstructionPtr hlo)144 virtual Status HandleSubtract(HloInstructionPtr hlo) { 145 return HandleElementwiseBinary(hlo); 146 } HandleAbs(HloInstructionPtr hlo)147 virtual Status HandleAbs(HloInstructionPtr hlo) { 148 return HandleElementwiseUnary(hlo); 149 } HandleAtan2(HloInstructionPtr hlo)150 virtual Status HandleAtan2(HloInstructionPtr hlo) { 151 return HandleElementwiseBinary(hlo); 152 } HandleRound(HloInstructionPtr hlo)153 virtual Status HandleRound(HloInstructionPtr hlo) { 154 return HandleElementwiseUnary(hlo); 155 } HandleRoundNearestEven(HloInstructionPtr hlo)156 virtual Status HandleRoundNearestEven(HloInstructionPtr hlo) { 157 return HandleElementwiseUnary(hlo); 158 } HandleLogistic(HloInstructionPtr hlo)159 virtual Status HandleLogistic(HloInstructionPtr hlo) { 160 return HandleElementwiseUnary(hlo); 161 } HandleSign(HloInstructionPtr hlo)162 virtual Status HandleSign(HloInstructionPtr hlo) { 163 return HandleElementwiseUnary(hlo); 164 } HandleNegate(HloInstructionPtr hlo)165 virtual Status HandleNegate(HloInstructionPtr hlo) { 166 return HandleElementwiseUnary(hlo); 167 } HandleExp(HloInstructionPtr hlo)168 virtual Status HandleExp(HloInstructionPtr hlo) { 169 return HandleElementwiseUnary(hlo); 170 } HandleExpm1(HloInstructionPtr hlo)171 virtual Status HandleExpm1(HloInstructionPtr hlo) { 172 return HandleElementwiseUnary(hlo); 173 } HandleFloor(HloInstructionPtr hlo)174 virtual Status HandleFloor(HloInstructionPtr hlo) { 175 return HandleElementwiseUnary(hlo); 176 } HandleCeil(HloInstructionPtr hlo)177 virtual Status HandleCeil(HloInstructionPtr hlo) { 178 return HandleElementwiseUnary(hlo); 179 } HandleLog(HloInstructionPtr hlo)180 virtual Status HandleLog(HloInstructionPtr hlo) { 181 return HandleElementwiseUnary(hlo); 182 } HandleClz(HloInstructionPtr hlo)183 virtual Status HandleClz(HloInstructionPtr hlo) { 184 return HandleElementwiseUnary(hlo); 185 } HandleLog1p(HloInstructionPtr hlo)186 virtual Status HandleLog1p(HloInstructionPtr hlo) { 187 return HandleElementwiseUnary(hlo); 188 } HandleCos(HloInstructionPtr hlo)189 virtual Status HandleCos(HloInstructionPtr hlo) { 190 return HandleElementwiseUnary(hlo); 191 } HandleSin(HloInstructionPtr hlo)192 virtual Status HandleSin(HloInstructionPtr hlo) { 193 return HandleElementwiseUnary(hlo); 194 } HandleTanh(HloInstructionPtr hlo)195 virtual Status HandleTanh(HloInstructionPtr hlo) { 196 return HandleElementwiseUnary(hlo); 197 } HandleReal(HloInstructionPtr hlo)198 virtual Status HandleReal(HloInstructionPtr hlo) { 199 return HandleElementwiseUnary(hlo); 200 } HandleImag(HloInstructionPtr hlo)201 virtual Status HandleImag(HloInstructionPtr hlo) { 202 return HandleElementwiseUnary(hlo); 203 } HandleIsFinite(HloInstructionPtr hlo)204 virtual Status HandleIsFinite(HloInstructionPtr hlo) { 205 return HandleElementwiseUnary(hlo); 206 } HandleAnd(HloInstructionPtr hlo)207 virtual Status HandleAnd(HloInstructionPtr hlo) { 208 return HandleElementwiseBinary(hlo); 209 } HandleNot(HloInstructionPtr hlo)210 virtual Status HandleNot(HloInstructionPtr hlo) { 211 return HandleElementwiseUnary(hlo); 212 } HandleOr(HloInstructionPtr hlo)213 virtual Status HandleOr(HloInstructionPtr hlo) { 214 return HandleElementwiseBinary(hlo); 215 } HandleXor(HloInstructionPtr hlo)216 virtual Status HandleXor(HloInstructionPtr hlo) { 217 return HandleElementwiseBinary(hlo); 218 } HandlePopulationCount(HloInstructionPtr hlo)219 virtual Status HandlePopulationCount(HloInstructionPtr hlo) { 220 return HandleElementwiseUnary(hlo); 221 } HandleShiftLeft(HloInstructionPtr hlo)222 virtual Status HandleShiftLeft(HloInstructionPtr hlo) { 223 return HandleElementwiseBinary(hlo); 224 } HandleShiftRightArithmetic(HloInstructionPtr hlo)225 virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) { 226 return HandleElementwiseBinary(hlo); 227 } HandleShiftRightLogical(HloInstructionPtr hlo)228 virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) { 229 return HandleElementwiseBinary(hlo); 230 } 231 HandleReducePrecision(HloInstructionPtr hlo)232 virtual Status HandleReducePrecision(HloInstructionPtr hlo) { 233 return HandleElementwiseUnary(hlo); 234 } 235 HandleDomain(HloInstructionPtr hlo)236 virtual Status HandleDomain(HloInstructionPtr hlo) { 237 return HandleElementwiseUnary(hlo); 238 } 239 240 virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; 241 virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; 242 virtual Status HandleRng(HloInstructionPtr hlo) = 0; 243 virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0; 244 virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0; 245 virtual Status HandleReverse(HloInstructionPtr hlo) = 0; 246 virtual Status HandleSort(HloInstructionPtr hlo) = 0; 247 virtual Status HandleConstant(HloInstructionPtr hlo) = 0; 248 virtual Status HandleIota(HloInstructionPtr hlo) = 0; 249 virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; 250 virtual Status HandleReduce(HloInstructionPtr hlo) = 0; 251 virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; 252 virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; 253 virtual Status HandleReshape(HloInstructionPtr hlo) = 0; 254 virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0; 255 virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; 256 virtual Status HandleParameter(HloInstructionPtr hlo) = 0; 257 virtual Status HandleFusion(HloInstructionPtr hlo) = 0; 258 virtual Status HandleCall(HloInstructionPtr hlo) = 0; 259 virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; 260 virtual Status HandleSlice(HloInstructionPtr hlo) = 0; 261 virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; 262 virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; 263 virtual Status HandleTuple(HloInstructionPtr hlo) = 0; 264 virtual Status HandleMap(HloInstructionPtr hlo) = 0; 265 virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; 266 virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; 267 virtual Status HandleWhile(HloInstructionPtr hlo) = 0; 268 virtual Status HandleConditional(HloInstructionPtr hlo) = 0; 269 virtual Status HandleGather(HloInstructionPtr hlo) = 0; 270 virtual Status HandleScatter(HloInstructionPtr hlo) = 0; 271 272 virtual Status HandlePad(HloInstructionPtr hlo) = 0; 273 274 virtual Status HandleAsyncStart(HloInstructionPtr hlo) = 0; 275 virtual Status HandleAsyncUpdate(HloInstructionPtr hlo) = 0; 276 virtual Status HandleAsyncDone(HloInstructionPtr hlo) = 0; 277 278 virtual Status HandleCopyStart(HloInstructionPtr copy_start) = 0; 279 virtual Status HandleCopyDone(HloInstructionPtr copy_done) = 0; 280 281 virtual Status HandleSend(HloInstructionPtr send) = 0; 282 virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; 283 284 virtual Status HandleRecv(HloInstructionPtr recv) = 0; 285 virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; 286 287 virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; 288 289 virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0; 290 291 virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; 292 293 virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0; 294 virtual Status HandleAfterAll(HloInstructionPtr token) = 0; 295 296 // Invoked to inform the visitor that the traversal has completed, and that 297 // the root was "root". 298 virtual Status FinishVisit(HloInstructionPtr root) = 0; 299 300 // 3 possible visitation states of HLO instructions. Each instruction's 301 // state only flows one way: kNotVisited -> kVisiting -> kVisited. 302 enum VisitState { 303 kNotVisited = 0, 304 kVisiting = 1, 305 kVisited = 2, 306 }; 307 GetVisitState(int id)308 VisitState GetVisitState(int id) { 309 auto iter = visit_state_.find(id); 310 if (iter == visit_state_.end()) { 311 return VisitState::kNotVisited; 312 } 313 return iter->second; 314 } 315 VisitState GetVisitState(const HloInstruction& instruction); 316 317 // Resize internal state if necessary to hold state for ids <= num. 318 // This call is purely a performance hint and can be omitted without 319 // affecting correctness. ReserveVisitStates(int num)320 void ReserveVisitStates(int num) { visit_state_.reserve(num); } VisitStateCapacity()321 size_t VisitStateCapacity() const { return visit_state_.capacity(); } 322 323 // Useful when we want to visit the same computation more than once with the 324 // same visitor. ResetVisitStates()325 void ResetVisitStates() { 326 // Clear the map, but don't resize the capacity across uses -- Calculating 327 // and reserving space could be expensive, and we always use the same 328 // module->instruction_count() as the capacity. 329 visit_state_.erase(visit_state_.begin(), visit_state_.end()); 330 } 331 332 // Useful when we want to free up the memory used by the visit state without 333 // destroying the actual visitor subclass. DestroyVisitState()334 void DestroyVisitState() { 335 visit_state_ = absl::flat_hash_map<int, VisitState>{}; 336 } 337 SetVisitState(int id,VisitState state)338 void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } 339 340 // Sets the visitation state of the given instruction as kVisiting. 341 // 342 // Precondition: current state must be kNotVisited. 343 void SetVisiting(const HloInstruction& instruction); 344 345 // Sets the visitation state of the given instruction as kVisited. 346 // 347 // Precondition: current state must be either kNotVisited or kVisiting. 348 void SetVisited(const HloInstruction& instruction); 349 350 // Returns whether the state of the given instruction is kVisiting. IsVisiting(const HloInstruction & instruction)351 bool IsVisiting(const HloInstruction& instruction) { 352 return GetVisitState(instruction) == kVisiting; 353 } 354 355 // Returns whether the state of the given instruction is kVisited. DidVisit(const HloInstruction & instruction)356 bool DidVisit(const HloInstruction& instruction) { 357 return GetVisitState(instruction) == kVisited; 358 } 359 360 // Returns whether the state of the given instruction is kNotVisited. NotVisited(const HloInstruction & instruction)361 bool NotVisited(const HloInstruction& instruction) { 362 return GetVisitState(instruction) == kNotVisited; 363 } 364 365 // This method should be overridden by subclasses that wish to run some 366 // operation on an op before its Handle* visitor method is called. 367 // 368 // For any HLO op, the order of calls is: 369 // 370 // Preprocess(op); 371 // Handle/OpType/(op); 372 // Postprocess(op); 373 // 374 // Overriding methods should call DfsHloVisitor::Preprocess before doing their 375 // own preprocessing. 376 virtual Status Preprocess(HloInstructionPtr hlo); 377 378 // This method should be overridden by subclasses that wish to run some 379 // operation on an op after its Handle* visitor method is called. See 380 // Preprocess for more details. 381 // 382 // Overriding methods should call DfsHloVisitor::Postprocess after doing their 383 // own postprocessing. 384 virtual Status Postprocess(HloInstructionPtr hlo); 385 386 private: 387 absl::flat_hash_map<int, VisitState> visit_state_; 388 389 DfsHloVisitorBase(const DfsHloVisitorBase&) = delete; 390 DfsHloVisitorBase& operator=(const DfsHloVisitorBase&) = delete; 391 }; 392 393 // Explicit instantiations in dfs_hlo_visitor.cc. 394 extern template class DfsHloVisitorBase<HloInstruction*>; 395 extern template class DfsHloVisitorBase<const HloInstruction*>; 396 397 // Users should use one of these two type aliases, which are the only two valid 398 // instantiations of DfsHloVisitorBase. 399 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>; 400 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>; 401 402 } // namespace xla 403 404 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 405