• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
17 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
18 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
19 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
20 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
22 
23 namespace xla {
24 namespace llvm_ir {
25 
CanUpdateDynamicSliceInPlace(HloInstruction * dynamic_update_slice,const BufferAssignment & assignment)26 bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
27                                   const BufferAssignment& assignment) {
28   CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
29   const HloInstruction* operand = dynamic_update_slice->operand(0);
30   return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
31          assignment.HasTopLevelAllocation(operand) &&
32          assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
33 }
34 
35 // Shared implementation of EmitDynamicUpdateSliceInPlace and
36 // EmitFusedDynamicUpdateSliceInPlace.
37 //
38 // Emits a sequential loop if launch_dimensions is null.
39 using IndexGenerator = std::function<StatusOr<llvm::Value*>(int64)>;
40 
EmitDynamicUpdateSliceInPlaceImpl(const Shape & update_shape,const IndexGenerator & start_indices_generator,bool is_signed,ElementGenerator update_array_generator,const IrArray & output_array,const gpu::LaunchDimensions * launch_dimensions,absl::string_view name,llvm::IRBuilder<> * b)41 static Status EmitDynamicUpdateSliceInPlaceImpl(
42     const Shape& update_shape, const IndexGenerator& start_indices_generator,
43     bool is_signed, ElementGenerator update_array_generator,
44     const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions,
45     absl::string_view name, llvm::IRBuilder<>* b) {
46   const Shape& output_shape = output_array.GetShape();
47 
48   // Read start indices from start_indices_generator.
49   const int64 rank = output_shape.rank();
50   std::vector<llvm::Value*> start_multi_index(rank);
51   for (int64 i = 0; i < rank; ++i) {
52     TF_ASSIGN_OR_RETURN(start_multi_index[i], start_indices_generator(i));
53     llvm::Value* output_dim_size = llvm::ConstantInt::get(
54         start_multi_index[i]->getType(), output_shape.dimensions(i));
55     llvm::Value* update_dim_size = llvm::ConstantInt::get(
56         start_multi_index[i]->getType(), update_shape.dimensions(i));
57 
58     // Clamp the start index so that the update region fits in the operand.
59     // start_index = clamp(start_index, 0, output_dim_size - update_dim_size)
60     llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size);
61     llvm::Value* zero =
62         llvm::ConstantInt::get(start_multi_index[i]->getType(), 0);
63     start_multi_index[i] =
64         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
65                                                 : llvm::ICmpInst::ICMP_UGE,
66                                       zero, start_multi_index[i]),
67                         zero, start_multi_index[i]);
68 
69     start_multi_index[i] =
70         b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
71                                                 : llvm::ICmpInst::ICMP_ULE,
72                                       max_bound, start_multi_index[i]),
73                         max_bound, start_multi_index[i]);
74   }
75 
76   auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
77     // Calculate output_index, where we'll write the value from update.  For
78     // each dimension,
79     //
80     //   output_index[dim] = start_index[dim] + update_index[dim]
81     //
82     std::vector<llvm::Value*> output_multi_index(rank);
83     for (int64 i = 0; i < rank; ++i) {
84       llvm::Value* start_index0 = b->CreateSExtOrBitCast(
85           start_multi_index[i], update_index[i]->getType());
86       output_multi_index[i] = b->CreateAdd(start_index0, update_index[i]);
87     }
88 
89     // Do output[output_index] = update[update_index].
90     IrArray::Index output_index(output_multi_index, output_shape,
91                                 b->getInt64Ty());
92     TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
93                         update_array_generator(update_index));
94     output_array.EmitWriteArrayElement(output_index, update_data, b);
95     return Status::OK();
96   };
97 
98   if (launch_dimensions != nullptr) {
99     return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
100                                     *launch_dimensions, b)
101         .EmitLoop(name);
102   }
103   return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name);
104 }
105 
EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,const IrArray & output_array,absl::string_view name,llvm::IRBuilder<> * b)106 Status EmitDynamicUpdateSliceInPlace(absl::Span<const IrArray> operand_arrays,
107                                      const IrArray& output_array,
108                                      absl::string_view name,
109                                      llvm::IRBuilder<>* b) {
110   VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
111 
112   // No need to use operand_arrays[0], the input array of the
113   // dynamic-update-slice, because we know it aliases the op's output.
114   IrArray update_array = operand_arrays[1];
115   IrArray start_indices_array = operand_arrays[2];
116   Shape output_shape = output_array.GetShape();
117   Shape update_shape = update_array.GetShape();
118 
119   IndexGenerator start_indices_generator = [&](int64 index) {
120     return operand_arrays[2 + index].EmitReadArrayElement(
121         IrArray::Index(b->getInt64Ty()), b);
122   };
123   ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
124     return update_array.EmitReadArrayElement(index, b);
125   };
126 
127   bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape());
128   return EmitDynamicUpdateSliceInPlaceImpl(
129       update_shape, start_indices_generator, is_signed, update_array_generator,
130       output_array, /*launch_dimensions=*/nullptr, name, b);
131 }
132 
133 // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
134 // EmitParallelFusedDynamicUpdateSliceInPlace.
135 //
136 // Emits a sequential loop if launch_dimensions is null.
EmitFusedDynamicUpdateSliceInPlaceImpl(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,const gpu::LaunchDimensions * launch_dimensions,llvm::IRBuilder<> * b)137 static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
138     HloInstruction* fusion,
139     GeneratorForOperandIrArrays operand_arrays_generator,
140     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
141     const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) {
142   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
143   VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
144           << fusion->ToShortString();
145 
146   auto* dynamic_update_slice = fusion->fused_expression_root();
147 
148   const auto* update = dynamic_update_slice->operand(1);
149   const auto* start_indices = dynamic_update_slice->operand(2);
150   Shape update_shape = update->shape();
151 
152   // Our in-place dynamic-update-slice implementation emits a loop over
153   // update_shape.  To emit a cache-friendly loop, we need to know that shape's
154   // layout.
155   //
156   // update_shape is inside a fusion node -- it's never materialized in memory
157   // and thus doesn't have a layout.  In this case we use the layout of the
158   // fusion node for iteration, since that corresponds to the order in memory of
159   // the buffer we'll be writing to.
160   //
161   // (This isn't necessarily optimal; in some cases it might be faster to peek
162   // through the chain of ops that gives us the update operand and use the
163   // layout of its source buffer(s).  But this is no worse than we do with
164   // fusion elsewhere.)
165   TF_RETURN_IF_ERROR(
166       LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
167 
168   // Create element generators for update and start_indices.
169   FusedIrEmitter fused_emitter(std::move(operand_arrays_generator),
170                                elemental_emitter);
171   TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
172   ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
173 
174   IndexGenerator start_indices_generator = [&](int64 index) {
175     ElementGenerator element_generator =
176         fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
177     return element_generator(IrArray::Index(b->getInt64Ty()));
178   };
179   bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
180   return EmitDynamicUpdateSliceInPlaceImpl(
181       update_shape, start_indices_generator, is_signed, update_array_generator,
182       fusion_output_array, launch_dimensions, IrName(fusion), b);
183 }
184 
EmitFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,llvm::IRBuilder<> * b)185 Status EmitFusedDynamicUpdateSliceInPlace(
186     HloInstruction* fusion,
187     GeneratorForOperandIrArrays operand_arrays_generator,
188     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
189     llvm::IRBuilder<>* b) {
190   return EmitFusedDynamicUpdateSliceInPlaceImpl(
191       fusion, std::move(operand_arrays_generator), fusion_output_array,
192       elemental_emitter,
193       /*launch_dimensions=*/nullptr, b);
194 }
195 
EmitParallelFusedDynamicUpdateSliceInPlace(HloInstruction * fusion,GeneratorForOperandIrArrays operand_arrays_generator,const IrArray & fusion_output_array,ElementalIrEmitter * elemental_emitter,const gpu::LaunchDimensions & launch_dimensions,llvm::IRBuilder<> * b)196 Status EmitParallelFusedDynamicUpdateSliceInPlace(
197     HloInstruction* fusion,
198     GeneratorForOperandIrArrays operand_arrays_generator,
199     const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
200     const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) {
201   return EmitFusedDynamicUpdateSliceInPlaceImpl(
202       fusion, std::move(operand_arrays_generator), fusion_output_array,
203       elemental_emitter, &launch_dimensions, b);
204 }
205 
206 }  // namespace llvm_ir
207 }  // namespace xla
208