• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
17 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
18 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
19 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
22 
23 namespace xla {
24 namespace llvm_ir {
25 
MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction * instr)26 bool MayBeImplementedAsInPlaceDynamicUpdateSlice(const HloInstruction* instr) {
27   // Today we can't emit a dynamic-update-slice if the DUS node is parallelized;
28   // the emitter will not emit correct code.  It's possible to change this, but
29   // then ParallelTaskAssigner would have to somehow know whether a node *will*
30   // be emitted as an in-place DUS, and it can't, because it doesn't have a
31   // buffer assignment when it runs.
32   if (!instr->outer_dimension_partitions().empty()) {
33     return false;
34   }
35 
36   // Until we know the final buffer assignment, any unfused dynamic-update-slice
37   // might be implementable as an in-place DUS.
38   if (instr->opcode() == HloOpcode::kDynamicUpdateSlice) {
39     return true;
40   }
41 
42   // A fusion may be implementable as an in-place dynamic update slice if
43   //  - it's a loop fusion,
44   //  - dynamic-update-slice is the root of the fusion, and
45   //  - operand 0 of the dynamic-update-slice is a parameter to the fusion
46   //    (ignoring any get-tuple-element operations in the way).
47   if (instr->IsLoopFusion()) {
48     const HloInstruction* fused_root = instr->fused_expression_root();
49     return fused_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
50            fused_root->operand(0)->LatestNonGteAncestor()->opcode() ==
51                HloOpcode::kParameter;
52   }
53 
54   return false;
55 }
56 
CanUpdateDynamicSliceInPlace(HloInstruction * dynamic_update_slice,const BufferAssignment & assignment)57 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
58                                   const BufferAssignment& assignment) {
59   CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
60   const HloInstruction* operand = dynamic_update_slice->operand(0);
61   return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
62          assignment.HasTopLevelAllocation(operand) &&
63          assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
64 }
65 
CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,const BufferAssignment & assignment)66 bool CanEmitFusedDynamicUpdateSliceInPlace(HloInstruction* fusion,
67                                            const BufferAssignment& assignment) {
68   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
69   if (!MayBeImplementedAsInPlaceDynamicUpdateSlice(fusion)) {
70     return false;
71   }
72 
73   // Walk DynamicUpdateSlice operand(0) to fused parameter and get its
74   // associated operand. See if it shares an allocation with this operand.
75   HloInstruction* fused_root = fusion->fused_expression_root();
76   HloInstruction* fusion_operand;
77   ShapeIndex index;
78   std::tie(fusion_operand, index) =
79       fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex();
80   // MayBeImplementedAsInPlaceDynamicUpdateSlice should have ensured that
81   // fusion_operand is a parameter.
82   CHECK_EQ(fusion_operand->opcode(), HloOpcode::kParameter);
83   auto* operand = fusion->operand(fusion_operand->parameter_number());
84   return assignment.HasAllocationAt(operand, index) &&
85          assignment.HasAllocationAt(fusion, {}) &&
86          assignment.SharesSliceAtIndex(fusion, {}, operand, index);
87 }
88 
89 // Shared implementation of EmitDynamicUpdateSliceInPlace and
90 // EmitFusedDynamicUpdateSliceInPlace.
91 //
92 // Emits a sequential loop if launch_dimensions is null.
93 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>;
94 
EmitDynamicUpdateSliceInPlaceImpl(const Shape & update_shape,const IndexGenerator & start_indices_generator,bool is_signed,ElementGenerator update_array_generator,const IrArray & output_array,const gpu::LaunchDimensions * launch_dimensions,absl::string_view name,llvm::IRBuilder<> * b)95 static Status EmitDynamicUpdateSliceInPlaceImpl(
96     const Shape& update_shape, const IndexGenerator& start_indices_generator,
97     bool is_signed, ElementGenerator update_array_generator,
98     const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
99     absl::string_view name, llvm::IRBuilder<>* b) {
100   const Shape& output_shape = output_array.GetShape();
101 
102   // Read start indices from start_indices_generator.
103   const int64 rank = output_shape.rank();
104   std::vector<llvm::Value*> start_multi_index(rank);
105   for (int64 i = 0; i < rank; ++i) {
106     TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i));
107     llvm::Value* output_dim_size = llvm::ConstantInt::get(
108         start_multi_index[i]->getType(), output_shape.dimensions(i));
109     llvm::Value* update_dim_size = llvm::ConstantInt::get(
110         start_multi_index[i]->getType(), update_shape.dimensions(i));
111 
112     // Clamp the start index so that the update region fits in the operand.
113     // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
114     llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size);
115     llvm::Value* zero =
116         llvm::ConstantInt::get(start_multi_index[i]->getType(), 0);
117     start_multi_index[i] =
118         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
119                                                 : llvm::ICmpInst::ICMP_UGE,
120                                       zero, start_multi_index[i]),
121                         zero, start_multi_index[i]);
122 
123     start_multi_index[i] =
124         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
125                                                 : llvm::ICmpInst::ICMP_ULE,
126                                       max_bound, start_multi_index[i]),
127                         max_bound, start_multi_index[i]);
128   }
129 
130   auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
131     // Calculate output_index, where we'll write the value from update.  For
132     // each dimension,
133     //
134     //   output_index[dim] = start_index[dim] + update_index[dim]
135     //
136     std::vector<llvm::Value*> output_multi_index(rank);
137     for (int64 i = 0; i < rank; ++i) {
138       llvm::Value* start_index0 = b->CreateSExtOrBitCast(
139           start_multi_index[i], update_index[i]->getType());
140       output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]);
141     }
142 
143     // Do output[output_index] = update[update_index].
144     IrArray::Index output_index(output_multi_index, output_shape,
145                                 b->getInt64Ty());
146     TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
147                         update_array_generator(update_index));
148     output_array.EmitWriteArrayElement(output_index, update_data, b);
149     return Status::OK();
150   };
151 
152   if (launch_dimensions != nullptr) {
153     return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
154                                     *launch_dimensions, b)
155         .EmitLoop(name);
156   }
157   return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
158 }
159 
EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,const IrArray & output_array,absl::string_view name,llvm::IRBuilder<> * b)160 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
161                                      const IrArray& output_array,
162                                      absl::string_view name,
163                                      llvm::IRBuilder<>* b) {
164   VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
165 
166   // No need to use operand_arrays[0], the input array of the
167   // dynamic-update-slice, because we know it aliases the op's output.
168   IrArray update_array = operand_arrays[1];
169   IrArray start_indices_array = operand_arrays[2];
170   Shape output_shape = output_array.GetShape();
171   Shape update_shape = update_array.GetShape();
172 
173   IndexGenerator start_indices_generator = [&](int64 index) {
174     return operand_arrays[2 + index].EmitReadArrayElement(
175         IrArray::Index(b->getInt64Ty()), b);
176   };
177   ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
178     return update_array.EmitReadArrayElement(index, b);
179   };
180 
181   bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
182   return EmitDynamicUpdateSliceInPlaceImpl(
183       update_shape, start_indices_generator, is_signed, update_array_generator,
184       output_array, /*launch_dimensions=*/nullptr, name, b);
185 }
186 
187 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
188 // EmitParallelFusedDynamicUpdateSliceInPlace.
189 //
190 // Emits a sequential loop if launch_dimensions is null.
EmitFusedDynamicUpdateSliceInPlaceImpl(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,const gpu::LaunchDimensions * launch_dimensions,llvm::IRBuilder<> * b)191 static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
192     HloInstruction* fusion,
193     GeneratorForOperandIrArrays operand_arrays_generator,
194     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
195     const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
196   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
197   VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
198           << fusion->ToShortString();
199 
200   auto* dynamic_update_slice = fusion->fused_expression_root();
201 
202   const auto* update = dynamic_update_slice->operand(1);
203   const auto* start_indices = dynamic_update_slice->operand(2);
204   Shape update_shape = update->shape();
205 
206   // Our in-place dynamic-update-slice implementation emits a loop over
207   // update_shape.  To emit a cache-friendly loop, we need to know that shape's
208   // layout.
209   //
210   // update_shape is inside a fusion node -- it's never materialized in memory
211   // and thus doesn't have a layout.  In this case we use the layout of the
212   // fusion node for iteration, since that corresponds to the order in memory of
213   // the buffer we'll be writing to.
214   //
215   // (This isn't necessarily optimal; in some cases it might be faster to peek
216   // through the chain of ops that gives us the update operand and use the
217   // layout of its source buffer(s).  But this is no worse than we do with
218   // fusion elsewhere.)
219   TF_RETURN_IF_ERROR(
220       LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
221 
222   // Create element generators for update and start_indices.
223   FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
224                                elemental_emitter);
225   TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
226   ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
227 
228   IndexGenerator start_indices_generator = [&](int64 index) {
229     ElementGenerator element_generator =
230         fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
231     return element_generator(IrArray::Index(b->getInt64Ty()));
232   };
233   bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
234   return EmitDynamicUpdateSliceInPlaceImpl(
235       update_shape, start_indices_generator, is_signed, update_array_generator,
236       fusion_output_array, launch_dimensions, IrName(fusion), b);
237 }
238 
EmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,llvm::IRBuilder<> * b)239 Status EmitFusedDynamicUpdateSliceInPlace(
240     HloInstruction* fusion,
241     GeneratorForOperandIrArrays operand_arrays_generator,
242     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
243     llvm::IRBuilder<>* b) {
244   return EmitFusedDynamicUpdateSliceInPlaceImpl(
245       fusion, std::move(operand_arrays_generator), fusion_output_array,
246       elemental_emitter,
247       /*launch_dimensions=*/nullptr, b);
248 }
249 
EmitParallelFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,const gpu::LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b)250 Status EmitParallelFusedDynamicUpdateSliceInPlace(
251     HloInstruction* fusion,
252     GeneratorForOperandIrArrays operand_arrays_generator,
253     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
254     const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
255   return EmitFusedDynamicUpdateSliceInPlaceImpl(
256       fusion, std::move(operand_arrays_generator), fusion_output_array,
257       elemental_emitter, &launch_dimensions, b);
258 }
259 
260 }  // namespace llvm_ir
261 }  // namespace xla
262