1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
17
18 #include <stddef.h>
19 #include <stdint.h>
20
21 #include <algorithm>
22 #include <iterator>
23 #include <limits>
24 #include <memory>
25 #include <utility>
26 #include <vector>
27
28 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "llvm/CodeGen/TargetRegisterInfo.h"
36 #include "llvm/CodeGen/TargetSubtargetInfo.h"
37 #include "llvm/IR/BasicBlock.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/GlobalVariable.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/Intrinsics.h"
42 #include "llvm/IR/IntrinsicsX86.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/Value.h"
45 #include "tensorflow/compiler/xla/layout_util.h"
46 #include "tensorflow/compiler/xla/map_util.h"
47 #include "tensorflow/compiler/xla/primitive_util.h"
48 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
49 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
50 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
51 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
52 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
53 #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
54 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
55 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
56 #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
57 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
58 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
59 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
60 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
61 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
62 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
63 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
64 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
65 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
66 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
67 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
68 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
69 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
70 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
71 #include "tensorflow/compiler/xla/shape_util.h"
72 #include "tensorflow/compiler/xla/status_macros.h"
73 #include "tensorflow/compiler/xla/types.h"
74 #include "tensorflow/compiler/xla/util.h"
75 #include "tensorflow/compiler/xla/window_util.h"
76 #include "tensorflow/compiler/xla/xla_data.pb.h"
77 #include "tensorflow/core/lib/core/bits.h"
78 #include "tensorflow/core/lib/core/errors.h"
79 #include "tensorflow/core/lib/math/math_util.h"
80 #include "tensorflow/core/platform/logging.h"
81
82 namespace xla {
83
84 namespace {
85 using llvm_ir::IrName;
86 using llvm_ir::SetToFirstInsertPoint;
87 } // namespace
88
89 namespace cpu {
90
IrEmitter(mlir::MLIRContext * mlir_context,const HloModule & hlo_module,const BufferAssignment & assignment,llvm::Module * llvm_module,std::unordered_map<const HloInstruction *,int64> instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> computation_to_profile_idx,const TargetMachineFeatures * target_machine_features,bool emit_code_for_msan)91 IrEmitter::IrEmitter(
92 mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
93 const BufferAssignment& assignment, llvm::Module* llvm_module,
94 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
95 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
96 const TargetMachineFeatures* target_machine_features,
97 bool emit_code_for_msan)
98 : assignment_(assignment),
99 module_(llvm_module),
100 arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
101 b_(llvm_module->getContext()),
102 mlir_context_(mlir_context),
103 instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
104 computation_to_profile_idx_(std::move(computation_to_profile_idx)),
105 alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
106 hlo_module_config_(hlo_module.config()),
107 is_top_level_computation_(false),
108 target_machine_features_(*target_machine_features),
109 emit_code_for_msan_(emit_code_for_msan) {
110 b_.setFastMathFlags(llvm_ir::GetCpuFastMathFlags(hlo_module_config_));
111 Status s = GatherComputationsByAllocationType(
112 &hlo_module, &thread_local_computations_, &global_computations_);
113 absl::c_sort(thread_local_computations_);
114 absl::c_sort(global_computations_);
115 TF_CHECK_OK(s) << "Should have failed buffer assignment.";
116 }
117
EmitThreadLocalFunctionEpilogue(HloComputation * computation)118 void IrEmitter::EmitThreadLocalFunctionEpilogue(HloComputation* computation) {
119 llvm::Argument* out_parameter = compute_function_->result_arg();
120 llvm_ir::IrArray root_value = GetIrArrayFor(computation->root_instruction());
121 const Shape& return_shape = computation->root_instruction()->shape();
122
123 if (ShapeUtil::IsScalar(return_shape)) {
124 llvm::Value* ret_value =
125 Load(root_value.GetBasePointer(), "load_ret_value");
126 Store(ret_value,
127 BitCast(out_parameter, root_value.GetBasePointer()->getType()));
128 } else {
129 CHECK(return_shape.IsTuple());
130
131 llvm::Type* tuple_type = llvm_ir::ShapeToIrType(return_shape, module_);
132 llvm::Type* tuple_type_lvalue = tuple_type->getPointerTo();
133 llvm::Value* tuple_lvalue = BitCast(out_parameter, tuple_type_lvalue);
134
135 for (int i = 0; i < return_shape.tuple_shapes_size(); i++) {
136 const Shape& element_shape = return_shape.tuple_shapes(i);
137 llvm::Value* destination = llvm_ir::EmitGetTupleElement(
138 element_shape,
139 /*index=*/i,
140 /*alignment=*/MinimumAlignmentForShape(element_shape), tuple_lvalue,
141 &b_);
142
143 llvm::Value* source = llvm_ir::EmitGetTupleElement(
144 element_shape,
145 /*index=*/i,
146 /*alignment=*/MinimumAlignmentForShape(element_shape),
147 root_value.GetBasePointer(), &b_);
148
149 Store(Load(source), destination);
150 }
151 }
152 }
153
EmitComputation(HloComputation * computation,const string & function_name_prefix,bool is_top_level_computation,absl::Span<HloInstruction * const> instruction_order)154 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
155 HloComputation* computation, const string& function_name_prefix,
156 bool is_top_level_computation,
157 absl::Span<HloInstruction* const> instruction_order) {
158 string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
159 VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
160 is_top_level_computation_ = is_top_level_computation;
161 num_dynamic_loop_bounds_ = 0;
162 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
163 num_dynamic_loop_bounds_ =
164 computation->root_instruction()->outer_dimension_partitions().size();
165 }
166
167 if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
168 TF_ASSIGN_OR_RETURN(
169 computation_root_allocation_,
170 assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
171 }
172
173 for (const HloInstruction* param : computation->parameter_instructions()) {
174 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
175 assignment_.GetUniqueTopLevelSlice(param));
176 computation_parameter_allocations_[param_slice.allocation()->index()] =
177 param->parameter_number();
178 }
179
180 InitializeIrFunction(function_name);
181 // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
182 // readcyclecounter if it is unavailable.
183 bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
184 arch_type_ == llvm::Triple::ArchType::x86_64;
185 profiling_state_ = ProfilingState(use_rdtscp);
186
187 tracing_state_.set_enabled(
188 computation->parent()->config().cpu_traceme_enabled());
189
190 TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order));
191 llvm::Function* ir_function = compute_function_->function();
192 InsertOrDie(&emitted_functions_, computation, ir_function);
193 // Delete 'compute_function', finalizing 'ir_function' and restoring caller
194 // IR insert point.
195
196 // Function epilogue: copying the value over to either the return register,
197 // or values pointing from the return register.
198 const BufferAllocation* root_allocation =
199 computation_root_allocation_.allocation();
200 if (root_allocation && root_allocation->is_thread_local()) {
201 EmitThreadLocalFunctionEpilogue(computation);
202 }
203
204 // Destructor for compute_function_ emits the "ret void" instruction.
205 compute_function_.reset();
206 computation_root_allocation_ = BufferAllocation::Slice();
207 computation_parameter_allocations_.clear();
208 return ir_function;
209 }
210
InitializeIrFunction(const string & function_name)211 void IrEmitter::InitializeIrFunction(const string& function_name) {
212 // Functions with local linkage get an inlining bonus. Because we know
213 // a-priori that embedded functions (non-entry functions) will not have its
214 // name resolved, give it local linkage.
215 llvm::Function::LinkageTypes linkage =
216 is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
217 : llvm::GlobalValue::InternalLinkage;
218 // Create and initialize new IrFunction.
219 compute_function_.reset(new IrFunction(function_name, linkage,
220 hlo_module_config_, module_, &b_,
221 num_dynamic_loop_bounds_));
222 }
223
~IrEmitter()224 IrEmitter::~IrEmitter() {}
225
HandleBitcast(HloInstruction * bitcast)226 Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
227 VLOG(2) << "HandleBitcast: " << bitcast->ToString();
228 emitted_value_[bitcast] =
229 BitCast(GetEmittedValueFor(bitcast->operand(0)),
230 IrShapeType(bitcast->shape())->getPointerTo(), IrName(bitcast));
231 return Status::OK();
232 }
233
EmitGlobalForLiteral(const Literal & literal)234 llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
235 llvm::Constant* initializer =
236 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
237 llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
238 /*Module=*/*module_,
239 /*Type=*/initializer->getType(),
240 /*isConstant=*/true,
241 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
242 /*Initializer=*/initializer,
243 /*Name=*/"");
244 result_global->setAlignment(
245 llvm::Align(MinimumAlignmentForShape(literal.shape())));
246 result_global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
247 return llvm::ConstantExpr::getBitCast(
248 result_global, IrShapeType(literal.shape())->getPointerTo());
249 }
250
EmitConstantGlobals()251 Status IrEmitter::EmitConstantGlobals() {
252 for (const BufferAllocation& allocation : assignment_.Allocations()) {
253 if (!allocation.is_constant()) {
254 continue;
255 }
256
257 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
258 llvm::Constant* global_for_const;
259 auto it = emitted_literals_.find(&literal);
260 if (it != emitted_literals_.end()) {
261 global_for_const = it->second;
262 } else {
263 global_for_const = EmitGlobalForLiteral(literal);
264 InsertOrDie(&emitted_literals_, &literal, global_for_const);
265 }
266
267 InsertOrDie(&constant_buffer_to_global_, allocation.index(),
268 global_for_const);
269 }
270
271 return Status::OK();
272 }
273
HandleConstant(HloInstruction * constant)274 Status IrEmitter::HandleConstant(HloInstruction* constant) {
275 VLOG(2) << "HandleConstant: " << constant->ToString();
276 // IrEmitter::EmitConstantGlobals has already taken care of emitting the body
277 // of the constant.
278 return EmitTargetAddressForOp(constant);
279 }
280
HandleCopy(HloInstruction * copy)281 Status IrEmitter::HandleCopy(HloInstruction* copy) {
282 if (copy->shape().IsTuple() ||
283 (copy->shape().IsArray() &&
284 LayoutUtil::Equal(copy->operand(0)->shape().layout(),
285 copy->shape().layout()))) {
286 // If the layouts are equal this is just a memcpy. kCopy shallow copies a
287 // tuple so just memcpy the top-level buffer for tuples.
288 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
289 return EmitMemcpy(*(copy->operand(0)), *copy);
290 } else if (copy->shape().IsArray()) {
291 // Use the elemental emitter for array shapes.
292 return DefaultAction(copy);
293 }
294 return Unimplemented("unsupported operand type %s for copy instruction",
295 PrimitiveType_Name(copy->shape().element_type()));
296 }
297
298 // Calculate the alignment of a buffer allocated for a given primitive type.
MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type)299 int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
300 int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
301 DCHECK_GE(byte_size, 0);
302 // Largest scalar is a complex128 so we don't need to worry about the
303 // int64->int truncation here.
304 DCHECK_LE(byte_size, 16);
305
306 // Allocations may be 8-byte aligned if part of a small block.
307 return std::min(int64{8}, byte_size);
308 }
309
ByteSizeOf(const Shape & shape) const310 int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
311 return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
312 }
313
314 // Calculate the alignment of a buffer allocated for a given shape.
MinimumAlignmentForShape(const Shape & shape)315 int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
316 if (ShapeUtil::IsScalar(shape)) {
317 return MinimumAlignmentForPrimitiveType(shape.element_type());
318 }
319
320 int64 buffer_size = ByteSizeOf(shape);
321 DCHECK_GE(buffer_size, 0);
322 DCHECK_LE(buffer_size, SIZE_MAX);
323
324 return target_machine_features_.minimum_alignment_for_allocation(buffer_size);
325 }
326
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,const Shape & shape)327 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
328 const Shape& shape) {
329 int alignment = MinimumAlignmentForShape(shape);
330 if (alignment > 1) {
331 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
332 }
333 }
334
AttachAlignmentMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)335 void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
336 int64 buffer_size) {
337 int alignment =
338 target_machine_features_.minimum_alignment_for_allocation(buffer_size);
339 if (alignment > 1) {
340 llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
341 }
342 }
343
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,const Shape & shape)344 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
345 const Shape& shape) {
346 AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
347 }
348
AttachDereferenceableMetadataForLoad(llvm::LoadInst * load,int64 buffer_size)349 void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
350 int64 buffer_size) {
351 if (buffer_size > 0) {
352 llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
353 }
354 }
355
HandleGetTupleElement(HloInstruction * get_tuple_element)356 Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
357 // A tuple is an array of pointers, one for each operand. Each pointer points
358 // to the output buffer of its corresponding operand. A GetTupleElement
359 // instruction forwards a pointer to the tuple element buffer at the given
360 // index.
361 const HloInstruction* operand = get_tuple_element->operand(0);
362 const Shape& shape = get_tuple_element->shape();
363 emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
364 shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
365 GetEmittedValueFor(operand), &b_);
366 return Status::OK();
367 }
368
HandleSelect(HloInstruction * select)369 Status IrEmitter::HandleSelect(HloInstruction* select) {
370 auto pred = select->operand(0);
371 TF_RET_CHECK(pred->shape().element_type() == PRED);
372 return DefaultAction(select);
373 }
374
HandleTupleSelect(HloInstruction * tuple_select)375 Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
376 auto pred = tuple_select->operand(0);
377 auto on_true = tuple_select->operand(1);
378 auto on_false = tuple_select->operand(2);
379 TF_RET_CHECK(pred->shape().element_type() == PRED);
380 TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
381 TF_RET_CHECK(tuple_select->shape().IsTuple());
382 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
383 llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
384 GetEmittedValueFor(on_true),
385 GetEmittedValueFor(on_false), &b_);
386 return Status::OK();
387 }
388
HandleInfeed(HloInstruction * instruction)389 Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
390 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
391 VLOG(2) << "HandleInfeed: " << infeed->ToString();
392
393 // The infeed operation produces a two-element tuple containing data and a
394 // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
395 const Shape& data_shape = infeed->infeed_shape();
396 DCHECK(ShapeUtil::Equal(data_shape,
397 ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
398 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
399
400 // Write the tuple index table.
401 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
402 assignment_.GetUniqueSlice(infeed, {0}));
403 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
404 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
405 assignment_.GetUniqueSlice(infeed, {1}));
406 llvm::Value* token_address = EmitBufferPointer(
407 token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
408 llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_);
409
410 if (data_shape.IsTuple()) {
411 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
412
413 // For a tuple, we first copy each of the internal elements to
414 // their corresponding target locations. We then construct the
415 // tuple outer buffer containing pointers to the internal
416 // elements.
417 std::vector<llvm::Value*> tuple_element_addresses;
418 for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) {
419 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
420 assignment_.GetUniqueSlice(infeed, {0, i}));
421
422 const Shape& tuple_element_shape =
423 ShapeUtil::GetTupleElementShape(data_shape, i);
424
425 // Only the outer tuple buffer's target address is obtained from
426 // GetEmittedValueFor, to handle the case when Infeed is the root
427 // instruction. Target addresses for internal elements can be obtained
428 // from EmitBufferPointer.
429 llvm::Value* tuple_element_address =
430 EmitBufferPointer(buffer, tuple_element_shape);
431
432 TF_RETURN_IF_ERROR(EmitXfeedTransfer(
433 XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
434
435 tuple_element_addresses.push_back(tuple_element_address);
436 }
437
438 llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
439 tuple_element_addresses, &b_);
440 } else {
441 TF_RETURN_IF_ERROR(
442 EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
443 }
444
445 return Status::OK();
446 }
447
EmitXfeedTransfer(XfeedKind kind,const Shape & shape,llvm::Value * program_buffer_address)448 Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
449 llvm::Value* program_buffer_address) {
450 int64 length = ByteSizeOf(shape);
451 if (length < 0 || length > std::numeric_limits<int32>::max()) {
452 return InvalidArgument(
453 "xfeed (infeed or outfeed) buffer length %d is outside the valid "
454 "size range",
455 length);
456 }
457 int32 length_32 = static_cast<int32>(length);
458
459 int32 shape_length;
460 TF_ASSIGN_OR_RETURN(
461 llvm::Value * shape_ptr,
462 llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
463
464 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
465
466 const char* acquire_func_name =
467 kind == XfeedKind::kInfeed
468 ? runtime::kAcquireInfeedBufferForDequeueSymbolName
469 : runtime::kAcquireOutfeedBufferForPopulationSymbolName;
470
471 // Implementation note: this call informs the runtime that it wants a
472 // buffer of size exactly 'length_32', and the runtime is responsible for
473 // check-failing the process if there is a mismatch, versus passing us
474 // back a buffer that we might overrun.
475 llvm::Value* acquired_pointer =
476 EmitCallToFunc(acquire_func_name,
477 {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
478 shape_ptr, b_.getInt32(shape_length)},
479 i8_ptr_type);
480 if (kind == XfeedKind::kInfeed) {
481 // Copy to the program buffer address from the acquired buffer.
482 MemCpy(program_buffer_address, /*DstAlign=*/llvm::Align(1),
483 acquired_pointer,
484 /*SrcAlign=*/llvm::Align(1), length_32);
485 } else {
486 // Outfeed -- copy from the in-program address to the acquired buffer.
487 MemCpy(acquired_pointer, /*DstAlign=*/llvm::Align(1),
488 program_buffer_address,
489 /*SrcAlign=*/llvm::Align(1), length_32);
490 if (emit_code_for_msan_) {
491 // Mark the outfed data as initialized for msan. The buffer gets read by
492 // the host code, which might be msan-instrumented.
493 // TODO(b/66051036): Run the msan instrumentation pass instead.
494 const llvm::DataLayout& dl = module_->getDataLayout();
495 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
496 EmitCallToFunc(
497 "__msan_unpoison",
498 {acquired_pointer, llvm::ConstantInt::get(intptr_type, length)},
499 b_.getVoidTy());
500 }
501 }
502
503 const char* release_func_name =
504 kind == XfeedKind::kInfeed
505 ? runtime::kReleaseInfeedBufferAfterDequeueSymbolName
506 : runtime::kReleaseOutfeedBufferAfterPopulationSymbolName;
507 EmitCallToFunc(release_func_name,
508 {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
509 acquired_pointer, shape_ptr, b_.getInt32(shape_length)},
510 b_.getVoidTy());
511
512 return Status::OK();
513 }
514
HandleOutfeed(HloInstruction * outfeed)515 Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
516 // Outfeed produces no useful result, but it does return a token[] that can be
517 // threaded through to other side effecting operations to ensure ordering. In
518 // the IR emitter we treat this token as a normal u8[] and thus need to insert
519 // an entry for it in emitted_value_.
520 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed));
521
522 HloInstruction* operand = outfeed->operands()[0];
523 const Shape& operand_shape = operand->shape();
524
525 llvm::Value* value = GetEmittedValueFor(operand);
526 if (!operand_shape.IsTuple()) {
527 return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
528 }
529
530 TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
531
532 for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
533 const Shape& tuple_element_shape =
534 ShapeUtil::GetTupleElementShape(operand_shape, i);
535 llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
536 tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
537 value, &b_);
538 TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
539 tuple_element_shape, tuple_element));
540 }
541
542 return Status::OK();
543 }
544
HandleSort(HloInstruction * hlo)545 Status IrEmitter::HandleSort(HloInstruction* hlo) {
546 const HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
547 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
548 Shape keys_shape = sort->keys()->shape();
549 PrimitiveType keys_type = keys_shape.element_type();
550 if (!primitive_util::IsArrayType(keys_type)) {
551 return Unimplemented("Element type %s not supported in the Sort op on CPU.",
552 PrimitiveType_Name(keys_type));
553 }
554 std::vector<llvm::Value*> destination_addresses(sort->operand_count());
555 for (int64 i = 0; i < sort->operand_count(); ++i) {
556 ShapeIndex shape_index =
557 sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({});
558 const HloInstruction* operand = sort->operand(i);
559 // We assume that the layout of all involved operands and outputs is the
560 // same.
561 TF_RET_CHECK(
562 LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape()));
563 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
564 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
565
566 // The sort is implemented in-place, therefore we first copy the operand
567 // buffer to the output buffer if they are not the same.
568 auto destination_buffer = GetAllocationSlice(*sort, shape_index);
569 destination_addresses[i] =
570 EmitBufferPointer(destination_buffer, operand->shape());
571 auto source_address = GetAllocationSlice(*operand);
572 if (destination_buffer != source_address) {
573 int64 primitive_type_size =
574 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
575 auto source_buffer = GetEmittedValueFor(operand);
576 int64 size = ByteSizeOf(operand->shape());
577 MemCpy(destination_addresses[i],
578 /*DstAlign=*/llvm::Align(primitive_type_size), source_buffer,
579 /*SrcAlign=*/llvm::Align(primitive_type_size), size);
580 }
581 }
582
583 // Normalize the shape and the dimension to sort.
584 Shape normalized_keys_shape =
585 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape);
586 int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
587 keys_shape.layout())[sort->sort_dimension()];
588
589 int64 sort_dimension_elements =
590 normalized_keys_shape.dimensions(physical_dimension_to_sort);
591 int64 higher_dimensions = 1;
592 for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
593 higher_dimensions *= normalized_keys_shape.dimensions(i);
594 }
595 int64 lower_dimensions = 1;
596 for (int64 i = normalized_keys_shape.rank() - 1;
597 i > physical_dimension_to_sort; --i) {
598 lower_dimensions *= normalized_keys_shape.dimensions(i);
599 }
600
601 CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply()));
602 llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
603 b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca",
604 &b_);
605 llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
606 b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca",
607 &b_);
608 for (int64 i = 0; i < sort->operand_count(); ++i) {
609 llvm::Value* value_as_i8ptr =
610 PointerCast(destination_addresses[i], b_.getInt8PtrTy());
611 llvm::Value* slot_in_values_alloca =
612 ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i);
613 Store(value_as_i8ptr, slot_in_values_alloca);
614 llvm::Value* slot_in_sizes_alloca =
615 ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i);
616 llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType(
617 sort->operand(i)->shape().element_type()));
618 Store(size, slot_in_sizes_alloca);
619 }
620
621 auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply());
622 EmitCallToFunc(
623 runtime::kKeyValueSortSymbolName,
624 {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
625 b_.getInt64(lower_dimensions), values,
626 b_.getInt32(sort->operand_count()), sizes, b_.getInt1(sort->is_stable()),
627 GetExecutableRunOptionsArgument(), GetProfileCountersArgument(),
628 less_than_function},
629 b_.getVoidTy());
630
631 if (sort->values_count() > 0) {
632 llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_);
633 }
634 return Status::OK();
635 }
636
HandleTuple(HloInstruction * tuple)637 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
638 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple));
639 std::vector<llvm::Value*> base_ptrs;
640 for (auto operand : tuple->operands()) {
641 base_ptrs.push_back(GetEmittedValueFor(operand));
642 }
643 llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_);
644 return Status::OK();
645 }
646
HandleReduceWindow(HloInstruction * reduce_window)647 Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
648 // Pseudo code for reduce window:
649 //
650 // for (coordinates O in the output)
651 // value = init_value;
652 // for (coordinates W in the window)
653 // for each index i:
654 // input coordinates I_i = O_i * stride_i + W_i - pad_low_i
655 // if I within bounds of input:
656 // value = function(value, input(I));
657 // output(O) = value;
658 //
659 // This is completely un-optimized and just here to have something
660 // that works.
661 return DefaultAction(reduce_window);
662 }
663
HandleSelectAndScatter(HloInstruction * select_and_scatter)664 Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
665 CHECK_EQ(select_and_scatter->operand_count(), 3);
666 const auto operand = select_and_scatter->operand(0);
667 const auto source = select_and_scatter->operand(1);
668 const auto init_value = select_and_scatter->operand(2);
669 const Window& window = select_and_scatter->window();
670 PrimitiveType operand_element_type = operand->shape().element_type();
671 const int64 rank = operand->shape().rank();
672 CHECK_EQ(rank, source->shape().rank());
673 CHECK_EQ(rank, window.dimensions_size());
674
675 // TODO(b/31410564): Implement dilation for select-and-scatter.
676 if (window_util::HasDilation(window)) {
677 return Unimplemented(
678 "Dilation for SelectAndScatter is not implemented on CPU. ");
679 }
680
681 // Pseudo code for select-and-scatter:
682 //
683 // initialized_flag is initially off for every window, and is turned on after
684 // the first iteration is completed and the first operand value is selected.
685 //
686 // output(*) = init_value
687 // for (coordinates S in the source) {
688 // initialized_flag = false
689 // for (coordinates W in the window) {
690 // I = S * stride + W - pad_low
691 // if I within bounds of operand:
692 // if !initialized_flag or select(selected_value, operand(I)) == false:
693 // selected_value = operand(I)
694 // selected_index = I
695 // initialized_flag = true
696 // }
697 // output(selected_index) = scatter(output(selected_index), source(S))
698 // }
699 //
700
701 // Initialize the output array with the given init_value.
702 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
703 select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"),
704 [this, init_value](const llvm_ir::IrArray::Index& target_index) {
705 llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
706 return Load(init_value_addr);
707 }));
708
709 // Create a loop to iterate over the source array to scatter to the output.
710 llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_);
711 const llvm_ir::IrArray::Index source_index =
712 source_loops.AddLoopsForShape(source->shape(), "source");
713 SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_);
714
715 // Allocate space to keep the currently selected value, its index, and
716 // the boolean initialized_flag, which is initially set to false.
717 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
718 llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
719 "selected_value_address", &b_,
720 MinimumAlignmentForPrimitiveType(operand_element_type));
721 llvm::Value* selected_index_address =
722 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
723 b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_);
724 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
725 b_.getInt1Ty(), "initialized_flag_address", &b_);
726 Store(b_.getInt1(false), initialized_flag_address);
727
728 // Create the inner loop to iterate over the window.
729 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_);
730 std::vector<int64> window_size;
731 for (const auto& dim : window.dimensions()) {
732 window_size.push_back(dim.size());
733 }
734 const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
735 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
736 SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_);
737
738 // Compute the operand index to visit and evaluate the condition whether the
739 // operand index is within the bounds. The unsigned comparison includes
740 // checking whether the operand index >= 0.
741 std::vector<llvm::Value*> operand_multi_index(source_index.size());
742 llvm::Value* in_bounds_condition = b_.getTrue();
743 for (int64 i = 0; i < rank; ++i) {
744 llvm::Value* strided_index =
745 NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride()));
746 operand_multi_index[i] =
747 NSWSub(NSWAdd(strided_index, window_index[i]),
748 b_.getInt64(window.dimensions(i).padding_low()));
749 llvm::Value* index_condition =
750 ICmpULT(operand_multi_index[i],
751 b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
752 in_bounds_condition = And(in_bounds_condition, index_condition);
753 }
754 CHECK(in_bounds_condition != nullptr);
755
756 // Only need to do something if the operand index is within the bounds. First
757 // check if the initialized_flag is set.
758 llvm_ir::LlvmIfData if_in_bounds =
759 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
760 SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
761 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
762 Load(initialized_flag_address), "initialized", &b_);
763
764 // If the initialized_flag is false, initialize the selected value and index
765 // with the currently visiting operand.
766 SetToFirstInsertPoint(if_initialized.false_block, &b_);
767 const auto save_operand_index =
768 [&](const llvm_ir::IrArray::Index& operand_index) {
769 for (int64 i = 0; i < rank; ++i) {
770 llvm::Value* selected_index_address_slot =
771 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
772 Store(operand_index[i], selected_index_address_slot);
773 }
774 };
775 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
776 llvm_ir::IrArray::Index operand_index(
777 operand_multi_index, operand_array.GetShape(), b_.getInt64Ty());
778 llvm::Value* operand_data =
779 operand_array.EmitReadArrayElement(operand_index, &b_);
780 Store(operand_data, selected_value_address);
781 save_operand_index(operand_index);
782 Store(b_.getInt1(true), initialized_flag_address);
783
784 // If the initialized_flag is true, call the `select` function to potentially
785 // update the selected value and index with the currently visiting operand.
786 SetToFirstInsertPoint(if_initialized.true_block, &b_);
787 llvm::Value* operand_address =
788 operand_array.EmitArrayElementAddress(operand_index, &b_);
789 llvm::Value* operand_element = Load(operand_address);
790 llvm::Value* result = EmitScalarReturningThreadLocalCall(
791 *select_and_scatter->select(),
792 {Load(selected_value_address), operand_element}, "select_function");
793
794 // If the 'select' function returns false, update the selected value and the
795 // index to the currently visiting operand.
796 llvm::Value* cond = ICmpNE(
797 result,
798 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
799 "boolean_predicate");
800 llvm_ir::LlvmIfData if_select_lhs =
801 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
802 SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
803 Store(Load(operand_address), selected_value_address);
804 save_operand_index(operand_index);
805
806 // After iterating over the window elements, scatter the source element to
807 // the selected index of the output. The value we store at the output
808 // location is computed by calling the `scatter` function with the source
809 // value and the current output value.
810 SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_);
811 std::vector<llvm::Value*> selected_multi_index;
812 for (int64 i = 0; i < rank; ++i) {
813 llvm::Value* selected_index_address_slot =
814 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
815 selected_multi_index.push_back(Load(selected_index_address_slot));
816 }
817 llvm_ir::IrArray source_array(GetIrArrayFor(source));
818 llvm::Value* source_value =
819 source_array.EmitReadArrayElement(source_index, &b_);
820 llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
821 llvm_ir::IrArray::Index selected_index(
822 selected_multi_index, output_array.GetShape(), source_index.GetType());
823 llvm::Value* output_value =
824 output_array.EmitReadArrayElement(selected_index, &b_);
825 llvm::Value* scatter_value = EmitScalarReturningThreadLocalCall(
826 *select_and_scatter->scatter(), {output_value, source_value},
827 "scatter_function");
828 output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
829
830 SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
831 return Status::OK();
832 }
833
HandleDot(HloInstruction * dot)834 Status IrEmitter::HandleDot(HloInstruction* dot) {
835 auto lhs = dot->operand(0);
836 auto rhs = dot->operand(1);
837 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
838 /*instruction=*/*dot, /*operands=*/{lhs, rhs},
839 /*supported_types=*/
840 {PRED, S8, U8, S16, U16, S32, U32, S64, U64, F16, F32, F64, C64, C128}));
841 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
842
843 if (dnums.lhs_contracting_dimensions_size() != 1) {
844 // This is disallowed by ShapeInference today.
845 return Unimplemented(
846 "Dot with multiple contracting dimensions not implemented.");
847 }
848
849 llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
850 llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
851
852 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot));
853 llvm_ir::IrArray target_array = GetIrArrayFor(dot);
854
855 VLOG(2) << "HandleDot: ";
856 VLOG(2) << " lhs operand: "
857 << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
858 VLOG(2) << " rhs operand: "
859 << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
860 VLOG(2) << " target: "
861 << llvm_ir::DumpToString(*target_array.GetBasePointer());
862
863 // Dot operation is complicated so we delegate to a helper class.
864 return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
865 /*addend_array=*/nullptr,
866 GetExecutableRunOptionsArgument(), &b_, mlir_context_,
867 hlo_module_config_, target_machine_features_);
868 }
869
HandleConvolution(HloInstruction * convolution)870 Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
871 auto lhs = convolution->operand(0);
872 auto rhs = convolution->operand(1);
873 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
874 /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
875 /*supported_types=*/{F16, F32, F64, C64, C128}));
876
877 // TODO(tonywy): Add PotentiallyImplementedAsMKLConvolution to support
878 // different data layouts.
879 if (PotentiallyImplementedAsEigenConvolution(*convolution,
880 target_machine_features_)) {
881 const Shape& lhs_shape = lhs->shape();
882 const Shape& rhs_shape = rhs->shape();
883 const Shape& convolution_shape = convolution->shape();
884 // The input, kernel and output agree with respect to layout.
885 if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
886 LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
887 LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
888 // We lower 1D convolutions into calls to the same Eigen function as 2D
889 // convolutions, except that we pretend that the 1D convolution is really
890 // a 2D convolution with the missing dimension set to 1. We also adjust
891 // the padding, dilation parameters as needed.
892 bool one_dim_convolution = lhs_shape.dimensions_size() == 3;
893 llvm::Value* lhs_address = GetEmittedValueFor(lhs);
894 llvm::Value* rhs_address = GetEmittedValueFor(rhs);
895 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution));
896
897 const ConvolutionDimensionNumbers& dnums =
898 convolution->convolution_dimension_numbers();
899
900 // Input tensor.
901 const Shape& input_shape = convolution->operand(0)->shape();
902 int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension());
903 int64 input_rows =
904 input_shape.dimensions(dnums.input_spatial_dimensions(0));
905 int64 input_cols =
906 one_dim_convolution
907 ? 1
908 : input_shape.dimensions(dnums.input_spatial_dimensions(1));
909 int64 input_channels =
910 input_shape.dimensions(dnums.input_feature_dimension());
911
912 // Kernel tensor.
913 const Shape& kernel_shape = convolution->operand(1)->shape();
914 int64 kernel_rows =
915 kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
916 int64 kernel_cols =
917 one_dim_convolution
918 ? 1
919 : kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
920 int64 kernel_channels =
921 kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
922 int64 kernel_filters =
923 kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
924
925 // Output tensor.
926 const Shape& convolution_shape = convolution->shape();
927 int64 output_rows =
928 convolution_shape.dimensions(dnums.output_spatial_dimensions(0));
929 int64 output_cols = one_dim_convolution
930 ? 1
931 : convolution_shape.dimensions(
932 dnums.output_spatial_dimensions(1));
933
934 // Extract the window stride for the convolution.
935 const Window& window = convolution->window();
936 int64 row_stride = window.dimensions(0).stride();
937 int64 col_stride =
938 one_dim_convolution ? 1 : window.dimensions(1).stride();
939
940 int64 padding_top = window.dimensions(0).padding_low();
941 int64 padding_bottom = window.dimensions(0).padding_high();
942 int64 padding_left =
943 one_dim_convolution ? 0 : window.dimensions(1).padding_low();
944 int64 padding_right =
945 one_dim_convolution ? 0 : window.dimensions(1).padding_high();
946
947 int64 lhs_row_dilation = window.dimensions(0).base_dilation();
948 int64 lhs_col_dilation =
949 one_dim_convolution ? 1 : window.dimensions(1).base_dilation();
950 int64 rhs_row_dilation = window.dimensions(0).window_dilation();
951 int64 rhs_col_dilation =
952 one_dim_convolution ? 1 : window.dimensions(1).window_dilation();
953
954 PrimitiveType primitive_type = lhs->shape().element_type();
955 llvm::Type* ir_ptr_type = primitive_type == F16
956 ? b_.getHalfTy()->getPointerTo()
957 : b_.getFloatTy()->getPointerTo();
958 bool multi_threaded =
959 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
960 bool use_mkl_dnn =
961 hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
962
963 // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the
964 // potential race condition by setting the omp_num_threads.
965 const char* fn_name =
966 primitive_type == F16
967 ? (multi_threaded
968 ? runtime::kEigenConvF16SymbolName
969 : runtime::kEigenSingleThreadedConvF16SymbolName)
970 : (multi_threaded
971 ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName
972 : runtime::kEigenConvF32SymbolName)
973 : runtime::kEigenSingleThreadedConvF32SymbolName);
974 if (!multi_threaded && use_mkl_dnn) {
975 LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
976 "conv2d function.";
977 }
978 EmitCallToFunc(fn_name,
979 {
980 GetExecutableRunOptionsArgument(),
981 BitCast(GetEmittedValueFor(convolution), ir_ptr_type),
982 BitCast(lhs_address, ir_ptr_type),
983 BitCast(rhs_address, ir_ptr_type),
984 b_.getInt64(input_batch),
985 b_.getInt64(input_rows),
986 b_.getInt64(input_cols),
987 b_.getInt64(input_channels),
988 b_.getInt64(kernel_rows),
989 b_.getInt64(kernel_cols),
990 b_.getInt64(kernel_channels),
991 b_.getInt64(kernel_filters),
992 b_.getInt64(output_rows),
993 b_.getInt64(output_cols),
994 b_.getInt64(row_stride),
995 b_.getInt64(col_stride),
996 b_.getInt64(padding_top),
997 b_.getInt64(padding_bottom),
998 b_.getInt64(padding_left),
999 b_.getInt64(padding_right),
1000 b_.getInt64(lhs_row_dilation),
1001 b_.getInt64(lhs_col_dilation),
1002 b_.getInt64(rhs_row_dilation),
1003 b_.getInt64(rhs_col_dilation),
1004 },
1005 b_.getVoidTy(), /*does_not_throw=*/true,
1006 /*only_accesses_arg_memory=*/true);
1007
1008 return Status::OK();
1009 }
1010 }
1011
1012 // This is a completely un-optimized version of convolution just to
1013 // have an early version that works. E.g. the input index and
1014 // padding calculation is not hoisted out of the inner loop.
1015 //
1016 // See the description of convolution in the XLA documentation for the pseudo
1017 // code for convolution.
1018 return DefaultAction(convolution);
1019 }
1020
HandleFft(HloInstruction * fft)1021 Status IrEmitter::HandleFft(HloInstruction* fft) {
1022 auto operand = fft->operand(0);
1023 TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
1024 /*instruction=*/*fft, /*operands=*/{operand},
1025 /*supported_types=*/{F32, F64, C64, C128}));
1026 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
1027 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
1028 VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
1029 VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape());
1030
1031 llvm::Value* operand_address = GetEmittedValueFor(operand);
1032 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft));
1033
1034 const std::vector<int64>& fft_length = fft->fft_length();
1035 int64 input_batch = 1;
1036 for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) {
1037 input_batch *= fft->shape().dimensions(i);
1038 }
1039
1040 // Args have been computed, make the call.
1041 llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo();
1042 bool multi_threaded_eigen =
1043 hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
1044 const char* fn_name = multi_threaded_eigen
1045 ? runtime::kEigenFftSymbolName
1046 : runtime::kEigenSingleThreadedFftSymbolName;
1047 const int fft_rank = fft_length.size();
1048 EmitCallToFunc(
1049 fn_name,
1050 {GetExecutableRunOptionsArgument(),
1051 BitCast(GetEmittedValueFor(fft), int8_ptr_type),
1052 BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
1053 b_.getInt32(operand->shape().element_type() == F64 ||
1054 operand->shape().element_type() == C128),
1055 b_.getInt32(fft_rank), b_.getInt64(input_batch),
1056 b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
1057 b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
1058 b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)},
1059 b_.getVoidTy(), /*does_not_throw=*/true,
1060 /*only_accesses_arg_memory=*/false,
1061 /*only_accesses_inaccessible_mem_or_arg_mem=*/true);
1062
1063 return Status::OK();
1064 }
1065
HandleAllReduceSingleReplica(HloInstruction * crs)1066 Status IrEmitter::HandleAllReduceSingleReplica(HloInstruction* crs) {
1067 // When there is a single replica, a cross replica sum is the identity
1068 // function, and the buffer assignment expects a copy.
1069 //
1070 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1071 // in algebraic-simplifier, but currently on some platforms
1072 // HloModuleConfig::num_replicas changes between when the module is compiled
1073 // and when it's run.
1074 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1075
1076 // CRS with one operand and one replica is simply the identity function.
1077 if (crs->operand_count() == 1) {
1078 return EmitMemcpy(*crs->operand(0), *crs);
1079 }
1080
1081 // CRS with multiple operands and one replica produces a (one-deep) tuple.
1082 std::vector<llvm::Value*> operand_ptrs;
1083 for (int64 i = 0; i < crs->operand_count(); ++i) {
1084 llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i));
1085 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1086 assignment_.GetUniqueSlice(crs, {i}));
1087
1088 const Shape& operand_shape = crs->operand(i)->shape();
1089 CHECK(operand_shape.IsArray())
1090 << "Operands to all-reduce must be arrays: " << crs->ToString();
1091 operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1092
1093 // TODO(b/63762267): Be more aggressive about specifying alignment.
1094 MemCpy(operand_ptrs.back(), /*DstAlign=*/llvm::Align(1), in_ptr,
1095 /*SrcAlign=*/llvm::Align(1), ShapeUtil::ByteSizeOf(operand_shape));
1096 }
1097 llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_);
1098 return Status::OK();
1099 }
1100
HandleAllReduceMultipleReplica(HloInstruction * crs)1101 Status IrEmitter::HandleAllReduceMultipleReplica(HloInstruction* crs) {
1102 CHECK_GE(crs->operand_count(), 1);
1103 PrimitiveType datatype = crs->operand(0)->shape().element_type();
1104 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs));
1105
1106 bool is_datatype_supported = [&] {
1107 // TODO(cheshire): Fix duplication wrt. cpu_runtime
1108 switch (datatype) {
1109 case PRED:
1110 case S8:
1111 case U8:
1112 case S32:
1113 case U32:
1114 case S64:
1115 case U64:
1116 case F16:
1117 case F32:
1118 case F64:
1119 return true;
1120 default:
1121 return false;
1122 }
1123 }();
1124
1125 if (!is_datatype_supported) {
1126 return Unimplemented("AllReduce for datatype '%s' is not supported",
1127 primitive_util::LowercasePrimitiveTypeName(datatype));
1128 }
1129
1130 if (!MatchReductionComputation(crs->to_apply()).has_value()) {
1131 return Unimplemented("AllReduce for computation '%s' is not supported",
1132 crs->to_apply()->ToString());
1133 }
1134
1135 std::string replica_groups = ReplicaGroupsToString(crs->replica_groups());
1136 int32 replica_groups_size = replica_groups.size();
1137 llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1138
1139 bool is_tuple = crs->operand_count() > 1;
1140 std::vector<llvm::Value*> input_buffer_ptrs;
1141 std::vector<llvm::Value*> output_buffer_ptrs;
1142 if (is_tuple) {
1143 CHECK(crs->shape().IsTuple());
1144
1145 for (int64 i = 0; i < crs->operand_count(); i++) {
1146 const HloInstruction* op = crs->operand(i);
1147 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1148 assignment_.GetUniqueSlice(crs, {i}));
1149 const Shape& operand_shape = crs->operand(i)->shape();
1150 CHECK(operand_shape.IsArray())
1151 << "Operands to all-reduce must be arrays: " << crs->ToString();
1152 output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1153 input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1154 }
1155 } else {
1156 Shape shape = crs->operand(0)->shape();
1157 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1158 assignment_.GetUniqueSlice(crs->operand(0), {}));
1159 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1160 assignment_.GetUniqueSlice(crs, {}));
1161 input_buffer_ptrs.push_back(EmitBufferPointer(input_slice, shape));
1162 output_buffer_ptrs.push_back(EmitBufferPointer(output_slice, shape));
1163 }
1164
1165 llvm::Value* input_buffers =
1166 EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1167 llvm::Value* output_buffers =
1168 EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1169
1170 int32 shape_length;
1171 TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
1172 llvm_ir::EncodeSelfDescribingShapeConstant(
1173 crs->shape(), &shape_length, &b_));
1174
1175 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1176 EmitCallToFunc(
1177 runtime::kAllReduceSymbolName,
1178 {/*run_options=*/GetExecutableRunOptionsArgument(),
1179 /*replica_groups=*/replica_groups_v,
1180 /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1181
1182 /*channel_id_present=*/
1183 b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1184 /*op_id=*/
1185 b_.getInt64(crs->channel_id().has_value()
1186 ? *crs->channel_id()
1187 : crs->GetModule()->unique_id()),
1188 /*reduction_kind=*/
1189 b_.getInt32(
1190 static_cast<int32>(*MatchReductionComputation(crs->to_apply()))),
1191 /*shape_ptr=*/shape_ptr,
1192 /*shape_length=*/b_.getInt32(shape_length),
1193 /*num_buffers=*/b_.getInt32(crs->operand_count()),
1194 /*input_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1195 /*output_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1196 b_.getVoidTy());
1197
1198 return Status::OK();
1199 }
1200
HandleAllReduce(HloInstruction * crs)1201 Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
1202 if (hlo_module_config_.replica_count() == 1) {
1203 return HandleAllReduceSingleReplica(crs);
1204 }
1205 return HandleAllReduceMultipleReplica(crs);
1206 }
1207
HandleAllToAll(HloInstruction * instruction)1208 Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
1209 auto* instr = Cast<HloAllToAllInstruction>(instruction);
1210 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
1211 CHECK(!instr->split_dimension() && instr->shape().IsTuple())
1212 << "Only tuple AllToAll is supported";
1213
1214 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1215 std::string replica_groups =
1216 ReplicaGroupsToString(instruction->replica_groups());
1217 int32 replica_groups_size = replica_groups.size();
1218 llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
1219
1220 int64 buffer_size = -1;
1221 std::vector<llvm::Value*> input_buffer_ptrs;
1222 std::vector<llvm::Value*> output_buffer_ptrs;
1223
1224 for (int64 i = 0; i < instruction->operand_count(); i++) {
1225 const HloInstruction* op = instruction->operand(i);
1226 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
1227 assignment_.GetUniqueSlice(instruction, {i}));
1228 const Shape& operand_shape = instruction->operand(i)->shape();
1229 CHECK(operand_shape.IsArray())
1230 << "Operands to all-to-all must be arrays: " << instruction->ToString();
1231 output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
1232 input_buffer_ptrs.push_back(GetEmittedValueFor(op));
1233 CHECK(buffer_size == -1 || buffer_size == out_slice.size());
1234 buffer_size = out_slice.size();
1235 }
1236
1237 llvm::Value* input_buffers =
1238 EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
1239 llvm::Value* output_buffers =
1240 EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
1241
1242 EmitCallToFunc(
1243 runtime::kAllToAllSymbolName,
1244 {/*run_options=*/GetExecutableRunOptionsArgument(),
1245 /*channel_id_present=*/
1246 b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
1247 /*op_id=*/
1248 b_.getInt64(instruction->channel_id().has_value()
1249 ? *instruction->channel_id()
1250 : instruction->GetModule()->unique_id()),
1251 /*replica_groups=*/replica_groups_v,
1252 /*replica_groups_size=*/b_.getInt32(replica_groups_size),
1253 /*num_buffers=*/b_.getInt32(instruction->operand_count()),
1254 /*buffer_size=*/b_.getInt64(buffer_size),
1255 /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
1256 /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)},
1257 b_.getVoidTy());
1258
1259 llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
1260 return Status::OK();
1261 }
1262
HandleCollectivePermute(HloInstruction * crs)1263 Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
1264 auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
1265 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr));
1266 std::string source_target_pairs = absl::StrJoin(
1267 instr->source_target_pairs(), ",", absl::PairFormatter("="));
1268 llvm::Value* source_target_pairs_v =
1269 b_.CreateGlobalStringPtr(source_target_pairs);
1270
1271 Shape shape = crs->operand(0)->shape();
1272
1273 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
1274 assignment_.GetUniqueSlice(crs->operand(0), {}));
1275 llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
1276
1277 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1278 assignment_.GetUniqueSlice(crs, {}));
1279 llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
1280
1281 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1282 EmitCallToFunc(
1283 runtime::kCollectivePermuteSymbolName,
1284 {/*run_options=*/GetExecutableRunOptionsArgument(),
1285 /*channel_id_present=*/
1286 b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
1287 /*op_id=*/
1288 b_.getInt64(crs->channel_id().has_value()
1289 ? *crs->channel_id()
1290 : crs->GetModule()->unique_id()),
1291 /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
1292 /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
1293 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
1294 /*source_target_pairs=*/source_target_pairs_v,
1295 /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())},
1296 b_.getVoidTy());
1297
1298 return Status::OK();
1299 }
1300
HandleReplicaId(HloInstruction * hlo)1301 Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
1302 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
1303 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1304 assignment_.GetUniqueSlice(hlo, {}));
1305 llvm::Value* output_buffer = EmitBufferPointer(output_slice, hlo->shape());
1306 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
1307 EmitCallToFunc(
1308 runtime::kReplicaIdSymbolName,
1309 {/*run_options=*/GetExecutableRunOptionsArgument(),
1310 /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type)},
1311 b_.getVoidTy());
1312 return Status::OK();
1313 }
1314
HandleParameter(HloInstruction * parameter)1315 Status IrEmitter::HandleParameter(HloInstruction* parameter) {
1316 VLOG(2) << "HandleParameter: " << parameter->ToString();
1317 return EmitTargetAddressForOp(parameter);
1318 }
1319
1320 // Returns true if the relative order of the unreduced dimensions stays the same
1321 // through the reduce operation.
ReductionPreservesLayout(const HloInstruction & reduce)1322 static bool ReductionPreservesLayout(const HloInstruction& reduce) {
1323 DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
1324
1325 // Maps dimensions that were not reduced from their dimension numbers in the
1326 // source shape to their dimensions numbers in the destination shape.
1327 //
1328 // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
1329 // [0->0, 3->1].
1330 absl::flat_hash_map<int64, int64> unreduced_dim_map;
1331
1332 absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
1333 reduce.dimensions().end());
1334
1335 const Shape& operand_shape = reduce.operand(0)->shape();
1336 const Shape& result_shape = reduce.shape();
1337
1338 int64 delta = 0;
1339 for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
1340 if (reduced_dims.contains(i)) {
1341 delta++;
1342 } else {
1343 InsertOrDie(&unreduced_dim_map, i, i - delta);
1344 }
1345 }
1346
1347 // Iterate dimensions minor to major and check that the corresponding
1348 // dimensions in the source and target shapes are equivalent.
1349 int64 result_dim_idx = 0;
1350 for (int64 operand_dim_idx = 0;
1351 operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
1352 int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
1353 if (!reduced_dims.contains(operand_dim)) {
1354 if (FindOrDie(unreduced_dim_map, operand_dim) !=
1355 result_shape.layout().minor_to_major(result_dim_idx++)) {
1356 return false;
1357 }
1358 }
1359 }
1360
1361 CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
1362
1363 return true;
1364 }
1365
MatchReductionGenerator(HloComputation * function,string * failure_reason) const1366 IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
1367 HloComputation* function, string* failure_reason) const {
1368 CHECK_EQ(function->num_parameters(), 2);
1369
1370 auto root_instruction = function->root_instruction();
1371 CHECK(ShapeUtil::IsScalar(root_instruction->shape()));
1372
1373 if (root_instruction->operand_count() != 2) {
1374 *failure_reason = "root instruction is not a binary operation";
1375 return nullptr;
1376 }
1377
1378 const Shape& root_shape = root_instruction->shape();
1379 if (ShapeUtil::ElementIsComplex(root_shape)) {
1380 // TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
1381 // Complex multiply would be more challenging. We could perhaps use a
1382 // strided load to get all reals in a vector, all images in a vector, or use
1383 // CreateShuffleVector on a bitcast to float x [2N].
1384 *failure_reason = "complex values not supported";
1385 return nullptr;
1386 }
1387 bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
1388 bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
1389 bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
1390
1391 auto lhs = root_instruction->operand(0);
1392 auto rhs = root_instruction->operand(1);
1393
1394 auto param_0 = function->parameter_instruction(0);
1395 auto param_1 = function->parameter_instruction(1);
1396 if (!(lhs == param_0 && rhs == param_1) &&
1397 !(rhs == param_0 && lhs == param_1)) {
1398 *failure_reason =
1399 "root instruction is not a binary operation on the incoming arguments";
1400 return nullptr;
1401 }
1402
1403 CHECK(ShapeUtil::IsScalar(lhs->shape()) && ShapeUtil::IsScalar(rhs->shape()));
1404
1405 // This is visually similar to ElementalIrEmitter, though conceptually we're
1406 // doing something different here. ElementalIrEmitter emits scalar operations
1407 // while these emit scalar or vector operations depending on the type of the
1408 // operands. See CreateShardedVectorType for the actual types in use here.
1409 switch (root_instruction->opcode()) {
1410 default:
1411 *failure_reason = "did not recognize root instruction opcode";
1412 return nullptr;
1413
1414 case HloOpcode::kAdd:
1415 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1416 llvm::Value* rhs) {
1417 return root_is_integral ? b->CreateAdd(lhs, rhs)
1418 : b->CreateFAdd(lhs, rhs);
1419 };
1420
1421 case HloOpcode::kMultiply:
1422 return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs,
1423 llvm::Value* rhs) {
1424 return root_is_integral ? b->CreateMul(lhs, rhs)
1425 : b->CreateFMul(lhs, rhs);
1426 };
1427
1428 case HloOpcode::kAnd:
1429 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1430 return b->CreateAnd(lhs, rhs);
1431 };
1432
1433 case HloOpcode::kOr:
1434 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1435 return b->CreateOr(lhs, rhs);
1436 };
1437
1438 case HloOpcode::kXor:
1439 return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) {
1440 return b->CreateXor(lhs, rhs);
1441 };
1442
1443 case HloOpcode::kMaximum:
1444 return [root_is_floating_point, root_is_signed](
1445 llvm::IRBuilder<>* b, llvm::Value* lhs,
1446 llvm::Value* rhs) -> llvm::Value* {
1447 if (root_is_floating_point) {
1448 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
1449 {lhs, rhs}, {lhs->getType()}, b);
1450 }
1451
1452 return b->CreateSelect(
1453 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE
1454 : llvm::ICmpInst::ICMP_UGE,
1455 lhs, rhs),
1456 lhs, rhs);
1457 };
1458
1459 case HloOpcode::kMinimum:
1460 return [root_is_floating_point, root_is_signed](
1461 llvm::IRBuilder<>* b, llvm::Value* lhs,
1462 llvm::Value* rhs) -> llvm::Value* {
1463 if (root_is_floating_point) {
1464 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
1465 {lhs, rhs}, {lhs->getType()}, b);
1466 }
1467
1468 return b->CreateSelect(
1469 b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE
1470 : llvm::ICmpInst::ICMP_ULE,
1471 lhs, rhs),
1472 lhs, rhs);
1473 };
1474 }
1475 }
1476
CreateShardedVectorType(PrimitiveType element_type,unsigned element_count)1477 IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
1478 PrimitiveType element_type, unsigned element_count) {
1479 int vector_register_size_in_elements =
1480 target_machine_features_.vector_register_byte_size(
1481 *compute_function_->function()) /
1482 ShapeUtil::ByteSizeOfPrimitiveType(element_type);
1483
1484 ShardedVectorType sharded_vector_type;
1485 llvm::Type* element_ir_type =
1486 llvm_ir::PrimitiveTypeToIrType(element_type, module_);
1487
1488 for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
1489 // For every power of two present in element_count, we generate one or more
1490 // vector or scalar types.
1491 const unsigned current_size_fragment = 1u << i;
1492 if (!(element_count & current_size_fragment)) {
1493 // Power of two not present in element_count.
1494 continue;
1495 }
1496
1497 if (current_size_fragment == 1) {
1498 // Single element, use a scalar type.
1499 sharded_vector_type.push_back(element_ir_type);
1500 continue;
1501 }
1502
1503 // Lower "current_size_fragment" number of elements using (as few as
1504 // possible) vector registers.
1505
1506 if (current_size_fragment >= vector_register_size_in_elements) {
1507 auto vector_type = llvm::VectorType::get(
1508 element_ir_type, vector_register_size_in_elements, false);
1509 sharded_vector_type.insert(
1510 sharded_vector_type.end(),
1511 current_size_fragment / vector_register_size_in_elements,
1512 vector_type);
1513
1514 // Both current_size_fragment and vector_register_size_in_elements are
1515 // powers of two.
1516 CHECK_EQ(current_size_fragment % vector_register_size_in_elements, 0);
1517 continue;
1518 }
1519
1520 // For now we assume that vector_register_size_in_elements and lower powers
1521 // of two are all legal vector sizes (or at least can be lowered easily by
1522 // LLVM).
1523 sharded_vector_type.push_back(
1524 llvm::VectorType::get(element_ir_type, current_size_fragment, false));
1525 }
1526 return sharded_vector_type;
1527 }
1528
1529 StatusOr<IrEmitter::ShardedVector>
EmitInnerLoopForVectorizedReduction(const ReductionGenerator & reduction_generator,const llvm_ir::IrArray::Index & output_index,const ShardedVectorType & accumulator_type,HloInstruction * init_value,HloInstruction * arg,absl::Span<const int64> dimensions,llvm::Align element_alignment)1530 IrEmitter::EmitInnerLoopForVectorizedReduction(
1531 const ReductionGenerator& reduction_generator,
1532 const llvm_ir::IrArray::Index& output_index,
1533 const ShardedVectorType& accumulator_type, HloInstruction* init_value,
1534 HloInstruction* arg, absl::Span<const int64> dimensions,
1535 llvm::Align element_alignment) {
1536 ShardedVector accumulator;
1537 accumulator.reserve(accumulator_type.size());
1538 for (auto accumulator_shard_type : accumulator_type) {
1539 accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
1540 accumulator_shard_type, "accumulator", &b_, 0));
1541 }
1542
1543 llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value));
1544
1545 for (llvm::Value* accumulator_shard : accumulator) {
1546 llvm::Value* initial_value;
1547 auto shard_type = accumulator_shard->getType()->getPointerElementType();
1548 if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) {
1549 initial_value =
1550 VectorSplat(vector_type->getElementCount(), init_value_ssa);
1551 } else {
1552 initial_value = init_value_ssa;
1553 }
1554
1555 AlignedStore(initial_value, accumulator_shard, element_alignment);
1556 }
1557
1558 llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"),
1559 &b_);
1560 std::vector<llvm::Value*> input_multi_index =
1561 reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
1562 "reduction_dim");
1563
1564 SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_);
1565
1566 llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
1567 llvm_ir::IrArray::Index::const_iterator it = output_index.begin();
1568
1569 for (auto& i : input_multi_index) {
1570 if (i == nullptr) {
1571 i = *it++;
1572 }
1573 }
1574 CHECK(output_index.end() == it);
1575 llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
1576 b_.getInt64Ty());
1577
1578 llvm::Value* input_address = BitCast(
1579 arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy());
1580
1581 for (int i = 0; i < accumulator.size(); i++) {
1582 auto input_address_typed =
1583 BitCast(input_address, accumulator[i]->getType());
1584 auto current_accumulator_value =
1585 AlignedLoad(accumulator[i], element_alignment);
1586 auto addend = AlignedLoad(input_address_typed, element_alignment);
1587 arg_array.AnnotateLoadStoreInstructionWithMetadata(addend);
1588
1589 auto reduced_result =
1590 reduction_generator(&b_, current_accumulator_value, addend);
1591 AlignedStore(reduced_result, accumulator[i], element_alignment);
1592
1593 if (i != (accumulator.size() - 1)) {
1594 input_address = ConstInBoundsGEP1_32(reduced_result->getType(),
1595 input_address_typed, 1);
1596 }
1597 }
1598
1599 SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_);
1600
1601 ShardedVector result_ssa;
1602 result_ssa.reserve(accumulator.size());
1603 for (auto accumulator_shard : accumulator) {
1604 result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment));
1605 }
1606 return result_ssa;
1607 }
1608
EmitShardedVectorStore(llvm::Value * store_address,const std::vector<llvm::Value * > & value_to_store,llvm::Align alignment,const llvm_ir::IrArray & containing_array)1609 void IrEmitter::EmitShardedVectorStore(
1610 llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store,
1611 llvm::Align alignment, const llvm_ir::IrArray& containing_array) {
1612 for (int i = 0; i < value_to_store.size(); i++) {
1613 auto store_address_typed =
1614 BitCast(store_address,
1615 llvm::PointerType::getUnqual(value_to_store[i]->getType()));
1616
1617 auto store_instruction =
1618 AlignedStore(value_to_store[i], store_address_typed, alignment);
1619 containing_array.AnnotateLoadStoreInstructionWithMetadata(
1620 store_instruction);
1621
1622 if (i != (value_to_store.size() - 1)) {
1623 store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(),
1624 store_address_typed, 1);
1625 }
1626 }
1627 }
1628
EmitVectorizedReduce(HloInstruction * reduce,HloInstruction * arg,HloInstruction * init_value,absl::Span<const int64> dimensions,HloComputation * function,string * failure_reason)1629 StatusOr<bool> IrEmitter::EmitVectorizedReduce(
1630 HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
1631 absl::Span<const int64> dimensions, HloComputation* function,
1632 string* failure_reason) {
1633 if (!reduce->shape().IsArray()) {
1634 *failure_reason = "vectorization of variadic reduce not implemented";
1635 return false;
1636 }
1637
1638 if (!ReductionPreservesLayout(*reduce)) {
1639 return false;
1640 }
1641
1642 ReductionGenerator reduction_generator =
1643 MatchReductionGenerator(function, failure_reason);
1644 if (!reduction_generator) {
1645 return false;
1646 }
1647
1648 int vector_register_size_in_elements =
1649 target_machine_features_.vector_register_byte_size(
1650 *compute_function_->function()) /
1651 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1652 if (vector_register_size_in_elements == 0) {
1653 // Either we don't know the vector register width for the target or the
1654 // vector register is smaller than the size of the primitive type.
1655 return false;
1656 }
1657
1658 int vectorization_factor_in_bytes =
1659 target_machine_features_.vectorization_factor_in_bytes();
1660
1661 // We try to process vectorization_factor elements at the same time.
1662 const int vectorization_factor =
1663 vectorization_factor_in_bytes /
1664 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
1665
1666 bool is_reduction_over_minor_dimension = absl::c_linear_search(
1667 dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
1668
1669 llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
1670 ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
1671 MinimumAlignmentForPrimitiveType(reduce->shape().element_type())));
1672
1673 if (is_reduction_over_minor_dimension) {
1674 // TODO(sanjoy): Implement vectorized reduction over the minor dimension.
1675 *failure_reason = "reduction over minor dimension not implemented";
1676 return false;
1677 }
1678
1679 CHECK(!reduce->shape().IsTuple());
1680 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce));
1681
1682 // We know we're not reducing over the most minor dimension, which means we
1683 // can lower the reduction loop as:
1684 //
1685 // 1. We're reducing over dimensions R0, R1.
1686 // 2. D0 is the most minor dimension.
1687 // 3. VS is the vectorization stride (we want to reduce this many elements at
1688 // once)
1689 //
1690 // for (d1 in D1) {
1691 // for (d0 in D0 with stride VS) {
1692 // vector_acc = init
1693 // for (r1 in R1) {
1694 // for (r0 in R0) {
1695 // vector_acc = elementwise_reduce(vector_acc, input[d1, d0, r1, r0]
1696 // }
1697 // }
1698 // output[d1, d0] = vector_acc
1699 // }
1700 // }
1701
1702 llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_);
1703 std::vector<llvm::Value*> array_multi_index(
1704 reduce->shape().dimensions_size());
1705 for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
1706 --i) {
1707 int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
1708 int64 start_index = 0;
1709 int64 end_index = reduce->shape().dimensions(dimension);
1710 std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
1711 start_index, end_index, absl::StrFormat("dim.%d", dimension));
1712 array_multi_index[dimension] = loop->GetIndVarValue();
1713 }
1714
1715 int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0);
1716 int64 innermost_dimension_size =
1717 reduce->shape().dimensions(innermost_dimension);
1718
1719 if (llvm::BasicBlock* innermost_body_bb =
1720 loop_nest.GetInnerLoopBodyBasicBlock()) {
1721 SetToFirstInsertPoint(innermost_body_bb, &b_);
1722 }
1723
1724 auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock();
1725
1726 if (innermost_dimension_size >= vectorization_factor) {
1727 int64 start_index = 0;
1728 int64 end_index = (innermost_dimension_size / vectorization_factor) *
1729 vectorization_factor;
1730 std::unique_ptr<llvm_ir::ForLoop> loop =
1731 loop_nest.AddLoop(start_index, end_index, vectorization_factor,
1732 absl::StrFormat("dim.%d", innermost_dimension));
1733 array_multi_index[innermost_dimension] = loop->GetIndVarValue();
1734
1735 SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_);
1736
1737 ShardedVectorType vector_type = CreateShardedVectorType(
1738 reduce->shape().element_type(), vectorization_factor);
1739 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1740 b_.getInt64Ty());
1741 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1742 EmitInnerLoopForVectorizedReduction(
1743 reduction_generator, array_index, vector_type,
1744 init_value, arg, dimensions, element_alignment));
1745
1746 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1747 llvm::Value* output_address =
1748 target_array.EmitArrayElementAddress(array_index, &b_);
1749 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1750 target_array);
1751
1752 if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) {
1753 CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1754 b_.SetInsertPoint(exit_terminator);
1755 } else {
1756 CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1);
1757 b_.SetInsertPoint(loop->GetExitBasicBlock());
1758 }
1759 }
1760
1761 // Since we increment the stride for the inner dimension by more than 1, we
1762 // may need to peel out an "epilogue" iteration to get the remaining elements
1763 // in the following case:
1764 if (innermost_dimension_size % vectorization_factor) {
1765 // TODO(b/63775531): Consider using a scalar loop here to save on code size.
1766 array_multi_index[innermost_dimension] =
1767 b_.getInt64(innermost_dimension_size -
1768 (innermost_dimension_size % vectorization_factor));
1769
1770 ShardedVectorType vector_type = CreateShardedVectorType(
1771 reduce->shape().element_type(),
1772 innermost_dimension_size % vectorization_factor);
1773 llvm_ir::IrArray::Index array_index(array_multi_index, reduce->shape(),
1774 b_.getInt64Ty());
1775 TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accumulator,
1776 EmitInnerLoopForVectorizedReduction(
1777 reduction_generator, array_index, vector_type,
1778 init_value, arg, dimensions, element_alignment));
1779
1780 llvm_ir::IrArray target_array = GetIrArrayFor(reduce);
1781 llvm::Value* output_address =
1782 target_array.EmitArrayElementAddress(array_index, &b_);
1783 EmitShardedVectorStore(output_address, accumulator, element_alignment,
1784 target_array);
1785 }
1786
1787 if (outermost_loop_exit_block) {
1788 b_.SetInsertPoint(outermost_loop_exit_block);
1789 }
1790
1791 return true;
1792 }
1793
HandleReduce(HloInstruction * reduce)1794 Status IrEmitter::HandleReduce(HloInstruction* reduce) {
1795 auto arg = reduce->mutable_operand(0);
1796 auto init_value = reduce->mutable_operand(1);
1797 absl::Span<const int64> dimensions(reduce->dimensions());
1798 HloComputation* function = reduce->to_apply();
1799 if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
1800 string vectorization_failure_reason;
1801 TF_ASSIGN_OR_RETURN(
1802 bool vectorization_successful,
1803 EmitVectorizedReduce(reduce, arg, init_value, dimensions, function,
1804 &vectorization_failure_reason));
1805 if (vectorization_successful) {
1806 VLOG(1) << "Successfully vectorized reduction " << reduce->ToString()
1807 << "\n";
1808 return Status::OK();
1809 } else {
1810 VLOG(1) << "Could not vectorize reduction " << reduce->ToString() << ": "
1811 << vectorization_failure_reason;
1812 }
1813 }
1814
1815 return DefaultAction(reduce);
1816 }
1817
HandleSend(HloInstruction * send)1818 Status IrEmitter::HandleSend(HloInstruction* send) {
1819 // TODO(b/33942983): Support Send/Recv on CPU.
1820 return Unimplemented("Send is not implemented on CPU.");
1821 }
1822
HandleSendDone(HloInstruction * send_done)1823 Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
1824 // TODO(b/33942983): Support Send/Recv on CPU.
1825 return Unimplemented("Send-done is not implemented on CPU.");
1826 }
1827
HandleScatter(HloInstruction *)1828 Status IrEmitter::HandleScatter(HloInstruction*) {
1829 return Unimplemented("Scatter is not implemented on CPUs.");
1830 }
1831
HandleSlice(HloInstruction * slice)1832 Status IrEmitter::HandleSlice(HloInstruction* slice) {
1833 VLOG(2) << "HandleSlice: " << slice->ToString();
1834 auto operand = slice->operand(0);
1835 // The code below emits a sequential loop nest. For the parallel backend, use
1836 // ParallelLoopEmitter which respects dynamic loop bounds.
1837 if (ShouldEmitParallelLoopFor(*slice)) {
1838 return DefaultAction(slice);
1839 }
1840
1841 // The code below assumes the layouts are equal.
1842 if (!LayoutUtil::Equal(operand->shape().layout(), slice->shape().layout())) {
1843 return DefaultAction(slice);
1844 }
1845
1846 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
1847
1848 if (ShapeUtil::IsZeroElementArray(slice->shape())) {
1849 return Status::OK();
1850 }
1851
1852 const Layout& layout = operand->shape().layout();
1853 const int64 num_dims = operand->shape().dimensions_size();
1854
1855 // The slice lowering finds maximal contiguous blocks of memory that can be
1856 // copied from the source to the target. This is done by looking at the
1857 // source/target layout in minor to major order and do the following:
1858 //
1859 // * Find an initial segment of dimensions along which the slice uses the
1860 // whole dimension. These are the "inner" dimensions and can be folded into
1861 // the memcpy.
1862 //
1863 // * Of the remaining dimensions decide which ones require loops.
1864 //
1865 // * Implement the memcpy within the innermost loop.
1866
1867 absl::flat_hash_set<int64> inner_dims;
1868 for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
1869 if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
1870 break;
1871 }
1872 inner_dims.insert(dim);
1873 }
1874
1875 const bool is_trivial_copy = (inner_dims.size() == num_dims);
1876 if (is_trivial_copy) {
1877 if (ShapeUtil::IsEffectiveScalar(slice->shape())) {
1878 return DefaultAction(slice);
1879 } else {
1880 return EmitMemcpy(*slice, *operand);
1881 }
1882 }
1883
1884 // The memcpy will copy elements that are logically this shape (allowed to be
1885 // scalar).
1886 const Shape logical_element_shape = ShapeUtil::FilterDimensions(
1887 [&inner_dims](int64 dim) { return inner_dims.contains(dim); },
1888 operand->shape());
1889
1890 const int64 primitive_elements_per_logical_element =
1891 ShapeUtil::ElementsIn(logical_element_shape);
1892
1893 // memcpy_dim is the innermost (in terms of layout) dimension for which the
1894 // slice does *not* just copy all the elements along the dimension.
1895 const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size());
1896
1897 const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1;
1898 // The number of logical elements that can be copied in a single call
1899 // to memcpy. We can only copy 1 element at a time if there is a non-trivial
1900 // stride.
1901 const int64 memcpy_logical_elements =
1902 memcpy_is_contiguous
1903 ? slice->slice_limits(memcpy_dim) - slice->slice_starts(memcpy_dim)
1904 : 1;
1905
1906 // Determine the dimensions that get lowered as loops.
1907 std::vector<int64> outer_dims;
1908 for (int64 i = 0; i < num_dims - inner_dims.size() - 1; ++i) {
1909 outer_dims.push_back(LayoutUtil::Major(layout, i));
1910 }
1911
1912 // Is the slice along the memcpy dimension contiguous? If not, then memcpy_dim
1913 // needs to be wrapped around a loop as well.
1914 if (!memcpy_is_contiguous) {
1915 outer_dims.push_back(memcpy_dim);
1916 }
1917
1918 llvm_ir::IrArray target_array = GetIrArrayFor(slice);
1919
1920 const int64 num_outer_loops = outer_dims.size();
1921 llvm_ir::ForLoopNest loops(IrName(slice), &b_);
1922 std::vector<llvm::Value*> target_multi_index =
1923 loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice");
1924
1925 // Only the indices for the outer dimensions have been initialized in
1926 // target_index. The rest of the indices should get initialized to 0, since
1927 // for the rest of the dimensions the copy writes to the full dimension.
1928 std::replace(target_multi_index.begin(), target_multi_index.end(),
1929 static_cast<llvm::Value*>(nullptr),
1930 static_cast<llvm::Value*>(b_.getInt64(0)));
1931 llvm_ir::IrArray::Index target_index(target_multi_index, slice->shape(),
1932 b_.getInt64Ty());
1933
1934 if (num_outer_loops > 0) {
1935 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
1936 }
1937
1938 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
1939 const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice(
1940 /*operand_shape=*/operand->shape(), /*starts=*/slice->slice_starts(),
1941 /*strides=*/slice->slice_strides(), /*builder=*/&b_);
1942
1943 llvm::Value* memcpy_dest =
1944 target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest");
1945 llvm::Value* memcpy_source =
1946 source_array.EmitArrayElementAddress(source_index, &b_, "slice.source");
1947
1948 const int64 memcpy_elements =
1949 primitive_elements_per_logical_element * memcpy_logical_elements;
1950
1951 EmitTransferElements(memcpy_dest, memcpy_source, memcpy_elements,
1952 slice->shape().element_type(), target_array,
1953 source_array);
1954
1955 if (VLOG_IS_ON(2)) {
1956 const int64 memcpy_bytes =
1957 ShapeUtil::ByteSizeOf(logical_element_shape) * memcpy_elements;
1958 VLOG(2) << " emitted copy of " << memcpy_bytes << " bytes inside "
1959 << num_outer_loops << " loops";
1960 }
1961
1962 if (num_outer_loops > 0) {
1963 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
1964 }
1965
1966 return Status::OK();
1967 }
1968
HandleDynamicSlice(HloInstruction * dynamic_slice)1969 Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) {
1970 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
1971 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice));
1972 return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice);
1973 }
1974 return DefaultAction(dynamic_slice);
1975 }
1976
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)1977 Status IrEmitter::HandleDynamicUpdateSlice(
1978 HloInstruction* dynamic_update_slice) {
1979 auto update = dynamic_update_slice->operand(1);
1980 if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
1981 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1982 return EmitMemcpy(*update, *dynamic_update_slice);
1983 } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice,
1984 assignment_)) {
1985 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice));
1986 auto operands = GetIrArraysForOperandsOf(dynamic_update_slice);
1987 return llvm_ir::EmitDynamicUpdateSliceInPlace(
1988 operands, GetIrArrayFor(dynamic_update_slice),
1989 IrName(dynamic_update_slice, "in_place"), &b_);
1990 }
1991 return DefaultAction(dynamic_update_slice);
1992 }
1993
HandleRecv(HloInstruction * recv)1994 Status IrEmitter::HandleRecv(HloInstruction* recv) {
1995 // TODO(b/33942983): Support Send/Recv on CPU.
1996 return Unimplemented("Recv is not implemented on CPU.");
1997 }
1998
HandleRecvDone(HloInstruction * recv_done)1999 Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
2000 // TODO(b/33942983): Support Send/Recv on CPU.
2001 return Unimplemented("Recv-done is not implemented on CPU.");
2002 }
2003
HandlePad(HloInstruction * pad)2004 Status IrEmitter::HandlePad(HloInstruction* pad) {
2005 // CPU backend does not properly handle negative padding but this is ok
2006 // because negative padding should be removed by the algebraic simplifier.
2007 for (auto& padding_dimension : pad->padding_config().dimensions()) {
2008 if (padding_dimension.edge_padding_low() < 0 ||
2009 padding_dimension.edge_padding_high() < 0) {
2010 return InternalErrorStrCat(
2011 "Encountered negative padding in IrEmitter on CPU. "
2012 "This should have been eliminated at the HLO level. ",
2013 pad->ToString());
2014 }
2015 }
2016
2017 // First, fill in the padding value to all output elements.
2018 TF_RETURN_IF_ERROR(EmitTargetElementLoop(
2019 pad, "initialize",
2020 [this, pad](const llvm_ir::IrArray::Index& target_index) {
2021 const HloInstruction* padding_value = pad->operand(1);
2022 llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
2023 return Load(padding_value_addr);
2024 }));
2025
2026 // Create a loop to iterate over the operand elements and update the output
2027 // locations where the operand elements should be stored.
2028 llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_);
2029 const HloInstruction* operand = pad->operand(0);
2030 const llvm_ir::IrArray::Index operand_index =
2031 loops.AddLoopsForShape(operand->shape(), "operand");
2032
2033 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2034
2035 // Load an element from the operand.
2036 llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
2037 llvm::Value* operand_data =
2038 operand_array.EmitReadArrayElement(operand_index, &b_);
2039
2040 // Compute the output index the operand element should be assigned to.
2041 // output_index := edge_padding_low + operand_index * (interior_padding + 1)
2042 const PaddingConfig& padding_config = pad->padding_config();
2043 std::vector<llvm::Value*> output_multi_index;
2044 for (size_t i = 0; i < operand_index.size(); ++i) {
2045 llvm::Value* offset =
2046 Mul(operand_index[i],
2047 b_.getInt64(padding_config.dimensions(i).interior_padding() + 1));
2048 llvm::Value* index = Add(
2049 offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low()));
2050 output_multi_index.push_back(index);
2051 }
2052
2053 // Store the operand element to the computed output location.
2054 llvm_ir::IrArray output_array(GetIrArrayFor(pad));
2055 llvm_ir::IrArray::Index output_index(
2056 output_multi_index, output_array.GetShape(), operand_index.GetType());
2057 output_array.EmitWriteArrayElement(output_index, operand_data, &b_);
2058
2059 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2060 return Status::OK();
2061 }
2062
HandleFusion(HloInstruction * fusion)2063 Status IrEmitter::HandleFusion(HloInstruction* fusion) {
2064 auto* root = fusion->fused_expression_root();
2065 if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
2066 VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
2067 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2068 FusedIrEmitter fused_emitter(&elemental_emitter);
2069 BindFusionArguments(fusion, &fused_emitter);
2070
2071 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2072 // Delegate to common implementation of fused in-place dynamic-update-slice.
2073 return llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
2074 fusion, GetIrArrayFor(fusion), &fused_emitter, &b_);
2075 } else if (fusion->IsLoopFusion()) {
2076 VLOG(3) << "HandleFusion kLoop";
2077 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
2078 FusedIrEmitter fused_emitter(&elemental_emitter);
2079 BindFusionArguments(fusion, &fused_emitter);
2080 TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
2081 fusion->fused_expression_root()));
2082 return EmitTargetElementLoop(fusion, generator);
2083 } else if (fusion->IsOutputFusion()) {
2084 VLOG(3) << "HandleFusion kOutput";
2085 int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1;
2086 const HloInstruction* dot = root->operand(dot_op_index);
2087 CHECK_EQ(dot->opcode(), HloOpcode::kDot)
2088 << dot->ToString() << " "
2089 << fusion->fused_instructions_computation()->ToString();
2090
2091 int64 dot_lhs_param_number = dot->operand(0)->parameter_number();
2092 int64 dot_rhs_param_number = dot->operand(1)->parameter_number();
2093 int64 addend_param_number =
2094 root->operand(1 - dot_op_index)->parameter_number();
2095
2096 Shape target_shape = fusion->shape();
2097 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
2098 llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
2099
2100 llvm_ir::IrArray lhs_array(
2101 GetIrArrayFor(fusion->operand(dot_lhs_param_number)));
2102 llvm_ir::IrArray rhs_array(
2103 GetIrArrayFor(fusion->operand(dot_rhs_param_number)));
2104 llvm_ir::IrArray addend_array(
2105 GetIrArrayFor(fusion->operand(addend_param_number)));
2106
2107 TF_RETURN_IF_ERROR(EmitDotOperation(
2108 *dot, target_array, lhs_array, rhs_array, &addend_array,
2109 GetExecutableRunOptionsArgument(), &b_, mlir_context_,
2110 hlo_module_config_, target_machine_features_));
2111 return Status::OK();
2112 } else {
2113 return Unimplemented("Fusion kind not implemented on CPU");
2114 }
2115 }
2116
HandleCall(HloInstruction * call)2117 Status IrEmitter::HandleCall(HloInstruction* call) {
2118 HloComputation* computation = call->to_apply();
2119 llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
2120
2121 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
2122
2123 if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
2124 // ParallelTaskAssignment assigned partitions, emit call to
2125 // ParallelForkJoin.
2126 std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
2127 {}, &b_, computation->name(),
2128 /*return_value_buffer=*/emitted_value_[call],
2129 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
2130 /*buffer_table_arg=*/GetBufferTableArgument(),
2131 /*profile_counters_arg=*/GetProfileCountersArgument());
2132
2133 HloInstruction* root = computation->root_instruction();
2134 TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin(
2135 call_args, root->shape(), root->outer_dimension_partitions(), &b_,
2136 call_ir_function, computation->name()));
2137 } else {
2138 EmitGlobalCall(*computation, computation->name());
2139 }
2140
2141 return Status::OK();
2142 }
2143
HandleSliceToDynamic(HloInstruction * hlo)2144 Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
2145 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2146 std::vector<llvm::Value*> dynamic_dims;
2147 int32 raw_data_size =
2148 ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
2149 llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
2150 llvm::Value* raw_buffer =
2151 b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
2152 for (int64 i = 1; i < hlo->operand_count(); ++i) {
2153 const int64 dim_index = i - 1;
2154 llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
2155 llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
2156
2157 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2158 b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2159 b_.CreateStore(dyn_dim_size,
2160 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
2161 dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2162 /*isSigned=*/true,
2163 "i64_dyn_dim_size"));
2164 }
2165
2166 llvm_ir::IrArray data_array = GetIrArrayFor(hlo);
2167 // Pseudo code for sliceToDynamic:
2168 //
2169 // for (index i in dynamic_dim)
2170 // dest_index = delinearize(linearize(i, dynamic_dim), static_dim)
2171 // dest[dest_index] = source[i]
2172 auto loop_body_emitter =
2173 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2174 llvm::Value* source_element =
2175 GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_);
2176 llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2177 // Delinearize the index based on the static shape.
2178 llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(),
2179 &b_);
2180 data_array.EmitWriteArrayElement(dest_index, source_element, &b_);
2181 return Status::OK();
2182 };
2183 return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(),
2184 dynamic_dims, &b_)
2185 .EmitLoop(IrName(hlo));
2186 }
2187
HandlePadToStatic(HloInstruction * hlo)2188 Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
2189 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2190
2191 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
2192 assignment_.GetUniqueSlice(hlo, {0}));
2193 std::vector<llvm::Value*> dynamic_dims;
2194 std::vector<llvm::Value*> tuple_operand_ptrs;
2195 const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
2196 const Shape& input_shape = hlo->operand(0)->shape();
2197 llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
2198 llvm_ir::IrArray data_array(data_address, data_shape);
2199 llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
2200 llvm::Value* raw_buffer =
2201 b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
2202 int64 raw_data_size =
2203 ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
2204
2205 // Put a placeholder for the data array's pointer
2206 tuple_operand_ptrs.push_back(data_array.GetBasePointer());
2207 // PadToStatic has a dynamic tensor as input and variadic size of outputs:
2208 // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
2209 // Dynamic dimension sizes starts from output index 1.
2210 for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
2211 // Read from the metadata section of the dynamic input (operand 0).
2212 const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
2213 TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
2214 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
2215 assignment_.GetUniqueSlice(hlo, {i}));
2216 llvm::Value* dest_dim_size_address =
2217 EmitBufferPointer(dim_size_slice, data_shape);
2218 const int64 dim_index = i - 1;
2219 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
2220 b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
2221 llvm::Value* dyn_dim_size = b_.CreateLoad(
2222 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
2223 "dyn_dim_size");
2224 b_.CreateStore(dyn_dim_size,
2225 b_.CreateBitCast(dest_dim_size_address,
2226 b_.getInt32Ty()->getPointerTo()));
2227 dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(),
2228 /*isSigned=*/true,
2229 "i64_dyn_dim_size"));
2230 tuple_operand_ptrs.push_back(dest_dim_size_address);
2231 }
2232
2233 // Pseudo code for padToStatic:
2234 //
2235 // for (index i in dynamic_dim)
2236 // source_index = delinearize(inearize(i, dynamic_dim), static_dim)
2237 // dest[i] = source[source_index]
2238 auto loop_body_emitter =
2239 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
2240 llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_);
2241 llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_);
2242 llvm::Value* source_element =
2243 GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_);
2244 data_array.EmitWriteArrayElement(array_index, source_element, &b_);
2245 return Status::OK();
2246 };
2247 TF_RETURN_IF_ERROR(
2248 llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_)
2249 .EmitLoop(IrName(hlo)));
2250
2251 // Emit static tensor and dynamic sizes as one tuple.
2252 llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
2253 return Status::OK();
2254 }
2255
HandleTopK(HloInstruction * hlo)2256 Status IrEmitter::HandleTopK(HloInstruction* hlo) {
2257 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
2258 const HloInstruction* input = hlo->operand(0);
2259 const int64 k = hlo->shape().tuple_shapes(0).dimensions().back();
2260 const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
2261 TF_RET_CHECK(input->shape().element_type() == F32);
2262 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2263 hlo->shape().tuple_shapes(0).layout()));
2264 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
2265 hlo->shape().tuple_shapes(1).layout()));
2266 TF_RET_CHECK(
2267 LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
2268
2269 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
2270 assignment_.GetUniqueSlice(hlo->operand(0), {}));
2271 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice,
2272 assignment_.GetUniqueSlice(hlo, {0}));
2273 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice,
2274 assignment_.GetUniqueSlice(hlo, {1}));
2275 llvm::Value* values_ptr =
2276 EmitBufferPointer(values_slice, hlo->operand(0)->shape());
2277 llvm::Value* out_values_ptr =
2278 EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0));
2279 llvm::Value* out_indices_ptr =
2280 EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1));
2281 EmitCallToFunc(
2282 runtime::kTopKF32SymbolName,
2283 {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1),
2284 b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k),
2285 BitCast(values_ptr, b_.getFloatTy()->getPointerTo()),
2286 BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()),
2287 BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())},
2288 b_.getVoidTy());
2289
2290 llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr},
2291 &b_);
2292 return Status::OK();
2293 }
2294
HandleCustomCall(HloInstruction * custom_call)2295 Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
2296 if (custom_call->custom_call_target() == "PadToStatic") {
2297 return HandlePadToStatic(custom_call);
2298 }
2299 if (custom_call->custom_call_target() == "SliceToDynamic") {
2300 return HandleSliceToDynamic(custom_call);
2301 }
2302 if (custom_call->custom_call_target() == "TopK") {
2303 return HandleTopK(custom_call);
2304 }
2305 absl::Span<HloInstruction* const> operands(custom_call->operands());
2306 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2307 llvm::AllocaInst* operands_alloca =
2308 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2309 i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_);
2310 for (size_t i = 0; i < operands.size(); ++i) {
2311 const HloInstruction* operand = operands[i];
2312 llvm::Value* operand_as_i8ptr =
2313 PointerCast(GetEmittedValueFor(operand), i8_ptr_type);
2314 llvm::Value* slot_in_operands_alloca =
2315 InBoundsGEP(operands_alloca, {b_.getInt64(i)});
2316 Store(operand_as_i8ptr, slot_in_operands_alloca);
2317 }
2318 if (emit_code_for_msan_) {
2319 // Mark the alloca as initialized for msan. The buffer gets read by the
2320 // custom callee, which might be msan-instrumented.
2321 // TODO(b/66051036): Run the msan instrumentation pass instead.
2322 const llvm::DataLayout& dl = module_->getDataLayout();
2323 llvm::Type* intptr_type = b_.getIntPtrTy(dl);
2324 EmitCallToFunc(
2325 "__msan_unpoison",
2326 {PointerCast(operands_alloca, i8_ptr_type),
2327 llvm::ConstantInt::get(
2328 intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)},
2329 b_.getVoidTy());
2330 }
2331
2332 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
2333 // Write the tuple table if the output is a tuple.
2334 if (custom_call->shape().IsTuple()) {
2335 std::vector<llvm::Value*> base_ptrs;
2336 for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
2337 ++i) {
2338 const Shape& elem_shape =
2339 ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
2340 TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented";
2341 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
2342 assignment_.GetUniqueSlice(custom_call, {i}));
2343 llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
2344 base_ptrs.push_back(addr);
2345 }
2346 llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_);
2347 }
2348 auto* output_address_arg =
2349 PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
2350
2351 EmitCallToFunc(custom_call->custom_call_target(),
2352 {output_address_arg, operands_alloca}, b_.getVoidTy());
2353
2354 return Status::OK();
2355 }
2356
HandleWhile(HloInstruction * xla_while)2357 Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
2358 // Precondition: Condition computation must return a scalar bool.
2359 HloComputation* condition = xla_while->while_condition();
2360 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2361 condition->root_instruction()->shape().element_type() == PRED)
2362 << "While condition computation must return bool; got: "
2363 << ShapeUtil::HumanString(condition->root_instruction()->shape());
2364 // Check that all while-related buffers share an allocation slice.
2365 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2366 xla_while->shape(),
2367 [this, &xla_while](const Shape& /*subshape*/,
2368 const ShapeIndex& index) -> Status {
2369 auto check = [this](const HloInstruction* a, const HloInstruction* b,
2370 const ShapeIndex& index) {
2371 const BufferAllocation::Slice slice_a =
2372 assignment_.GetUniqueSlice(a, index).ConsumeValueOrDie();
2373 const BufferAllocation::Slice slice_b =
2374 assignment_.GetUniqueSlice(b, index).ConsumeValueOrDie();
2375 if (slice_a != slice_b) {
2376 return InternalError(
2377 "instruction %s %s does not share slice with "
2378 "instruction %s %s",
2379 a->ToString(), slice_a.ToString(), b->ToString(),
2380 slice_b.ToString());
2381 }
2382 return Status::OK();
2383 };
2384 TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
2385 TF_RETURN_IF_ERROR(check(
2386 xla_while, xla_while->while_condition()->parameter_instruction(0),
2387 index));
2388 TF_RETURN_IF_ERROR(
2389 check(xla_while, xla_while->while_body()->parameter_instruction(0),
2390 index));
2391 TF_RETURN_IF_ERROR(check(
2392 xla_while, xla_while->while_body()->root_instruction(), index));
2393 return Status::OK();
2394 }));
2395
2396 // Set emitted value to that of 'init' with which it shares an allocation.
2397 const HloInstruction* init = xla_while->operand(0);
2398 emitted_value_[xla_while] = GetEmittedValueFor(init);
2399
2400 // Generating:
2401 // while (Condition(while_result)) {
2402 // // CopyInsertion pass inserts copies which enable 'while_result' to
2403 // // be passed back in as 'Body' parameter.
2404 // while_result = Body(while_result); // Insert
2405 // }
2406
2407 // Terminates the current block with a branch to a while header.
2408 llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
2409 module_->getContext(), IrName(xla_while, "header"),
2410 compute_function_->function());
2411 Br(header_bb);
2412 b_.SetInsertPoint(header_bb);
2413
2414 // Calls the condition function to determine whether to proceed with the
2415 // body. It must return a bool, so use the scalar call form.
2416 EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
2417 llvm::Value* while_predicate = ICmpNE(
2418 Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
2419 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
2420
2421 // Branches to the body or to the while exit depending on the condition.
2422 llvm::BasicBlock* body_bb =
2423 llvm::BasicBlock::Create(module_->getContext(), IrName(xla_while, "body"),
2424 compute_function_->function());
2425 llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
2426 module_->getContext(), IrName(xla_while, "exit"));
2427 CondBr(while_predicate, body_bb, exit_bb);
2428
2429 // Calls the body function from the body block.
2430 b_.SetInsertPoint(body_bb);
2431
2432 // Calls the body function.
2433 EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
2434
2435 // Finishes with a branch back to the header.
2436 Br(header_bb);
2437
2438 // Adds the exit block to the function and sets the insert point there.
2439 compute_function_->function()->getBasicBlockList().push_back(exit_bb);
2440 b_.SetInsertPoint(exit_bb);
2441
2442 return Status::OK();
2443 }
2444
EmitFastConcatenate(HloInstruction * concatenate,absl::Span<HloInstruction * const> operands,string * failure_reason)2445 StatusOr<bool> IrEmitter::EmitFastConcatenate(
2446 HloInstruction* concatenate, absl::Span<HloInstruction* const> operands,
2447 string* failure_reason) {
2448 if (ShouldEmitParallelLoopFor(*concatenate)) {
2449 *failure_reason =
2450 "cannot generate memcpy-based concat for the parallel CPU backend";
2451 return false;
2452 }
2453
2454 const Shape& output_shape = concatenate->shape();
2455 for (auto* op : operands) {
2456 if (!LayoutUtil::Equal(op->shape().layout(), output_shape.layout())) {
2457 *failure_reason = "operand has mismatching layouts";
2458 return false;
2459 }
2460 }
2461
2462 // We split the dimensions into three categories: the dimension over which we
2463 // are concatenating (concat_dim), the dimensions that are minor to it
2464 // (inner_dims) and the dimensions that are major to it (outer_dims).
2465
2466 int64 concat_dim = concatenate->dimensions(0);
2467 const Layout& output_layout = output_shape.layout();
2468 auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
2469 auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
2470
2471 std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
2472 std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
2473 output_min2maj.end());
2474
2475 llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
2476
2477 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate));
2478 llvm_ir::IrArray target_array = GetIrArrayFor(concatenate);
2479
2480 llvm_ir::ForLoopNest loops(IrName(concatenate), &b_);
2481 std::vector<llvm::Value*> target_multi_index =
2482 loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat");
2483 std::replace(target_multi_index.begin(), target_multi_index.end(),
2484 static_cast<llvm::Value*>(nullptr),
2485 static_cast<llvm::Value*>(b_.getInt64(0)));
2486 llvm_ir::IrArray::Index target_index(target_multi_index, output_shape,
2487 b_.getInt64Ty());
2488
2489 if (!outer_dims.empty()) {
2490 SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
2491 }
2492
2493 PrimitiveType primitive_type = output_shape.element_type();
2494 unsigned primitive_type_size =
2495 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2496
2497 // Contiguous subregions from each operand to the concatenate contribute to a
2498 // contiguous subregion in the target buffer starting at target_region_begin.
2499 llvm::Value* target_region_begin = BitCast(
2500 target_array.EmitArrayElementAddress(target_index, &b_, "target_region"),
2501 i8_ptr_type);
2502 int64 byte_offset_into_target_region = 0;
2503
2504 int64 inner_dims_product =
2505 std::accumulate(inner_dims.begin(), inner_dims.end(), 1l,
2506 [&](int64 product, int64 inner_dim) {
2507 return product * output_shape.dimensions(inner_dim);
2508 });
2509
2510 // For each operand, emit a memcpy from the operand to the target of size
2511 // equal to the product of inner dimensions.
2512 for (HloInstruction* operand : operands) {
2513 const Shape& input_shape = operand->shape();
2514 llvm_ir::IrArray source_array = GetIrArrayFor(operand);
2515 llvm_ir::IrArray::Index source_index(target_multi_index, operand->shape(),
2516 b_.getInt64Ty());
2517 llvm::Value* copy_source_address = BitCast(
2518 source_array.EmitArrayElementAddress(source_index, &b_, "src_addr"),
2519 i8_ptr_type);
2520
2521 llvm::Value* copy_target_address =
2522 GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region));
2523
2524 EmitTransferElements(
2525 copy_target_address, copy_source_address,
2526 inner_dims_product * input_shape.dimensions(concat_dim), primitive_type,
2527 target_array, source_array);
2528
2529 byte_offset_into_target_region += inner_dims_product *
2530 input_shape.dimensions(concat_dim) *
2531 primitive_type_size;
2532 }
2533
2534 if (!outer_dims.empty()) {
2535 SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
2536 }
2537
2538 return true;
2539 }
2540
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments)2541 llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
2542 absl::Span<llvm::Value* const> arguments) {
2543 llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
2544 std::vector<llvm::Value*> call_args;
2545 call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
2546 absl::c_copy(arguments, std::back_inserter(call_args));
2547 return b_.CreateCall(
2548 b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
2549 "printf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty},
2550 /*isVarArg=*/true)),
2551 call_args);
2552 }
2553
EmitCallToFunc(std::string func_name,const std::vector<llvm::Value * > & arguments,llvm::Type * return_type,bool does_not_throw,bool only_accesses_arg_memory,bool only_accesses_inaccessible_mem_or_arg_mem)2554 llvm::Value* IrEmitter::EmitCallToFunc(
2555 std::string func_name, const std::vector<llvm::Value*>& arguments,
2556 llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
2557 bool only_accesses_inaccessible_mem_or_arg_mem) {
2558 std::vector<llvm::Type*> types;
2559 types.reserve(arguments.size());
2560 absl::c_transform(arguments, std::back_inserter(types),
2561 [&](llvm::Value* val) { return val->getType(); });
2562 llvm::FunctionType* func_type =
2563 llvm::FunctionType::get(return_type, types, /*isVarArg=*/false);
2564 auto func = llvm::dyn_cast<llvm::Function>(
2565 module_->getOrInsertFunction(func_name, func_type).getCallee());
2566 func->setCallingConv(llvm::CallingConv::C);
2567 if (does_not_throw) {
2568 func->setDoesNotThrow();
2569 }
2570 if (only_accesses_arg_memory) {
2571 func->setOnlyAccessesArgMemory();
2572 }
2573 if (only_accesses_inaccessible_mem_or_arg_mem) {
2574 func->setOnlyAccessesInaccessibleMemOrArgMem();
2575 }
2576 return b_.CreateCall(func, arguments);
2577 }
2578
EmitTransferElements(llvm::Value * target,llvm::Value * source,int64 element_count,PrimitiveType primitive_type,const llvm_ir::IrArray & target_array,const llvm_ir::IrArray & source_array)2579 void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
2580 int64 element_count,
2581 PrimitiveType primitive_type,
2582 const llvm_ir::IrArray& target_array,
2583 const llvm_ir::IrArray& source_array) {
2584 unsigned primitive_type_size =
2585 ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
2586 llvm::Align element_alignment(tensorflow::MathUtil::GCD<unsigned>(
2587 primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)));
2588 llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
2589 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
2590
2591 if (element_count == 1) {
2592 auto* load_instruction =
2593 AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment);
2594 source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction);
2595 auto* store_instruction =
2596 AlignedStore(load_instruction, BitCast(target, primitive_ptr_type),
2597 element_alignment);
2598 target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
2599 } else {
2600 auto* memcpy_instruction = b_.CreateMemCpy(
2601 target, /*DstAlign=*/llvm::Align(element_alignment), source,
2602 /*SrcAlign=*/llvm::Align(element_alignment),
2603 element_count * primitive_type_size);
2604
2605 // The memcpy does the load and the store internally. The aliasing related
2606 // metadata has to reflect that.
2607 std::map<int, llvm::MDNode*> merged_metadata =
2608 llvm_ir::MergeMetadata(&module_->getContext(), source_array.metadata(),
2609 target_array.metadata());
2610 for (const auto& kind_md_pair : merged_metadata) {
2611 memcpy_instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
2612 }
2613 }
2614 }
2615
HandleConcatenate(HloInstruction * concatenate)2616 Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
2617 absl::Span<HloInstruction* const> operands(concatenate->operands());
2618 string failure_reason;
2619 TF_ASSIGN_OR_RETURN(
2620 bool successful,
2621 EmitFastConcatenate(concatenate, operands, &failure_reason));
2622 if (successful) {
2623 VLOG(1) << "Emitted fast concatenate for " << concatenate->ToString();
2624 return Status::OK();
2625 }
2626
2627 VLOG(1) << "Could not emit fast concatenate for " << concatenate->ToString()
2628 << ": " << failure_reason;
2629
2630 return DefaultAction(concatenate);
2631 }
2632
HandleConditional(HloInstruction * conditional)2633 Status IrEmitter::HandleConditional(HloInstruction* conditional) {
2634 auto branch_index = conditional->operand(0);
2635 int num_branches = conditional->branch_count();
2636 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) &&
2637 (branch_index->shape().element_type() == PRED ||
2638 branch_index->shape().element_type() == S32))
2639 << "Branch index on a conditional must be scalar bool or int32; got: "
2640 << ShapeUtil::HumanString(branch_index->shape());
2641
2642 for (int b = 0; b < num_branches; ++b) {
2643 HloComputation* br_computation = conditional->branch_computation(b);
2644 TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(),
2645 br_computation->root_instruction()->shape()))
2646 << "Shape of conditional should be same as the shape of the " << b
2647 << "th branch computation; got: "
2648 << ShapeUtil::HumanString(conditional->shape()) << " and "
2649 << ShapeUtil::HumanString(br_computation->root_instruction()->shape());
2650 }
2651
2652 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
2653
2654 if (branch_index->shape().element_type() == PRED) {
2655 // Emit an if-else to LLVM:
2656 // if (pred)
2657 // cond_result = true_computation(true_operand)
2658 // else
2659 // cond_result = false_computation(false_operand)
2660 llvm::LoadInst* pred_value = Load(
2661 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value");
2662 llvm::Value* pred_cond =
2663 ICmpNE(pred_value,
2664 llvm::ConstantInt::get(
2665 llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
2666 "boolean_predicate");
2667 llvm_ir::LlvmIfData if_data =
2668 llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
2669
2670 SetToFirstInsertPoint(if_data.true_block, &b_);
2671 EmitGlobalCall(*conditional->branch_computation(0),
2672 IrName(conditional, "_true"));
2673
2674 SetToFirstInsertPoint(if_data.false_block, &b_);
2675 EmitGlobalCall(*conditional->branch_computation(1),
2676 IrName(conditional, "_false"));
2677
2678 SetToFirstInsertPoint(if_data.after_block, &b_);
2679 return Status::OK();
2680 }
2681 // We emit a switch statement to LLVM:
2682 // switch (branch_index) {
2683 // default:
2684 // result = branch_computations[num_branches-1](operands[num_branches-1]);
2685 // break;
2686 // case 0:
2687 // result = branch_computations[0](operands[0]); break;
2688 // case 1:
2689 // result = branch_computations[1](operands[1]); break;
2690 // ...
2691 // case [[num_branches-2]]:
2692 // result = branch_computations[num_branches-2](operands[num_branches-2]);
2693 // break;
2694 // }
2695 llvm::LoadInst* branch_index_value = Load(
2696 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value");
2697
2698 auto case_block = b_.GetInsertBlock();
2699 llvm::BasicBlock* after_block;
2700 // Add a terminator to the case block, if necessary.
2701 if (case_block->getTerminator() == nullptr) {
2702 after_block = llvm_ir::CreateBasicBlock(nullptr, "case-after", &b_);
2703 b_.SetInsertPoint(case_block);
2704 b_.CreateBr(after_block);
2705 } else {
2706 after_block =
2707 case_block->splitBasicBlock(b_.GetInsertPoint(), "case-after");
2708 }
2709 // Our basic block should now end with an unconditional branch. Remove it;
2710 // we're going to replace it with a switch based branch.
2711 case_block->getTerminator()->eraseFromParent();
2712
2713 // Lower the default branch computation.
2714 auto default_block = llvm_ir::CreateBasicBlock(nullptr, "case-default", &b_);
2715 b_.SetInsertPoint(default_block);
2716 EmitGlobalCall(*conditional->branch_computation(num_branches - 1),
2717 IrName(conditional, "_default"));
2718 b_.CreateBr(after_block);
2719
2720 // Prepare the switch (branch_index) { ... } instruction.
2721 b_.SetInsertPoint(case_block);
2722 llvm::SwitchInst* case_inst =
2723 b_.CreateSwitch(branch_index_value, default_block, num_branches - 1);
2724 // Lower each branch's computation.
2725 for (int b = 0; b < num_branches - 1; ++b) { // last branch is default
2726 // Lower the case b: { ... ; break; } computation.
2727 auto branch_block =
2728 llvm_ir::CreateBasicBlock(nullptr, absl::StrCat("case-branch", b), &b_);
2729 b_.SetInsertPoint(branch_block);
2730 EmitGlobalCall(*conditional->branch_computation(b),
2731 IrName(conditional, absl::StrCat("_branch", b)));
2732 b_.CreateBr(after_block);
2733 case_inst->addCase(b_.getInt32(b), branch_block);
2734 }
2735
2736 SetToFirstInsertPoint(after_block, &b_);
2737 return Status::OK();
2738 }
2739
HandleAfterAll(HloInstruction * after_all)2740 Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
2741 TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
2742 // No code to generate, but we need to emit an address for book-keeping.
2743 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
2744 return Status::OK();
2745 }
2746
HandleAddDependency(HloInstruction * add_dependency)2747 Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
2748 // AddDedendency just forwards its zero-th operand.
2749 emitted_value_[add_dependency] =
2750 GetEmittedValueFor(add_dependency->operand(0));
2751 return Status::OK();
2752 }
2753
HandleRng(HloInstruction * rng)2754 Status IrEmitter::HandleRng(HloInstruction* rng) {
2755 return Unimplemented("Rng should be expanded for CPU.");
2756 }
2757
HandleRngGetAndUpdateState(HloInstruction * rng_state)2758 Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) {
2759 VLOG(2) << "RngGetAndUpdateState: " << rng_state->ToString();
2760 llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
2761 Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
2762 &b_);
2763
2764 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(rng_state));
2765 llvm::Value* address = GetEmittedValueFor(rng_state);
2766
2767 // The buffer has an array type while the value has a i128. Cast the
2768 // buffer to i128 type to store the value.
2769 address = BitCast(address, llvm::PointerType::get(
2770 old_state->getType()->getScalarType(),
2771 address->getType()->getPointerAddressSpace()));
2772 llvm::StoreInst* store = Store(old_state, address);
2773 store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType(
2774 rng_state->shape().element_type())));
2775
2776 return Status::OK();
2777 }
2778
FinishVisit(HloInstruction * root)2779 Status IrEmitter::FinishVisit(HloInstruction* root) {
2780 // When this method is called, we should have already emitted an IR value for
2781 // the root (return) op. The IR value holds the address of the buffer holding
2782 // the value. If the root is a constant or parameter, we perform a memcpy from
2783 // this buffer to the retval buffer of the computation. Otherwise, there's
2784 // nothing to do since the result was already written directly into the output
2785 // buffer.
2786 VLOG(2) << "FinishVisit root: " << root->ToString();
2787 if (root->opcode() == HloOpcode::kOutfeed) {
2788 VLOG(2) << " outfeed with value: "
2789 << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0)));
2790 } else {
2791 VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root));
2792 }
2793
2794 auto record_complete_computation = [&](llvm::Value* prof_counter) {
2795 if (prof_counter) {
2796 profiling_state_.RecordCompleteComputation(&b_, prof_counter);
2797 }
2798 };
2799
2800 // For the entry computation this increment is cumulative of embedded
2801 // computations since it includes cycles spent in computations invoked by
2802 // While, Call etc.
2803 record_complete_computation(GetProfileCounterFor(*root->parent()));
2804 return Status::OK();
2805 }
2806
2807 template <typename T>
GetProfileCounterCommon(const T & hlo,const std::unordered_map<const T *,int64> & profile_index_map)2808 llvm::Value* IrEmitter::GetProfileCounterCommon(
2809 const T& hlo,
2810 const std::unordered_map<const T*, int64>& profile_index_map) {
2811 auto it = profile_index_map.find(&hlo);
2812 if (it == profile_index_map.end()) {
2813 return nullptr;
2814 }
2815
2816 int64 prof_counter_idx = it->second;
2817 string counter_name = IrName("prof_counter", hlo.name());
2818 return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx),
2819 counter_name);
2820 }
2821
GetProfileCounterFor(const HloInstruction & instruction)2822 llvm::Value* IrEmitter::GetProfileCounterFor(
2823 const HloInstruction& instruction) {
2824 return GetProfileCounterCommon<HloInstruction>(instruction,
2825 instruction_to_profile_idx_);
2826 }
2827
GetProfileCounterFor(const HloComputation & computation)2828 llvm::Value* IrEmitter::GetProfileCounterFor(
2829 const HloComputation& computation) {
2830 return GetProfileCounterCommon<HloComputation>(computation,
2831 computation_to_profile_idx_);
2832 }
2833
UpdateProfileCounter(llvm::IRBuilder<> * b,llvm::Value * prof_counter,llvm::Value * cycle_end,llvm::Value * cycle_start)2834 void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b,
2835 llvm::Value* prof_counter,
2836 llvm::Value* cycle_end,
2837 llvm::Value* cycle_start) {
2838 auto* cycle_diff = b->CreateSub(cycle_end, cycle_start);
2839 llvm::LoadInst* old_cycle_count =
2840 b->CreateLoad(prof_counter, "old_cycle_count");
2841 auto* new_cycle_count =
2842 b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
2843 b->CreateStore(new_cycle_count, prof_counter);
2844 }
2845
ReadCycleCounter(llvm::IRBuilder<> * b)2846 llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) {
2847 llvm::Module* module = b->GetInsertBlock()->getModule();
2848 if (!use_rdtscp_) {
2849 llvm::Function* func_llvm_readcyclecounter =
2850 llvm::Intrinsic::getDeclaration(module,
2851 llvm::Intrinsic::readcyclecounter);
2852 return b->CreateCall(func_llvm_readcyclecounter);
2853 }
2854 llvm::Function* func_llvm_x86_rdtscp =
2855 llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
2856 llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp);
2857 return b->CreateExtractValue(rdtscp_call, {0});
2858 }
2859
RecordCycleStart(llvm::IRBuilder<> * b,HloInstruction * hlo)2860 void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b,
2861 HloInstruction* hlo) {
2862 auto* cycle_start = ReadCycleCounter(b);
2863 cycle_start->setName(IrName(hlo, "cycle_start"));
2864 cycle_starts_[hlo] = cycle_start;
2865 if (first_read_cycle_start_ == nullptr) {
2866 first_read_cycle_start_ = cycle_start;
2867 }
2868 }
2869
RecordCycleDelta(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * prof_counter)2870 void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b,
2871 HloInstruction* hlo,
2872 llvm::Value* prof_counter) {
2873 auto* cycle_end = ReadCycleCounter(b);
2874 cycle_end->setName(IrName(hlo, "cycle_end"));
2875 auto* cycle_start = cycle_starts_[hlo];
2876 UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start);
2877 last_read_cycle_end_ = cycle_end;
2878 }
2879
RecordCompleteComputation(llvm::IRBuilder<> * b,llvm::Value * prof_counter)2880 void IrEmitter::ProfilingState::RecordCompleteComputation(
2881 llvm::IRBuilder<>* b, llvm::Value* prof_counter) {
2882 if (last_read_cycle_end_ && first_read_cycle_start_) {
2883 UpdateProfileCounter(b, prof_counter, last_read_cycle_end_,
2884 first_read_cycle_start_);
2885 }
2886 }
2887
EmitTracingStart(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2888 void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b,
2889 HloInstruction* hlo,
2890 llvm::Value* run_options) {
2891 if (!enabled_) {
2892 return;
2893 }
2894
2895 llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo();
2896 llvm::Type* void_ptr_type =
2897 int8_ptr_type; // LLVM does not have a void*, we use an int8* instead.
2898 llvm::FunctionType* fn_type =
2899 llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type},
2900 /*isVarArg=*/false);
2901
2902 llvm::Function* function = b->GetInsertBlock()->getParent();
2903 llvm::Module* module = function->getParent();
2904 const char* fn_name = runtime::kTracingStartSymbolName;
2905 llvm::FunctionCallee trace_func =
2906 module->getOrInsertFunction(fn_name, fn_type);
2907 if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2908 fn->setCallingConv(llvm::CallingConv::C);
2909 fn->setDoesNotThrow();
2910 fn->setOnlyAccessesArgMemory();
2911 }
2912 auto* hlo_name = b->CreateGlobalStringPtr(hlo->name());
2913 auto* activity_id =
2914 b->CreateCall(trace_func, {b->CreateBitCast(run_options, void_ptr_type),
2915 b->CreateBitCast(hlo_name, int8_ptr_type)});
2916 activity_id->setName(IrName(hlo, "activity_id"));
2917 activity_ids_[hlo] = activity_id;
2918 }
2919
EmitTracingEnd(llvm::IRBuilder<> * b,HloInstruction * hlo,llvm::Value * run_options)2920 void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b,
2921 HloInstruction* hlo,
2922 llvm::Value* run_options) {
2923 if (!enabled_) {
2924 return;
2925 }
2926
2927 llvm::Type* void_ptr_type =
2928 b->getInt8Ty()->getPointerTo(); // LLVM does not have a void*, we use an
2929 // int8* instead.
2930 llvm::FunctionType* fn_type =
2931 llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()},
2932 /*isVarArg=*/false);
2933
2934 llvm::Function* function = b->GetInsertBlock()->getParent();
2935 llvm::Module* module = function->getParent();
2936 const char* fn_name = runtime::kTracingEndSymbolName;
2937 llvm::FunctionCallee trace_func =
2938 module->getOrInsertFunction(fn_name, fn_type);
2939 if (auto* fn = llvm::dyn_cast<llvm::Function>(trace_func.getCallee())) {
2940 fn->setCallingConv(llvm::CallingConv::C);
2941 fn->setDoesNotThrow();
2942 fn->setOnlyAccessesArgMemory();
2943 }
2944 auto* activity_id = activity_ids_.at(hlo);
2945 b->CreateCall(trace_func,
2946 {b->CreateBitCast(run_options, void_ptr_type), activity_id});
2947 }
2948
2949 namespace {
IsHloVeryCheap(const HloInstruction * hlo)2950 bool IsHloVeryCheap(const HloInstruction* hlo) {
2951 return hlo->opcode() == HloOpcode::kBitcast ||
2952 hlo->opcode() == HloOpcode::kTuple ||
2953 hlo->opcode() == HloOpcode::kGetTupleElement ||
2954 hlo->opcode() == HloOpcode::kParameter ||
2955 hlo->opcode() == HloOpcode::kConstant;
2956 }
2957 } // namespace
2958
Preprocess(HloInstruction * hlo)2959 Status IrEmitter::Preprocess(HloInstruction* hlo) {
2960 VLOG(3) << "Visiting: " << hlo->ToString();
2961 // When profiling is enabled, trace the same HLOs that the profiler does.
2962 if (instruction_to_profile_idx_.count(hlo) ||
2963 (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
2964 tracing_state_.EmitTracingStart(&b_, hlo,
2965 GetExecutableRunOptionsArgument());
2966 profiling_state_.RecordCycleStart(&b_, hlo);
2967 }
2968 return Status::OK();
2969 }
2970
Postprocess(HloInstruction * hlo)2971 Status IrEmitter::Postprocess(HloInstruction* hlo) {
2972 if (auto* prof_counter = GetProfileCounterFor(*hlo)) {
2973 profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter);
2974 }
2975 // When profiling is enabled, trace the same HLOs that the profiler does.
2976 if (instruction_to_profile_idx_.count(hlo) ||
2977 (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) {
2978 tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument());
2979 }
2980 return Status::OK();
2981 }
2982
GetIrArrayFor(const HloInstruction * hlo)2983 llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) {
2984 llvm::Value* value_for_op = GetEmittedValueFor(hlo);
2985
2986 llvm_ir::IrArray array(value_for_op, hlo->shape());
2987 AddAliasingInformationToIrArray(*hlo, &array);
2988 return array;
2989 }
2990
GetIrArraysForOperandsOf(const HloInstruction * hlo)2991 std::vector<llvm_ir::IrArray> IrEmitter::GetIrArraysForOperandsOf(
2992 const HloInstruction* hlo) {
2993 std::vector<llvm_ir::IrArray> arrays;
2994 std::transform(
2995 hlo->operands().begin(), hlo->operands().end(),
2996 std::back_inserter(arrays),
2997 [&](const HloInstruction* operand) { return GetIrArrayFor(operand); });
2998 return arrays;
2999 }
3000
GetEmittedValueFor(const HloInstruction * hlo)3001 llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
3002 auto it = emitted_value_.find(hlo);
3003 if (it == emitted_value_.end()) {
3004 LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
3005 }
3006 return it->second;
3007 }
3008
IrShapeType(const Shape & shape)3009 llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
3010 return llvm_ir::ShapeToIrType(shape, module_);
3011 }
3012
GetProfileCountersArgument()3013 llvm::Value* IrEmitter::GetProfileCountersArgument() {
3014 return compute_function_->profile_counters_arg();
3015 }
3016
GetBufferTableArgument()3017 llvm::Value* IrEmitter::GetBufferTableArgument() {
3018 return compute_function_->buffer_table_arg();
3019 }
3020
GetExecutableRunOptionsArgument()3021 llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
3022 return compute_function_->exec_run_options_arg();
3023 }
3024
EmitThreadLocalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3025 llvm::Value* IrEmitter::EmitThreadLocalBufferPointer(
3026 const BufferAllocation::Slice& slice, const Shape& target_shape) {
3027 const BufferAllocation& allocation = *slice.allocation();
3028 llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
3029 auto param_it =
3030 computation_parameter_allocations_.find(slice.allocation()->index());
3031 if (param_it != computation_parameter_allocations_.end()) {
3032 int64 param_number = param_it->second;
3033 // We have to access the parameter at offset param_number in the params
3034 // array. The code generated here is equivalent to this C code:
3035 //
3036 // i8* param_address_untyped = params[param_number];
3037 // Param* param_address_typed = (Param*)param_address_untyped;
3038 //
3039 // Where Param is the actual element type of the underlying buffer (for
3040 // example, float for an XLA F32 element type).
3041 llvm::Value* params = compute_function_->parameters_arg();
3042 llvm::Value* param_address_offset =
3043 llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
3044 llvm::LoadInst* param_address_untyped = Load(param_address_offset);
3045
3046 if (!target_shape.IsOpaque()) {
3047 AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
3048 AttachDereferenceableMetadataForLoad(param_address_untyped,
3049 target_shape);
3050 }
3051 return param_address_untyped;
3052 }
3053
3054 // Thread-local allocations should only be assigned a single buffer.
3055 const auto& assigned_buffers = allocation.assigned_buffers();
3056 CHECK_EQ(1, assigned_buffers.size());
3057 const Shape& shape = assigned_buffers.begin()->first->shape();
3058
3059 std::pair<llvm::Function*, BufferAllocation::Slice> key = {
3060 compute_function_->function(), slice};
3061 auto buf_it = thread_local_buffers_.find(key);
3062 if (buf_it == thread_local_buffers_.end()) {
3063 llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3064 IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()),
3065 &b_, MinimumAlignmentForShape(target_shape));
3066 auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
3067 CHECK(it_inserted_pair.second);
3068 buf_it = it_inserted_pair.first;
3069 }
3070 return buf_it->second;
3071 }();
3072 return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo());
3073 }
3074
EmitGlobalBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3075 llvm::Value* IrEmitter::EmitGlobalBufferPointer(
3076 const BufferAllocation::Slice& slice, const Shape& target_shape) {
3077 const BufferAllocation& allocation = *slice.allocation();
3078 llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
3079 GetBufferTableArgument(), slice.index(), &b_);
3080 llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr);
3081 if (hlo_module_config_.debug_options()
3082 .xla_llvm_enable_invariant_load_metadata()) {
3083 tempbuf_address_base->setMetadata(
3084 llvm::LLVMContext::MD_invariant_load,
3085 llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
3086 }
3087 AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size());
3088 AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size());
3089
3090 llvm::Value* tempbuf_address_untyped = tempbuf_address_base;
3091 if (slice.offset() > 0) {
3092 // Adjust the address to account for the slice offset.
3093 tempbuf_address_untyped =
3094 InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
3095 }
3096 return BitCast(tempbuf_address_untyped,
3097 IrShapeType(target_shape)->getPointerTo());
3098 }
3099
EmitBufferPointer(const BufferAllocation::Slice & slice,const Shape & target_shape)3100 llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice,
3101 const Shape& target_shape) {
3102 if (slice.allocation()->is_thread_local()) {
3103 return EmitThreadLocalBufferPointer(slice, target_shape);
3104 } else if (slice.allocation()->is_constant()) {
3105 return BitCast(
3106 FindOrDie(constant_buffer_to_global_, slice.allocation()->index()),
3107 IrShapeType(target_shape)->getPointerTo());
3108 } else {
3109 return EmitGlobalBufferPointer(slice, target_shape);
3110 }
3111 }
3112
EmitTargetAddressForOp(const HloInstruction * op)3113 Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
3114 const Shape& target_shape = op->shape();
3115 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
3116 assignment_.GetUniqueTopLevelSlice(op));
3117 llvm::Value* addr = EmitBufferPointer(slice, target_shape);
3118 addr->setName(IrName(op));
3119 emitted_value_[op] = addr;
3120 return Status::OK();
3121 }
3122
EmitTargetElementLoop(HloInstruction * target_op,const llvm_ir::ElementGenerator & element_generator)3123 Status IrEmitter::EmitTargetElementLoop(
3124 HloInstruction* target_op,
3125 const llvm_ir::ElementGenerator& element_generator) {
3126 return EmitTargetElementLoop(target_op, /*desc=*/"", element_generator);
3127 }
3128
EmitTargetElementLoop(HloInstruction * target_op,absl::string_view desc,const llvm_ir::ElementGenerator & element_generator)3129 Status IrEmitter::EmitTargetElementLoop(
3130 HloInstruction* target_op, absl::string_view desc,
3131 const llvm_ir::ElementGenerator& element_generator) {
3132 VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
3133
3134 const Shape& target_shape = target_op->shape();
3135 TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op));
3136 llvm_ir::IrArray target_array = GetIrArrayFor(target_op);
3137
3138 if (target_shape.IsTuple() &&
3139 (target_op->opcode() == HloOpcode::kFusion ||
3140 target_op->opcode() == HloOpcode::kReduce ||
3141 target_op->opcode() == HloOpcode::kReduceWindow)) {
3142 // For multiple outputs fusion, we need to emit each operand and the root.
3143 TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
3144 std::vector<llvm_ir::IrArray> output_arrays;
3145 for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
3146 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
3147 assignment_.GetUniqueSlice(target_op, {i}));
3148 const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
3149 llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape);
3150 output_arrays.push_back(
3151 llvm_ir::IrArray(op_target_address, element_shape));
3152 }
3153 TF_RETURN_IF_ERROR(
3154 llvm_ir::LoopEmitter(element_generator, output_arrays, &b_)
3155 .EmitLoop(IrName(target_op)));
3156
3157 std::vector<llvm::Value*> tuple_operand_ptrs;
3158 for (int64 i = 0; i < output_arrays.size(); ++i) {
3159 tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
3160 }
3161 llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_);
3162
3163 } else {
3164 if (ShouldEmitParallelLoopFor(*target_op)) {
3165 // Emit code to read dynamic loop bounds from compute function argument.
3166 std::vector<std::pair<llvm::Value*, llvm::Value*>> dynamic_loop_bounds =
3167 compute_function_->GetDynamicLoopBounds();
3168 // Emit parallel loop with dynamic loop bounds for most-major dimensions.
3169 TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
3170 &dynamic_loop_bounds, &b_)
3171 .EmitLoop(IrName(target_op)));
3172 } else {
3173 TF_RETURN_IF_ERROR(
3174 llvm_ir::LoopEmitter(element_generator, target_array, &b_)
3175 .EmitLoop(IrName(target_op)));
3176 }
3177 }
3178 return Status::OK();
3179 }
3180
EmitMemcpy(const HloInstruction & source,const HloInstruction & destination)3181 Status IrEmitter::EmitMemcpy(const HloInstruction& source,
3182 const HloInstruction& destination) {
3183 llvm::Value* source_value = GetEmittedValueFor(&source);
3184 llvm::Value* destination_value = GetEmittedValueFor(&destination);
3185 int64 source_size = ByteSizeOf(source.shape());
3186 // TODO(b/63762267): Be more aggressive about specifying alignment.
3187 MemCpy(destination_value, /*DstAlign=*/llvm::Align(1), source_value,
3188 /*SrcAlign=*/llvm::Align(1), source_size);
3189 return Status::OK();
3190 }
3191
ElementTypesSameAndSupported(const HloInstruction & instruction,absl::Span<const HloInstruction * const> operands,absl::Span<const PrimitiveType> supported_types)3192 Status IrEmitter::ElementTypesSameAndSupported(
3193 const HloInstruction& instruction,
3194 absl::Span<const HloInstruction* const> operands,
3195 absl::Span<const PrimitiveType> supported_types) {
3196 for (auto operand : operands) {
3197 TF_RET_CHECK(
3198 ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
3199 }
3200
3201 TF_RET_CHECK(!operands.empty());
3202 PrimitiveType primitive_type = operands[0]->shape().element_type();
3203 if (!absl::c_linear_search(supported_types, primitive_type)) {
3204 return Unimplemented("unsupported operand type %s in op %s",
3205 PrimitiveType_Name(primitive_type),
3206 HloOpcodeString(instruction.opcode()));
3207 }
3208 return Status::OK();
3209 }
3210
DefaultAction(HloInstruction * hlo)3211 Status IrEmitter::DefaultAction(HloInstruction* hlo) {
3212 ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
3213 for (const HloInstruction* operand : hlo->operands()) {
3214 operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
3215 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3216 };
3217 }
3218 CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
3219 return EmitTargetElementLoop(
3220 hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
3221 }
3222
EmitScalarReturningThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3223 llvm::Value* IrEmitter::EmitScalarReturningThreadLocalCall(
3224 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3225 absl::string_view name) {
3226 std::vector<llvm::Value*> return_value =
3227 EmitThreadLocalCall(callee, parameters, name);
3228 CHECK_EQ(return_value.size(), 1);
3229 return return_value[0];
3230 }
3231
EmitThreadLocalCall(const HloComputation & callee,absl::Span<llvm::Value * const> parameters,absl::string_view name)3232 std::vector<llvm::Value*> IrEmitter::EmitThreadLocalCall(
3233 const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
3234 absl::string_view name) {
3235 CHECK(absl::c_binary_search(thread_local_computations_, &callee));
3236 const Shape& return_shape = callee.root_instruction()->shape();
3237 bool is_scalar_return = ShapeUtil::IsScalar(return_shape);
3238 bool is_tuple_of_scalars_return =
3239 return_shape.IsTuple() &&
3240 absl::c_all_of(return_shape.tuple_shapes(), [&](const Shape& shape) {
3241 return ShapeUtil::IsScalar(shape);
3242 });
3243 CHECK(is_scalar_return || is_tuple_of_scalars_return);
3244
3245 std::vector<llvm::Value*> parameter_addrs;
3246 for (llvm::Value* parameter : parameters) {
3247 CHECK(!parameter->getType()->isPointerTy());
3248 llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
3249 parameter->getType(), "arg_addr", &b_);
3250 Store(parameter, parameter_addr);
3251 parameter_addrs.push_back(parameter_addr);
3252 }
3253
3254 llvm::Type* return_value_buffer_type =
3255 llvm_ir::ShapeToIrType(return_shape, module_);
3256 std::string retval_alloca_name = absl::StrCat(name, "_return_value_addr");
3257 int retval_alignment =
3258 is_scalar_return
3259 ? MinimumAlignmentForPrimitiveType(return_shape.element_type())
3260 : 0;
3261 llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
3262 return_value_buffer_type, retval_alloca_name, &b_, retval_alignment);
3263
3264 std::vector<llvm::Value*> allocas_for_returned_scalars;
3265 if (is_scalar_return) {
3266 allocas_for_returned_scalars.push_back(return_value_buffer);
3267 } else {
3268 constexpr int max_tuple_size = 1000;
3269 CHECK_LT(return_shape.tuple_shapes_size(), max_tuple_size)
3270 << "Multivalue function can not return more than 1000 elements to avoid"
3271 << " stack smashing";
3272 allocas_for_returned_scalars =
3273 llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
3274 llvm_ir::IrArray tuple_array(return_value_buffer, return_shape);
3275
3276 EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
3277 }
3278
3279 Call(FindOrDie(emitted_functions_, &callee),
3280 GetArrayFunctionCallArguments(
3281 parameter_addrs, &b_, name,
3282 /*return_value_buffer=*/return_value_buffer,
3283 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3284 /*buffer_table_arg=*/
3285 llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
3286 /*profile_counters_arg=*/GetProfileCountersArgument()));
3287
3288 std::vector<llvm::Value*> returned_scalars;
3289 returned_scalars.reserve(allocas_for_returned_scalars.size());
3290 for (llvm::Value* addr : allocas_for_returned_scalars) {
3291 returned_scalars.push_back(Load(addr));
3292 }
3293 return returned_scalars;
3294 }
3295
EmitGlobalCall(const HloComputation & callee,absl::string_view name)3296 void IrEmitter::EmitGlobalCall(const HloComputation& callee,
3297 absl::string_view name) {
3298 CHECK(absl::c_binary_search(global_computations_, &callee));
3299
3300 Call(FindOrDie(emitted_functions_, &callee),
3301 GetArrayFunctionCallArguments(
3302 /*parameter_addresses=*/{}, &b_, name,
3303 /*return_value_buffer=*/
3304 llvm::Constant::getNullValue(b_.getInt8PtrTy()),
3305 /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
3306 /*buffer_table_arg=*/GetBufferTableArgument(),
3307 /*profile_counters_arg=*/GetProfileCountersArgument()));
3308 }
3309
GetBufferForGlobalCallReturnValue(const HloComputation & callee)3310 llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
3311 const HloComputation& callee) {
3312 const HloInstruction* root_inst = callee.root_instruction();
3313 if (root_inst->opcode() == HloOpcode::kOutfeed) {
3314 return llvm::Constant::getNullValue(b_.getInt8PtrTy());
3315 }
3316
3317 const BufferAllocation::Slice root_buffer =
3318 assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
3319 return EmitBufferPointer(root_buffer, root_inst->shape());
3320 }
3321
BindFusionArguments(const HloInstruction * fusion,FusedIrEmitter * fused_emitter)3322 void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
3323 FusedIrEmitter* fused_emitter) {
3324 for (int i = 0; i < fusion->operand_count(); i++) {
3325 const HloInstruction* operand = fusion->operand(i);
3326 fused_emitter->BindGenerator(
3327 fusion->fused_parameter(i),
3328 [this, operand](llvm_ir::IrArray::Index index) {
3329 return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_);
3330 });
3331 }
3332 }
3333
3334 } // namespace cpu
3335 } // namespace xla
3336