1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/inlined_vector.h" 21 #include "tensorflow/compiler/xla/service/call_inliner.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 24 namespace xla { 25 class WhileUtil { 26 public: 27 // Holds a return value from MakeInstructionsLiveIn. 28 struct MakeInstructionsLiveInResult { 29 // The new while operation that has the requested values live in. 30 HloInstruction* new_while_instr; 31 32 // The new tuple instruction that replaced the original while instruction 33 // with the same shape. 34 HloInstruction* replacement_instr; 35 36 // The i'th element of `while_body_live_in_values` is an instruction in the 37 // while body that holds the i'th *newly added* live in value at runtime. 38 std::vector<HloInstruction*> while_body_live_in_values; 39 40 // `while_body_instruction_map` maps instructions in the original while body 41 // to the corresponding instructions in the body for the newly created while 42 // operation. 43 CallInliner::InlinedInstructionMap while_body_instruction_map; 44 }; 45 46 // Replaces `while_instr` with a new while instruction that is equivalent to 47 // `while_instr` except that it has all of the HLO instructions in 48 // `instructions` as live-in, loop invariant values. These new live in values 49 // are represented as new elements appended to the parameter of the while 50 // loop, which must be of tuple shape. GetTupleElement instructions computing 51 // each new live in value is returned in the `while_body_live_in_values` 52 // vector. 53 // 54 // Deletes `while_instr` after replacing it. 55 // 56 // Preconditions: 57 // 58 // `while_instr` must have a tuple shaped state. 59 // 60 // Every instruction in `instructions` must be contained in the computation 61 // that contains `while_instr`. 62 static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn( 63 HloInstruction* while_instr, 64 absl::Span<HloInstruction* const> instructions); 65 66 using LoopStateTy = std::vector<HloInstruction*>; 67 using LoopBodyGeneratorTy = std::function<StatusOr<LoopStateTy>( 68 HloInstruction* /*induction_var*/, 69 const LoopStateTy& /*current_values*/)>; 70 71 // Creates a while loop in `computation` that runs for `trip_count` 72 // iterations. The structure of the while loop is as follows, in pseudocode: 73 // 74 // loop_state while_loop() { 75 // indvar = 0; 76 // loop_state = init_values 77 // while (indvar < trip_count) { 78 // loop_state = loop_body_generator(loop_state) 79 // indvar++; 80 // } 81 // return loop_state; 82 // } 83 static StatusOr<LoopStateTy> MakeCountedLoop( 84 HloComputation* computation, int32_t trip_count, 85 const LoopStateTy& init_values, 86 const LoopBodyGeneratorTy& loop_body_generator, 87 const OpMetadata& metadata); 88 89 struct OwningLoopStateTy { 90 std::vector<std::unique_ptr<HloInstruction>> instructions_to_add; 91 WhileUtil::LoopStateTy while_results; 92 }; 93 // As above but does not add the while loop or other instructions created 94 // around it in any particular computation. The caller can instead add it to a 95 // computation of their choosing. 96 static StatusOr<OwningLoopStateTy> MakeCountedLoop( 97 HloModule* module, int32_t trip_count, 98 const WhileUtil::LoopStateTy& init_values, 99 const WhileUtil::LoopBodyGeneratorTy& loop_body_generator, 100 const OpMetadata& metadata); 101 102 // Returns the GetTupleElement instructions in `while_body` that access 103 // elements in the parameter tuple that don't change across iterations. 104 // Assumes `while_body` is the body computation of the while loop in question. 105 static std::vector<HloInstruction*> GetInvariantGTEsForWhileBody( 106 const HloComputation& while_body); 107 108 // Returns a map of index to GetTupleElement instructions in 109 // `while_conditional` that access elements in the parameter tuple. Assumes 110 // `while_conditional` is the conditional computation of the while loop in 111 // question. 112 static absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>> 113 GetGTEsMapForWhileConditional(const HloComputation& while_conditional); 114 }; 115 } // namespace xla 116 117 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ 118