1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 18 19 #include <type_traits> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/status.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 class HloComputation; 37 class HloInstruction; 38 39 // A postorder depth-first HloInstruction visitor. When Handle* is called on an 40 // instruction, all its operands were already visited. User code can subclass 41 // this to iterate over an HloInstruction DAG. The Handle* routines have 42 // operands / data unpacked for ease of use in the visitor subclass. 43 // 44 // No instruction will ever be visited twice; however, the root instruction will 45 // be reported again when the traversal is done via a call to FinishVisit. 46 // 47 // A subclass must override at least 48 // (either HandleElementwiseUnary or all the Handle methods for unary ops) and 49 // (either HandleElementwiseBinary or all the Handle methods for binary ops)). 50 // The default Handle methods for (unary, binary) ops call 51 // (HandleElementwiseUnary, HandleElementwiseBinary). 52 // The default (HandleElementwiseUnary, HandleElementwiseBinary) return an 53 // "unimplemented" error status. 54 // 55 // Note: this may change to an iterator in the future for flexibility purposes. 56 // 57 // Users should not use this class directly, but use the type-aliases 58 // DfsHloVisitor/ConstDfsHloVisitor instead. 59 template <typename HloInstructionPtr> 60 class DfsHloVisitorBase { 61 static_assert( 62 std::is_same<HloInstruction*, HloInstructionPtr>::value || 63 std::is_same<const HloInstruction*, HloInstructionPtr>::value, 64 "Template argument expected to be HloInstruction* or const " 65 "HloInstruction*"); 66 67 public: DfsHloVisitorBase()68 DfsHloVisitorBase() {} ~DfsHloVisitorBase()69 virtual ~DfsHloVisitorBase() {} 70 71 // These routines are self-descriptive, see class comment for usage 72 // information. 73 74 virtual Status HandleElementwiseUnary(HloInstructionPtr hlo); 75 virtual Status HandleElementwiseBinary(HloInstructionPtr hlo); 76 77 virtual Status HandleClamp(HloInstructionPtr hlo) = 0; 78 virtual Status HandleSelect(HloInstructionPtr hlo) = 0; 79 virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0; HandleMaximum(HloInstructionPtr hlo)80 virtual Status HandleMaximum(HloInstructionPtr hlo) { 81 return HandleElementwiseBinary(hlo); 82 } HandleMinimum(HloInstructionPtr hlo)83 virtual Status HandleMinimum(HloInstructionPtr hlo) { 84 return HandleElementwiseBinary(hlo); 85 } 86 virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0; HandleConvert(HloInstructionPtr hlo)87 virtual Status HandleConvert(HloInstructionPtr hlo) { 88 return HandleElementwiseUnary(hlo); 89 } HandleBitcastConvert(HloInstructionPtr hlo)90 virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { 91 return HandleElementwiseUnary(hlo); 92 } HandleCopy(HloInstructionPtr hlo)93 virtual Status HandleCopy(HloInstructionPtr hlo) { 94 return HandleElementwiseUnary(hlo); 95 } HandleComplex(HloInstructionPtr hlo)96 virtual Status HandleComplex(HloInstructionPtr hlo) { 97 return HandleElementwiseBinary(hlo); 98 } HandleMultiply(HloInstructionPtr hlo)99 virtual Status HandleMultiply(HloInstructionPtr hlo) { 100 return HandleElementwiseBinary(hlo); 101 } 102 virtual Status HandleDot(HloInstructionPtr hlo) = 0; HandlePower(HloInstructionPtr hlo)103 virtual Status HandlePower(HloInstructionPtr hlo) { 104 return HandleElementwiseBinary(hlo); 105 } HandleSqrt(HloInstructionPtr hlo)106 virtual Status HandleSqrt(HloInstructionPtr hlo) { 107 return HandleElementwiseUnary(hlo); 108 } HandleRsqrt(HloInstructionPtr hlo)109 virtual Status HandleRsqrt(HloInstructionPtr hlo) { 110 return HandleElementwiseUnary(hlo); 111 } HandleCbrt(HloInstructionPtr hlo)112 virtual Status HandleCbrt(HloInstructionPtr hlo) { 113 return HandleElementwiseUnary(hlo); 114 } 115 virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; 116 virtual Status HandleFft(HloInstructionPtr fft) = 0; 117 virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; 118 virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; 119 virtual Status HandleAllGather(HloInstructionPtr hlo) = 0; 120 virtual Status HandleAllGatherStart(HloInstructionPtr hlo) = 0; 121 virtual Status HandleAllGatherDone(HloInstructionPtr hlo) = 0; 122 virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; 123 virtual Status HandleReduceScatter(HloInstructionPtr hlo) = 0; 124 virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0; 125 virtual Status HandleAllReduceDone(HloInstructionPtr hlo) = 0; 126 virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; 127 virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; 128 virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; 129 virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; 130 virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; 131 virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; 132 virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; 133 virtual Status HandleSetDimensionSize(HloInstructionPtr hlo) = 0; HandleCompare(HloInstructionPtr hlo)134 virtual Status HandleCompare(HloInstructionPtr hlo) { 135 return HandleElementwiseBinary(hlo); 136 } HandleAdd(HloInstructionPtr hlo)137 virtual Status HandleAdd(HloInstructionPtr hlo) { 138 return HandleElementwiseBinary(hlo); 139 } HandleDivide(HloInstructionPtr hlo)140 virtual Status HandleDivide(HloInstructionPtr hlo) { 141 return HandleElementwiseBinary(hlo); 142 } HandleRemainder(HloInstructionPtr hlo)143 virtual Status HandleRemainder(HloInstructionPtr hlo) { 144 return HandleElementwiseBinary(hlo); 145 } HandleSubtract(HloInstructionPtr hlo)146 virtual Status HandleSubtract(HloInstructionPtr hlo) { 147 return HandleElementwiseBinary(hlo); 148 } HandleAbs(HloInstructionPtr hlo)149 virtual Status HandleAbs(HloInstructionPtr hlo) { 150 return HandleElementwiseUnary(hlo); 151 } HandleAtan2(HloInstructionPtr hlo)152 virtual Status HandleAtan2(HloInstructionPtr hlo) { 153 return HandleElementwiseBinary(hlo); 154 } HandleRound(HloInstructionPtr hlo)155 virtual Status HandleRound(HloInstructionPtr hlo) { 156 return HandleElementwiseUnary(hlo); 157 } HandleLogistic(HloInstructionPtr hlo)158 virtual Status HandleLogistic(HloInstructionPtr hlo) { 159 return HandleElementwiseUnary(hlo); 160 } HandleSign(HloInstructionPtr hlo)161 virtual Status HandleSign(HloInstructionPtr hlo) { 162 return HandleElementwiseUnary(hlo); 163 } HandleNegate(HloInstructionPtr hlo)164 virtual Status HandleNegate(HloInstructionPtr hlo) { 165 return HandleElementwiseUnary(hlo); 166 } HandleExp(HloInstructionPtr hlo)167 virtual Status HandleExp(HloInstructionPtr hlo) { 168 return HandleElementwiseUnary(hlo); 169 } HandleExpm1(HloInstructionPtr hlo)170 virtual Status HandleExpm1(HloInstructionPtr hlo) { 171 return HandleElementwiseUnary(hlo); 172 } HandleFloor(HloInstructionPtr hlo)173 virtual Status HandleFloor(HloInstructionPtr hlo) { 174 return HandleElementwiseUnary(hlo); 175 } HandleCeil(HloInstructionPtr hlo)176 virtual Status HandleCeil(HloInstructionPtr hlo) { 177 return HandleElementwiseUnary(hlo); 178 } HandleLog(HloInstructionPtr hlo)179 virtual Status HandleLog(HloInstructionPtr hlo) { 180 return HandleElementwiseUnary(hlo); 181 } HandleClz(HloInstructionPtr hlo)182 virtual Status HandleClz(HloInstructionPtr hlo) { 183 return HandleElementwiseUnary(hlo); 184 } HandleLog1p(HloInstructionPtr hlo)185 virtual Status HandleLog1p(HloInstructionPtr hlo) { 186 return HandleElementwiseUnary(hlo); 187 } HandleCos(HloInstructionPtr hlo)188 virtual Status HandleCos(HloInstructionPtr hlo) { 189 return HandleElementwiseUnary(hlo); 190 } HandleSin(HloInstructionPtr hlo)191 virtual Status HandleSin(HloInstructionPtr hlo) { 192 return HandleElementwiseUnary(hlo); 193 } HandleTanh(HloInstructionPtr hlo)194 virtual Status HandleTanh(HloInstructionPtr hlo) { 195 return HandleElementwiseUnary(hlo); 196 } HandleReal(HloInstructionPtr hlo)197 virtual Status HandleReal(HloInstructionPtr hlo) { 198 return HandleElementwiseUnary(hlo); 199 } HandleImag(HloInstructionPtr hlo)200 virtual Status HandleImag(HloInstructionPtr hlo) { 201 return HandleElementwiseUnary(hlo); 202 } HandleIsFinite(HloInstructionPtr hlo)203 virtual Status HandleIsFinite(HloInstructionPtr hlo) { 204 return HandleElementwiseUnary(hlo); 205 } HandleAnd(HloInstructionPtr hlo)206 virtual Status HandleAnd(HloInstructionPtr hlo) { 207 return HandleElementwiseBinary(hlo); 208 } HandleNot(HloInstructionPtr hlo)209 virtual Status HandleNot(HloInstructionPtr hlo) { 210 return HandleElementwiseUnary(hlo); 211 } HandleOr(HloInstructionPtr hlo)212 virtual Status HandleOr(HloInstructionPtr hlo) { 213 return HandleElementwiseBinary(hlo); 214 } HandleXor(HloInstructionPtr hlo)215 virtual Status HandleXor(HloInstructionPtr hlo) { 216 return HandleElementwiseBinary(hlo); 217 } HandlePopulationCount(HloInstructionPtr hlo)218 virtual Status HandlePopulationCount(HloInstructionPtr hlo) { 219 return HandleElementwiseUnary(hlo); 220 } HandleShiftLeft(HloInstructionPtr hlo)221 virtual Status HandleShiftLeft(HloInstructionPtr hlo) { 222 return HandleElementwiseBinary(hlo); 223 } HandleShiftRightArithmetic(HloInstructionPtr hlo)224 virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) { 225 return HandleElementwiseBinary(hlo); 226 } HandleShiftRightLogical(HloInstructionPtr hlo)227 virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) { 228 return HandleElementwiseBinary(hlo); 229 } 230 HandleReducePrecision(HloInstructionPtr hlo)231 virtual Status HandleReducePrecision(HloInstructionPtr hlo) { 232 return HandleElementwiseUnary(hlo); 233 } 234 HandleDomain(HloInstructionPtr hlo)235 virtual Status HandleDomain(HloInstructionPtr hlo) { 236 return HandleElementwiseUnary(hlo); 237 } 238 239 virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; 240 virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; 241 virtual Status HandleRng(HloInstructionPtr hlo) = 0; 242 virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0; 243 virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0; 244 virtual Status HandleReverse(HloInstructionPtr hlo) = 0; 245 virtual Status HandleSort(HloInstructionPtr hlo) = 0; 246 virtual Status HandleConstant(HloInstructionPtr hlo) = 0; 247 virtual Status HandleIota(HloInstructionPtr hlo) = 0; 248 virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; 249 virtual Status HandleReduce(HloInstructionPtr hlo) = 0; 250 virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; 251 virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; 252 virtual Status HandleReshape(HloInstructionPtr hlo) = 0; 253 virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0; 254 virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; 255 virtual Status HandleParameter(HloInstructionPtr hlo) = 0; 256 virtual Status HandleFusion(HloInstructionPtr hlo) = 0; 257 virtual Status HandleCall(HloInstructionPtr hlo) = 0; 258 virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; 259 virtual Status HandleSlice(HloInstructionPtr hlo) = 0; 260 virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; 261 virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; 262 virtual Status HandleTuple(HloInstructionPtr hlo) = 0; 263 virtual Status HandleMap(HloInstructionPtr hlo) = 0; 264 virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; 265 virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; 266 virtual Status HandleWhile(HloInstructionPtr hlo) = 0; 267 virtual Status HandleConditional(HloInstructionPtr hlo) = 0; 268 virtual Status HandleGather(HloInstructionPtr hlo) = 0; 269 virtual Status HandleScatter(HloInstructionPtr hlo) = 0; 270 271 virtual Status HandlePad(HloInstructionPtr hlo) = 0; 272 273 virtual Status HandleCopyStart(HloInstructionPtr copy_start) = 0; 274 virtual Status HandleCopyDone(HloInstructionPtr copy_done) = 0; 275 276 virtual Status HandleSend(HloInstructionPtr send) = 0; 277 virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; 278 279 virtual Status HandleRecv(HloInstructionPtr recv) = 0; 280 virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; 281 282 virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; 283 284 virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0; 285 286 virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; 287 288 virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0; 289 virtual Status HandleAfterAll(HloInstructionPtr token) = 0; 290 291 // Invoked to inform the visitor that the traversal has completed, and that 292 // the root was "root". 293 virtual Status FinishVisit(HloInstructionPtr root) = 0; 294 295 // 3 possible visitation states of HLO instructions. Each instruction's 296 // state only flows one way: kNotVisited -> kVisiting -> kVisited. 297 enum VisitState { 298 kNotVisited = 0, 299 kVisiting = 1, 300 kVisited = 2, 301 }; 302 GetVisitState(int id)303 VisitState GetVisitState(int id) { 304 auto iter = visit_state_.find(id); 305 if (iter == visit_state_.end()) { 306 return VisitState::kNotVisited; 307 } 308 return iter->second; 309 } 310 VisitState GetVisitState(const HloInstruction& instruction); 311 312 // Resize internal state if necessary to hold state for ids <= num. 313 // This call is purely a performance hint and can be omitted without 314 // affecting correctness. ReserveVisitStates(int num)315 void ReserveVisitStates(int num) { visit_state_.reserve(num); } VisitStateCapacity()316 size_t VisitStateCapacity() const { return visit_state_.capacity(); } 317 318 // Useful when we want to visit the same computation more than once with the 319 // same visitor. ResetVisitStates()320 void ResetVisitStates() { 321 // Clear the map, but don't resize the capacity across uses -- Calculating 322 // and reserving space could be expensive, and we always use the same 323 // module->instruction_count() as the capacity. 324 visit_state_.erase(visit_state_.begin(), visit_state_.end()); 325 } 326 327 // Useful when we want to free up the memory used by the visit state without 328 // destroying the actual visitor subclass. DestroyVisitState()329 void DestroyVisitState() { 330 visit_state_ = absl::flat_hash_map<int, VisitState>{}; 331 } 332 SetVisitState(int id,VisitState state)333 void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } 334 335 // Sets the visitation state of the given instruction as kVisiting. 336 // 337 // Precondition: current state must be kNotVisited. 338 void SetVisiting(const HloInstruction& instruction); 339 340 // Sets the visitation state of the given instruction as kVisited. 341 // 342 // Precondition: current state must be either kNotVisited or kVisiting. 343 void SetVisited(const HloInstruction& instruction); 344 345 // Returns whether the state of the given instruction is kVisiting. IsVisiting(const HloInstruction & instruction)346 bool IsVisiting(const HloInstruction& instruction) { 347 return GetVisitState(instruction) == kVisiting; 348 } 349 350 // Returns whether the state of the given instruction is kVisited. DidVisit(const HloInstruction & instruction)351 bool DidVisit(const HloInstruction& instruction) { 352 return GetVisitState(instruction) == kVisited; 353 } 354 355 // Returns whether the state of the given instruction is kNotVisited. NotVisited(const HloInstruction & instruction)356 bool NotVisited(const HloInstruction& instruction) { 357 return GetVisitState(instruction) == kNotVisited; 358 } 359 360 // This method should be overridden by subclasses that wish to run some 361 // operation on an op before its Handle* visitor method is called. 362 // 363 // For any HLO op, the order of calls is: 364 // 365 // Preprocess(op); 366 // Handle/OpType/(op); 367 // Postprocess(op); 368 // 369 // Overriding methods should call DfsHloVisitor::Preprocess before doing their 370 // own preprocessing. 371 virtual Status Preprocess(HloInstructionPtr hlo); 372 373 // This method should be overridden by subclasses that wish to run some 374 // operation on an op after its Handle* visitor method is called. See 375 // Preprocess for more details. 376 // 377 // Overriding methods should call DfsHloVisitor::Postprocess after doing their 378 // own postprocessing. 379 virtual Status Postprocess(HloInstructionPtr hlo); 380 381 private: 382 absl::flat_hash_map<int, VisitState> visit_state_; 383 384 TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); 385 }; 386 387 // Explicit instantiations in dfs_hlo_visitor.cc. 388 extern template class DfsHloVisitorBase<HloInstruction*>; 389 extern template class DfsHloVisitorBase<const HloInstruction*>; 390 391 // Users should use one of these two type aliases, which are the only two valid 392 // instantiations of DfsHloVisitorBase. 393 using DfsHloVisitor = DfsHloVisitorBase<HloInstruction*>; 394 using ConstDfsHloVisitor = DfsHloVisitorBase<const HloInstruction*>; 395 396 } // namespace xla 397 398 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ 399