1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ 18 19 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 20 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 21 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 24 25 // Utilities related to emitting LLVM IR for various HLO ops. 26 27 namespace xla { 28 namespace llvm_ir { 29 30 using GeneratorForOperandIrArrays = 31 std::function<std::vector<llvm_ir::IrArray>()>; 32 33 // Determines whether the given instruction might be implemented as an 34 // in-place dynamic-update-slice after we have a buffer assignment. 35 // 36 // If this returns false, then CanUpdateDynamicSliceInPlace and 37 // CanEmitFusedDynamicUpdateSliceInPlace will also return false. 38 // 39 // This is useful if you want to check whether an instruction might be an 40 // in-place DUS during an HLO pass, at which point you don't have a buffer 41 // assignment. 42 // 43 // Note that simplifications to the HLO graph might change this function from 44 // returning false to returning true. Specifically, simplifying the contents of 45 // fusion nodes might cause a false->true transition. In general this isn't a 46 // problem by the time you're calling this function, but beware. 47 bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr); 48 49 // Checks if we can emit code for the given DynamicUpdateSlice node that updates 50 // its input in place. Returns true if the dynamic-update-slice's 51 // array-to-be-updated and output share the same BufferAllocation::Slice. 52 // 53 // dynamic_update_slice must be a DynamicUpdateSlice op. 54 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, 55 const BufferAssignment& assignment); 56 57 // Checks if the given fusion node is amenable to being implemented by 58 // EmitFusedDynamicUpdateSliceInPlace. 59 bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion, 60 const BufferAssignment& assignment); 61 62 // Emits IR for running the given dynamic-update-slice op in-place -- that is, 63 // where the input and output buffers share the same slice, so we can simply 64 // modify the input/output buffer without touching any of the other elements. 65 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays, 66 const IrArray& output_array, 67 absl::string_view name, 68 llvm::IRBuilder<>* b); 69 70 // Given a loop-fusion node whose root is a dynamic-update-slice op whose 71 // array-to-be-updated and output share the same buffer slice, emits 72 // (sequential) code for a fusion node that does the dynamic-update-slice in 73 // place. 74 Status EmitFusedDynamicUpdateSliceInPlace( 75 HloInstruction* fusion, 76 GeneratorForOperandIrArrays operand_arrays_generator, 77 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 78 llvm::IRBuilder<>* b); 79 80 // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with 81 // the given launch dimensions. 82 Status EmitParallelFusedDynamicUpdateSliceInPlace( 83 HloInstruction* fusion, 84 GeneratorForOperandIrArrays operand_arrays_generator, 85 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 86 const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); 87 88 } // namespace llvm_ir 89 } // namespace xla 90 91 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ 92