1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
17 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
18 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
19 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
22
23 namespace xla {
24 namespace llvm_ir {
25
CanUpdateDynamicSliceInPlace(HloInstruction * dynamic_update_slice,const BufferAssignment & assignment)26 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
27 const BufferAssignment& assignment) {
28 CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
29 const HloInstruction* operand = dynamic_update_slice->operand(0);
30 return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
31 assignment.HasTopLevelAllocation(operand) &&
32 assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
33 }
34
35 // Shared implementation of EmitDynamicUpdateSliceInPlace and
36 // EmitFusedDynamicUpdateSliceInPlace.
37 //
38 // Emits a sequential loop if launch_dimensions is null.
39 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>;
40
EmitDynamicUpdateSliceInPlaceImpl(const Shape & update_shape,const IndexGenerator & start_indices_generator,bool is_signed,ElementGenerator update_array_generator,const IrArray & output_array,const gpu::LaunchDimensions * launch_dimensions,absl::string_view name,llvm::IRBuilder<> * b)41 static Status EmitDynamicUpdateSliceInPlaceImpl(
42 const Shape& update_shape, const IndexGenerator& start_indices_generator,
43 bool is_signed, ElementGenerator update_array_generator,
44 const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
45 absl::string_view name, llvm::IRBuilder<>* b) {
46 const Shape& output_shape = output_array.GetShape();
47
48 // Read start indices from start_indices_generator.
49 const int64 rank = output_shape.rank();
50 std::vector<llvm::Value*> start_multi_index(rank);
51 for (int64 i = 0; i < rank; ++i) {
52 TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i));
53 llvm::Value* output_dim_size = llvm::ConstantInt::get(
54 start_multi_index[i]->getType(), output_shape.dimensions(i));
55 llvm::Value* update_dim_size = llvm::ConstantInt::get(
56 start_multi_index[i]->getType(), update_shape.dimensions(i));
57
58 // Clamp the start index so that the update region fits in the operand.
59 // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
60 llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size);
61 llvm::Value* zero =
62 llvm::ConstantInt::get(start_multi_index[i]->getType(), 0);
63 start_multi_index[i] =
64 b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
65 : llvm::ICmpInst::ICMP_UGE,
66 zero, start_multi_index[i]),
67 zero, start_multi_index[i]);
68
69 start_multi_index[i] =
70 b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
71 : llvm::ICmpInst::ICMP_ULE,
72 max_bound, start_multi_index[i]),
73 max_bound, start_multi_index[i]);
74 }
75
76 auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
77 // Calculate output_index, where we'll write the value from update. For
78 // each dimension,
79 //
80 // output_index[dim] = start_index[dim] + update_index[dim]
81 //
82 std::vector<llvm::Value*> output_multi_index(rank);
83 for (int64 i = 0; i < rank; ++i) {
84 llvm::Value* start_index0 = b->CreateSExtOrBitCast(
85 start_multi_index[i], update_index[i]->getType());
86 output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]);
87 }
88
89 // Do output[output_index] = update[update_index].
90 IrArray::Index output_index(output_multi_index, output_shape,
91 b->getInt64Ty());
92 TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
93 update_array_generator(update_index));
94 output_array.EmitWriteArrayElement(output_index, update_data, b);
95 return Status::OK();
96 };
97
98 if (launch_dimensions != nullptr) {
99 return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
100 *launch_dimensions, b)
101 .EmitLoop(name);
102 }
103 return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
104 }
105
EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,const IrArray & output_array,absl::string_view name,llvm::IRBuilder<> * b)106 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
107 const IrArray& output_array,
108 absl::string_view name,
109 llvm::IRBuilder<>* b) {
110 VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
111
112 // No need to use operand_arrays[0], the input array of the
113 // dynamic-update-slice, because we know it aliases the op's output.
114 IrArray update_array = operand_arrays[1];
115 IrArray start_indices_array = operand_arrays[2];
116 Shape output_shape = output_array.GetShape();
117 Shape update_shape = update_array.GetShape();
118
119 IndexGenerator start_indices_generator = [&](int64 index) {
120 return operand_arrays[2 + index].EmitReadArrayElement(
121 IrArray::Index(b->getInt64Ty()), b);
122 };
123 ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
124 return update_array.EmitReadArrayElement(index, b);
125 };
126
127 bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
128 return EmitDynamicUpdateSliceInPlaceImpl(
129 update_shape, start_indices_generator, is_signed, update_array_generator,
130 output_array, /*launch_dimensions=*/nullptr, name, b);
131 }
132
133 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
134 // EmitParallelFusedDynamicUpdateSliceInPlace.
135 //
136 // Emits a sequential loop if launch_dimensions is null.
EmitFusedDynamicUpdateSliceInPlaceImpl(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,const gpu::LaunchDimensions * launch_dimensions,llvm::IRBuilder<> * b)137 static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
138 HloInstruction* fusion,
139 GeneratorForOperandIrArrays operand_arrays_generator,
140 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
141 const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
142 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
143 VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
144 << fusion->ToShortString();
145
146 auto* dynamic_update_slice = fusion->fused_expression_root();
147
148 const auto* update = dynamic_update_slice->operand(1);
149 const auto* start_indices = dynamic_update_slice->operand(2);
150 Shape update_shape = update->shape();
151
152 // Our in-place dynamic-update-slice implementation emits a loop over
153 // update_shape. To emit a cache-friendly loop, we need to know that shape's
154 // layout.
155 //
156 // update_shape is inside a fusion node -- it's never materialized in memory
157 // and thus doesn't have a layout. In this case we use the layout of the
158 // fusion node for iteration, since that corresponds to the order in memory of
159 // the buffer we'll be writing to.
160 //
161 // (This isn't necessarily optimal; in some cases it might be faster to peek
162 // through the chain of ops that gives us the update operand and use the
163 // layout of its source buffer(s). But this is no worse than we do with
164 // fusion elsewhere.)
165 TF_RETURN_IF_ERROR(
166 LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
167
168 // Create element generators for update and start_indices.
169 FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
170 elemental_emitter);
171 TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
172 ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
173
174 IndexGenerator start_indices_generator = [&](int64 index) {
175 ElementGenerator element_generator =
176 fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
177 return element_generator(IrArray::Index(b->getInt64Ty()));
178 };
179 bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
180 return EmitDynamicUpdateSliceInPlaceImpl(
181 update_shape, start_indices_generator, is_signed, update_array_generator,
182 fusion_output_array, launch_dimensions, IrName(fusion), b);
183 }
184
EmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,llvm::IRBuilder<> * b)185 Status EmitFusedDynamicUpdateSliceInPlace(
186 HloInstruction* fusion,
187 GeneratorForOperandIrArrays operand_arrays_generator,
188 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
189 llvm::IRBuilder<>* b) {
190 return EmitFusedDynamicUpdateSliceInPlaceImpl(
191 fusion, std::move(operand_arrays_generator), fusion_output_array,
192 elemental_emitter,
193 /*launch_dimensions=*/nullptr, b);
194 }
195
EmitParallelFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,const gpu::LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b)196 Status EmitParallelFusedDynamicUpdateSliceInPlace(
197 HloInstruction* fusion,
198 GeneratorForOperandIrArrays operand_arrays_generator,
199 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
200 const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
201 return EmitFusedDynamicUpdateSliceInPlaceImpl(
202 fusion, std::move(operand_arrays_generator), fusion_output_array,
203 elemental_emitter, &launch_dimensions, b);
204 }
205
206 } // namespace llvm_ir
207 } // namespace xla
208